[Distributed] implement adhoc requirements properly for Encoder

This commit is contained in:
Konrad `ktoso` Malawski
2022-02-14 16:14:41 +09:00
parent 0fec1b31da
commit 005743c92c
11 changed files with 596 additions and 147 deletions

View File

@@ -82,22 +82,7 @@ Type swift::getDistributedActorSystemType(NominalTypeDecl *actor) {
Type swift::getDistributedActorIDType(NominalTypeDecl *actor) {
auto &C = actor->getASTContext();
return C.getAssociatedTypeOfDistributedSystem(actor, C.Id_ActorID);
}
Type swift::getDistributedActorSystemSerializationRequirementType(NominalTypeDecl *system) {
assert(!system->isDistributedActor());
auto &ctx = system->getASTContext();
auto protocol = ctx.getProtocol(KnownProtocolKind::DistributedActorSystem);
if (!protocol)
return Type();
// Dig out the serialization requirement type.
auto module = system->getParentModule();
Type selfType = system->getSelfInterfaceType();
auto conformance = module->lookupConformance(selfType, protocol);
return conformance.getTypeWitnessByName(selfType, ctx.Id_SerializationRequirement);
return C.getAssociatedTypeOfDistributedSystemOfActor(actor, C.Id_ActorID);
}
Type swift::getDistributedActorSystemActorIDRequirementType(NominalTypeDecl *system) {
@@ -115,8 +100,24 @@ Type swift::getDistributedActorSystemActorIDRequirementType(NominalTypeDecl *sys
return conformance.getTypeWitnessByName(selfType, ctx.Id_ActorID);
}
Type ASTContext::getAssociatedTypeOfDistributedSystem(NominalTypeDecl *actor,
Identifier member) {
Type swift::getDistributedSerializationRequirementType(
NominalTypeDecl *nominal, ProtocolDecl *protocol) {
assert(!nominal->isDistributedActor());
assert(protocol);
auto &ctx = nominal->getASTContext();
// Dig out the serialization requirement type.
auto module = nominal->getParentModule();
Type selfType = nominal->getSelfInterfaceType();
auto conformance = module->lookupConformance(selfType, protocol);
if (conformance.isInvalid())
return Type();
return conformance.getTypeWitnessByName(selfType, ctx.Id_SerializationRequirement);
}
Type ASTContext::getAssociatedTypeOfDistributedSystemOfActor(
NominalTypeDecl *actor, Identifier member) {
auto &ctx = actor->getASTContext();
auto actorProtocol = ctx.getProtocol(KnownProtocolKind::DistributedActor);
@@ -148,11 +149,11 @@ Type ASTContext::getAssociatedTypeOfDistributedSystem(NominalTypeDecl *actor,
actorProtocol, selfType, conformance));
}
Type ASTContext::getDistributedSerializationRequirementType(
NominalTypeDecl *nominal) {
return getAssociatedTypeOfDistributedSystem(nominal,
Id_SerializationRequirement);
}
//Type ASTContext::getDistributedSerializationRequirementType(
// NominalTypeDecl *nominal) {
// return getAssociatedTypeOfDistributedSystemOfActor(nominal,
// Id_SerializationRequirement);
//}
FuncDecl*
ASTContext::getDistributedActorArgumentDecodingMethod(NominalTypeDecl *actor) {
@@ -174,10 +175,11 @@ ASTContext::getDistributedActorInvocationDecoder(NominalTypeDecl *actor) {
bool
swift::getDistributedActorSystemSerializationRequirements(
NominalTypeDecl *systemNominal,
NominalTypeDecl *nominal,
ProtocolDecl *protocol,
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos) {
auto existentialRequirementTy =
getDistributedActorSystemSerializationRequirementType(systemNominal);
getDistributedSerializationRequirementType(nominal, protocol);
if (existentialRequirementTy->hasError()) {
fprintf(stderr, "[%s:%d] (%s) if (SerializationRequirementTy->hasError())\n", __FILE__, __LINE__, __FUNCTION__);
return false;
@@ -265,7 +267,6 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
return false;
}
// === Must be declared in a 'DistributedActorSystem' conforming type
ProtocolDecl *systemProto =
C.getProtocol(KnownProtocolKind::DistributedActorSystem);
@@ -310,7 +311,7 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
// === Get the SerializationRequirement
SmallPtrSet<ProtocolDecl*, 2> requirementProtos;
if (!getDistributedActorSystemSerializationRequirements(
systemNominal, requirementProtos)) {
systemNominal, systemProto, requirementProtos)) {
return false;
}
@@ -473,6 +474,8 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
return false;
}
}
} else {
return false;
}
}
@@ -491,6 +494,345 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
return true;
}
bool
AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const {
auto &C = getASTContext();
auto module = getParentModule();
// === Check base name
if (getBaseIdentifier() != C.Id_recordArgument)
return false;
// === Must be declared in a 'DistributedTargetInvocationEncoder' conforming type
ProtocolDecl *encoderProto =
C.getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
auto encoderNominal = getDeclContext()->getSelfNominalTypeDecl();
auto protocolConformance = module->lookupConformance(
encoderNominal->getDeclaredInterfaceType(), encoderProto);
if (protocolConformance.isInvalid()) {
return false;
}
// === Check modifiers
// --- must not be async
if (hasAsync()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- must be throwing
if (!hasThrows()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check generics
if (!isGeneric()) {
fprintf(stderr, "[%s:%d] (%s) if (!isGeneric())\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check number of generic parameters
auto genericParams = getGenericParams();
unsigned int expectedGenericParamNum = 1;
if (genericParams->size() != expectedGenericParamNum) {
fprintf(stderr, "[%s:%d] (%s) if (genericParams->size() != expectedGenericParamNum)\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Get the SerializationRequirement
SmallPtrSet<ProtocolDecl*, 2> requirementProtos;
if (!getDistributedActorSystemSerializationRequirements(
encoderNominal, encoderProto, requirementProtos)) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// -- Check number of generic requirements
size_t serializationRequirementsNum = requirementProtos.size();
size_t expectedRequirementsNum = serializationRequirementsNum;
// === Check all parameters
auto params = getParameters();
if (params->size() != 1) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check parameter: _ argument
auto argumentParam = params->get(0);
if (!argumentParam->getArgumentName().is("_")) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check generic parameters in detail
// --- Check: Argument: SerializationRequirement
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
auto sig = getGenericSignature();
auto requirements = sig.getRequirements();
if (requirements.size() != expectedRequirementsNum) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check the expected requirements
// --- all the Argument requirements ---
// conforms_to: Argument Decodable
// conforms_to: Argument Encodable
// ...
auto func = dyn_cast<FuncDecl>(this);
if (!func) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
->getMetatypeInstanceType()
->getDesugaredType();
auto resultParamType = func->mapTypeIntoContext(
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
// The result of the function must be the `Res` generic argument.
if (!resultType->isEqual(resultParamType)) {
return false;
}
for (auto requirementProto : requirementProtos) {
auto conformance = module->lookupConformance(resultType, requirementProto);
if (conformance.isInvalid()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
}
// === Check result type: Void
if (!func->getResultInterfaceType()->isVoid()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
return true;
}
bool
AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() const {
auto &C = getASTContext();
auto module = getParentModule();
// === Check base name
if (getBaseIdentifier() != C.Id_recordReturnType)
return false;
// === Must be declared in a 'DistributedTargetInvocationEncoder' conforming type
ProtocolDecl *encoderProto =
C.getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
auto encoderNominal = getDeclContext()->getSelfNominalTypeDecl();
auto protocolConformance = module->lookupConformance(
encoderNominal->getDeclaredInterfaceType(), encoderProto);
if (protocolConformance.isInvalid()) {
return false;
}
// === Check modifiers
// --- must not be async
if (hasAsync()) {
return false;
}
// --- must be throwing
if (!hasThrows()) {
return false;
}
// === Check generics
if (!isGeneric()) {
fprintf(stderr, "[%s:%d] (%s) if (!isGeneric())\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check number of generic parameters
auto genericParams = getGenericParams();
unsigned int expectedGenericParamNum = 1;
if (genericParams->size() != expectedGenericParamNum) {
fprintf(stderr, "[%s:%d] (%s) if (genericParams->size() != expectedGenericParamNum)\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Get the SerializationRequirement
SmallPtrSet<ProtocolDecl*, 2> requirementProtos;
if (!getDistributedActorSystemSerializationRequirements(
encoderNominal, encoderProto, requirementProtos)) {
return false;
}
// -- Check number of generic requirements
size_t serializationRequirementsNum = requirementProtos.size();
size_t expectedRequirementsNum = serializationRequirementsNum;
// === Check all parameters
auto params = getParameters();
if (params->size() != 1) {
return false;
}
// --- Check parameter: _ argument
auto argumentParam = params->get(0);
if (!argumentParam->getArgumentName().is("_")) {
return false;
}
// === Check generic parameters in detail
// --- Check: Argument: SerializationRequirement
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
auto sig = getGenericSignature();
auto requirements = sig.getRequirements();
if (requirements.size() != expectedRequirementsNum) {
return false;
}
// --- Check the expected requirements
// --- all the Argument requirements ---
// conforms_to: Argument Decodable
// conforms_to: Argument Encodable
// ...
auto func = dyn_cast<FuncDecl>(this);
if (!func)
return false;
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
->getMetatypeInstanceType()
->getDesugaredType();
auto resultParamType = func->mapTypeIntoContext(
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
// The result of the function must be the `Res` generic argument.
if (!resultType->isEqual(resultParamType)) {
return false;
}
for (auto requirementProto : requirementProtos) {
auto conformance = module->lookupConformance(resultType, requirementProto);
if (conformance.isInvalid()) {
return false;
}
}
// === Check result type: Void
if (!func->getResultInterfaceType()->isVoid()) {
return false;
}
return true;
}
bool
AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() const {
auto &C = getASTContext();
auto module = getParentModule();
// === Check base name
if (getBaseIdentifier() != C.Id_recordReturnType)
return false;
// === Must be declared in a 'DistributedTargetInvocationEncoder' conforming type
ProtocolDecl *encoderProto =
C.getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
auto encoderNominal = getDeclContext()->getSelfNominalTypeDecl();
auto protocolConformance = module->lookupConformance(
encoderNominal->getDeclaredInterfaceType(), encoderProto);
if (protocolConformance.isInvalid()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check modifiers
// --- must not be async
if (hasAsync()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- must be throwing
if (!hasThrows()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check generics
if (!isGeneric()) {
fprintf(stderr, "[%s:%d] (%s) if (!isGeneric())\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check number of generic parameters
auto genericParams = getGenericParams();
unsigned int expectedGenericParamNum = 1;
if (genericParams->size() != expectedGenericParamNum) {
fprintf(stderr, "[%s:%d] (%s) if (genericParams->size() != expectedGenericParamNum)\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check all parameters
auto params = getParameters();
if (params->size() != 1) {
return false;
}
// --- Check parameter: _ argument
auto argumentParam = params->get(0);
if (!argumentParam->getArgumentName().is("_")) {
return false;
}
// --- Check: Argument: SerializationRequirement
auto sig = getGenericSignature();
auto requirements = sig.getRequirements();
if (requirements.size() != 1) {
return false;
}
// === Check generic parameters in detail
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
// --- Check requirement: conforms_to: Err Error
auto errorReq = requirements[0];
auto errorTy = C.getProtocol(KnownProtocolKind::Error)
->getInterfaceType()
->getMetatypeInstanceType();
if (errorReq.getKind() != RequirementKind::Conformance) {
return false;
}
if (!errorReq.getSecondType()->isEqual(errorTy)) {
return false;
}
// === Check result type: Void
auto func = dyn_cast<FuncDecl>(this);
if (!func)
return false;
if (!func->getResultInterfaceType()->isVoid()) {
return false;
}
return true;
}
llvm::SmallPtrSet<ProtocolDecl *, 2>
swift::extractDistributedSerializationRequirements(
ASTContext &C, ArrayRef<Requirement> allRequirements) {