mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[Distributed] implement adhoc requirements properly for Encoder
This commit is contained in:
@@ -82,22 +82,7 @@ Type swift::getDistributedActorSystemType(NominalTypeDecl *actor) {
|
||||
|
||||
Type swift::getDistributedActorIDType(NominalTypeDecl *actor) {
|
||||
auto &C = actor->getASTContext();
|
||||
return C.getAssociatedTypeOfDistributedSystem(actor, C.Id_ActorID);
|
||||
}
|
||||
|
||||
Type swift::getDistributedActorSystemSerializationRequirementType(NominalTypeDecl *system) {
|
||||
assert(!system->isDistributedActor());
|
||||
auto &ctx = system->getASTContext();
|
||||
|
||||
auto protocol = ctx.getProtocol(KnownProtocolKind::DistributedActorSystem);
|
||||
if (!protocol)
|
||||
return Type();
|
||||
|
||||
// Dig out the serialization requirement type.
|
||||
auto module = system->getParentModule();
|
||||
Type selfType = system->getSelfInterfaceType();
|
||||
auto conformance = module->lookupConformance(selfType, protocol);
|
||||
return conformance.getTypeWitnessByName(selfType, ctx.Id_SerializationRequirement);
|
||||
return C.getAssociatedTypeOfDistributedSystemOfActor(actor, C.Id_ActorID);
|
||||
}
|
||||
|
||||
Type swift::getDistributedActorSystemActorIDRequirementType(NominalTypeDecl *system) {
|
||||
@@ -115,8 +100,24 @@ Type swift::getDistributedActorSystemActorIDRequirementType(NominalTypeDecl *sys
|
||||
return conformance.getTypeWitnessByName(selfType, ctx.Id_ActorID);
|
||||
}
|
||||
|
||||
Type ASTContext::getAssociatedTypeOfDistributedSystem(NominalTypeDecl *actor,
|
||||
Identifier member) {
|
||||
Type swift::getDistributedSerializationRequirementType(
|
||||
NominalTypeDecl *nominal, ProtocolDecl *protocol) {
|
||||
assert(!nominal->isDistributedActor());
|
||||
assert(protocol);
|
||||
auto &ctx = nominal->getASTContext();
|
||||
|
||||
// Dig out the serialization requirement type.
|
||||
auto module = nominal->getParentModule();
|
||||
Type selfType = nominal->getSelfInterfaceType();
|
||||
auto conformance = module->lookupConformance(selfType, protocol);
|
||||
if (conformance.isInvalid())
|
||||
return Type();
|
||||
|
||||
return conformance.getTypeWitnessByName(selfType, ctx.Id_SerializationRequirement);
|
||||
}
|
||||
|
||||
Type ASTContext::getAssociatedTypeOfDistributedSystemOfActor(
|
||||
NominalTypeDecl *actor, Identifier member) {
|
||||
auto &ctx = actor->getASTContext();
|
||||
|
||||
auto actorProtocol = ctx.getProtocol(KnownProtocolKind::DistributedActor);
|
||||
@@ -148,11 +149,11 @@ Type ASTContext::getAssociatedTypeOfDistributedSystem(NominalTypeDecl *actor,
|
||||
actorProtocol, selfType, conformance));
|
||||
}
|
||||
|
||||
Type ASTContext::getDistributedSerializationRequirementType(
|
||||
NominalTypeDecl *nominal) {
|
||||
return getAssociatedTypeOfDistributedSystem(nominal,
|
||||
Id_SerializationRequirement);
|
||||
}
|
||||
//Type ASTContext::getDistributedSerializationRequirementType(
|
||||
// NominalTypeDecl *nominal) {
|
||||
// return getAssociatedTypeOfDistributedSystemOfActor(nominal,
|
||||
// Id_SerializationRequirement);
|
||||
//}
|
||||
|
||||
FuncDecl*
|
||||
ASTContext::getDistributedActorArgumentDecodingMethod(NominalTypeDecl *actor) {
|
||||
@@ -174,10 +175,11 @@ ASTContext::getDistributedActorInvocationDecoder(NominalTypeDecl *actor) {
|
||||
|
||||
bool
|
||||
swift::getDistributedActorSystemSerializationRequirements(
|
||||
NominalTypeDecl *systemNominal,
|
||||
NominalTypeDecl *nominal,
|
||||
ProtocolDecl *protocol,
|
||||
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos) {
|
||||
auto existentialRequirementTy =
|
||||
getDistributedActorSystemSerializationRequirementType(systemNominal);
|
||||
getDistributedSerializationRequirementType(nominal, protocol);
|
||||
if (existentialRequirementTy->hasError()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) if (SerializationRequirementTy->hasError())\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
@@ -265,7 +267,6 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// === Must be declared in a 'DistributedActorSystem' conforming type
|
||||
ProtocolDecl *systemProto =
|
||||
C.getProtocol(KnownProtocolKind::DistributedActorSystem);
|
||||
@@ -310,7 +311,7 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
|
||||
// === Get the SerializationRequirement
|
||||
SmallPtrSet<ProtocolDecl*, 2> requirementProtos;
|
||||
if (!getDistributedActorSystemSerializationRequirements(
|
||||
systemNominal, requirementProtos)) {
|
||||
systemNominal, systemProto, requirementProtos)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -473,6 +474,8 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,6 +494,345 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const {
|
||||
auto &C = getASTContext();
|
||||
auto module = getParentModule();
|
||||
|
||||
// === Check base name
|
||||
if (getBaseIdentifier() != C.Id_recordArgument)
|
||||
return false;
|
||||
|
||||
// === Must be declared in a 'DistributedTargetInvocationEncoder' conforming type
|
||||
ProtocolDecl *encoderProto =
|
||||
C.getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
|
||||
|
||||
auto encoderNominal = getDeclContext()->getSelfNominalTypeDecl();
|
||||
auto protocolConformance = module->lookupConformance(
|
||||
encoderNominal->getDeclaredInterfaceType(), encoderProto);
|
||||
|
||||
if (protocolConformance.isInvalid()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check modifiers
|
||||
// --- must not be async
|
||||
if (hasAsync()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- must be throwing
|
||||
if (!hasThrows()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check generics
|
||||
if (!isGeneric()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) if (!isGeneric())\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Check number of generic parameters
|
||||
auto genericParams = getGenericParams();
|
||||
unsigned int expectedGenericParamNum = 1;
|
||||
|
||||
if (genericParams->size() != expectedGenericParamNum) {
|
||||
fprintf(stderr, "[%s:%d] (%s) if (genericParams->size() != expectedGenericParamNum)\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Get the SerializationRequirement
|
||||
SmallPtrSet<ProtocolDecl*, 2> requirementProtos;
|
||||
if (!getDistributedActorSystemSerializationRequirements(
|
||||
encoderNominal, encoderProto, requirementProtos)) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// -- Check number of generic requirements
|
||||
size_t serializationRequirementsNum = requirementProtos.size();
|
||||
size_t expectedRequirementsNum = serializationRequirementsNum;
|
||||
|
||||
// === Check all parameters
|
||||
auto params = getParameters();
|
||||
if (params->size() != 1) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Check parameter: _ argument
|
||||
auto argumentParam = params->get(0);
|
||||
if (!argumentParam->getArgumentName().is("_")) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check generic parameters in detail
|
||||
// --- Check: Argument: SerializationRequirement
|
||||
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
|
||||
|
||||
auto sig = getGenericSignature();
|
||||
auto requirements = sig.getRequirements();
|
||||
|
||||
if (requirements.size() != expectedRequirementsNum) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Check the expected requirements
|
||||
// --- all the Argument requirements ---
|
||||
// conforms_to: Argument Decodable
|
||||
// conforms_to: Argument Encodable
|
||||
// ...
|
||||
|
||||
auto func = dyn_cast<FuncDecl>(this);
|
||||
if (!func) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
|
||||
->getMetatypeInstanceType()
|
||||
->getDesugaredType();
|
||||
auto resultParamType = func->mapTypeIntoContext(
|
||||
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
|
||||
// The result of the function must be the `Res` generic argument.
|
||||
if (!resultType->isEqual(resultParamType)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto requirementProto : requirementProtos) {
|
||||
auto conformance = module->lookupConformance(resultType, requirementProto);
|
||||
if (conformance.isInvalid()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// === Check result type: Void
|
||||
if (!func->getResultInterfaceType()->isVoid()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() const {
|
||||
auto &C = getASTContext();
|
||||
auto module = getParentModule();
|
||||
|
||||
// === Check base name
|
||||
if (getBaseIdentifier() != C.Id_recordReturnType)
|
||||
return false;
|
||||
|
||||
// === Must be declared in a 'DistributedTargetInvocationEncoder' conforming type
|
||||
ProtocolDecl *encoderProto =
|
||||
C.getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
|
||||
|
||||
auto encoderNominal = getDeclContext()->getSelfNominalTypeDecl();
|
||||
auto protocolConformance = module->lookupConformance(
|
||||
encoderNominal->getDeclaredInterfaceType(), encoderProto);
|
||||
|
||||
if (protocolConformance.isInvalid()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check modifiers
|
||||
// --- must not be async
|
||||
if (hasAsync()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- must be throwing
|
||||
if (!hasThrows()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check generics
|
||||
if (!isGeneric()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) if (!isGeneric())\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Check number of generic parameters
|
||||
auto genericParams = getGenericParams();
|
||||
unsigned int expectedGenericParamNum = 1;
|
||||
|
||||
if (genericParams->size() != expectedGenericParamNum) {
|
||||
fprintf(stderr, "[%s:%d] (%s) if (genericParams->size() != expectedGenericParamNum)\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Get the SerializationRequirement
|
||||
SmallPtrSet<ProtocolDecl*, 2> requirementProtos;
|
||||
if (!getDistributedActorSystemSerializationRequirements(
|
||||
encoderNominal, encoderProto, requirementProtos)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// -- Check number of generic requirements
|
||||
size_t serializationRequirementsNum = requirementProtos.size();
|
||||
size_t expectedRequirementsNum = serializationRequirementsNum;
|
||||
|
||||
// === Check all parameters
|
||||
auto params = getParameters();
|
||||
if (params->size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Check parameter: _ argument
|
||||
auto argumentParam = params->get(0);
|
||||
if (!argumentParam->getArgumentName().is("_")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check generic parameters in detail
|
||||
// --- Check: Argument: SerializationRequirement
|
||||
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
|
||||
|
||||
auto sig = getGenericSignature();
|
||||
auto requirements = sig.getRequirements();
|
||||
|
||||
if (requirements.size() != expectedRequirementsNum) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Check the expected requirements
|
||||
// --- all the Argument requirements ---
|
||||
// conforms_to: Argument Decodable
|
||||
// conforms_to: Argument Encodable
|
||||
// ...
|
||||
|
||||
auto func = dyn_cast<FuncDecl>(this);
|
||||
if (!func)
|
||||
return false;
|
||||
|
||||
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
|
||||
->getMetatypeInstanceType()
|
||||
->getDesugaredType();
|
||||
auto resultParamType = func->mapTypeIntoContext(
|
||||
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
|
||||
// The result of the function must be the `Res` generic argument.
|
||||
if (!resultType->isEqual(resultParamType)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto requirementProto : requirementProtos) {
|
||||
auto conformance = module->lookupConformance(resultType, requirementProto);
|
||||
if (conformance.isInvalid()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// === Check result type: Void
|
||||
if (!func->getResultInterfaceType()->isVoid()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() const {
|
||||
|
||||
auto &C = getASTContext();
|
||||
auto module = getParentModule();
|
||||
|
||||
// === Check base name
|
||||
if (getBaseIdentifier() != C.Id_recordReturnType)
|
||||
return false;
|
||||
|
||||
// === Must be declared in a 'DistributedTargetInvocationEncoder' conforming type
|
||||
ProtocolDecl *encoderProto =
|
||||
C.getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
|
||||
|
||||
auto encoderNominal = getDeclContext()->getSelfNominalTypeDecl();
|
||||
auto protocolConformance = module->lookupConformance(
|
||||
encoderNominal->getDeclaredInterfaceType(), encoderProto);
|
||||
|
||||
if (protocolConformance.isInvalid()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check modifiers
|
||||
// --- must not be async
|
||||
if (hasAsync()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- must be throwing
|
||||
if (!hasThrows()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check generics
|
||||
if (!isGeneric()) {
|
||||
fprintf(stderr, "[%s:%d] (%s) if (!isGeneric())\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Check number of generic parameters
|
||||
auto genericParams = getGenericParams();
|
||||
unsigned int expectedGenericParamNum = 1;
|
||||
|
||||
if (genericParams->size() != expectedGenericParamNum) {
|
||||
fprintf(stderr, "[%s:%d] (%s) if (genericParams->size() != expectedGenericParamNum)\n", __FILE__, __LINE__, __FUNCTION__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check all parameters
|
||||
auto params = getParameters();
|
||||
if (params->size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Check parameter: _ argument
|
||||
auto argumentParam = params->get(0);
|
||||
if (!argumentParam->getArgumentName().is("_")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// --- Check: Argument: SerializationRequirement
|
||||
auto sig = getGenericSignature();
|
||||
auto requirements = sig.getRequirements();
|
||||
if (requirements.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check generic parameters in detail
|
||||
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
|
||||
|
||||
// --- Check requirement: conforms_to: Err Error
|
||||
auto errorReq = requirements[0];
|
||||
auto errorTy = C.getProtocol(KnownProtocolKind::Error)
|
||||
->getInterfaceType()
|
||||
->getMetatypeInstanceType();
|
||||
if (errorReq.getKind() != RequirementKind::Conformance) {
|
||||
return false;
|
||||
}
|
||||
if (!errorReq.getSecondType()->isEqual(errorTy)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// === Check result type: Void
|
||||
auto func = dyn_cast<FuncDecl>(this);
|
||||
if (!func)
|
||||
return false;
|
||||
|
||||
if (!func->getResultInterfaceType()->isVoid()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::SmallPtrSet<ProtocolDecl *, 2>
|
||||
swift::extractDistributedSerializationRequirements(
|
||||
ASTContext &C, ArrayRef<Requirement> allRequirements) {
|
||||
|
||||
Reference in New Issue
Block a user