[Distributed] implement adhoc requirements properly for Encoder

This commit is contained in:
Konrad `ktoso` Malawski
2022-02-14 16:14:41 +09:00
parent 0fec1b31da
commit 005743c92c
11 changed files with 596 additions and 147 deletions

View File

@@ -693,15 +693,15 @@ public:
// \param nominal optionally provide a 'NominalTypeDecl' from which the
// function decl shall be extracted. This is useful to avoid witness calls
// through the protocol which is looked up when nominal is null.
FuncDecl *getRecordArgumentOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const;
// Retrieve the declaration of DistributedInvocationEncoder.recordErrorType(_:).
FuncDecl *getRecordErrorTypeOnDistributedInvocationEncoder(
AbstractFunctionDecl *getRecordArgumentOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const;
// Retrieve the declaration of DistributedInvocationEncoder.recordReturnType(_:).
FuncDecl *getRecordReturnTypeOnDistributedInvocationEncoder(
AbstractFunctionDecl *getRecordReturnTypeOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const;
// Retrieve the declaration of DistributedInvocationEncoder.recordErrorType(_:).
AbstractFunctionDecl *getRecordErrorTypeOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const;
// Retrieve the declaration of DistributedInvocationEncoder.doneRecording().
@@ -1351,14 +1351,15 @@ public:
/// alternative specified via the -entry-point-function-name frontend flag.
std::string getEntryPointFunctionName() const;
Type getAssociatedTypeOfDistributedSystem(NominalTypeDecl *actor,
Type getAssociatedTypeOfDistributedSystemOfActor(NominalTypeDecl *actor,
Identifier member);
/// Find the type of SerializationRequirement on the passed nominal.
///
/// This type exists as a typealias/associatedtype on all distributed actors,
/// actor systems, and related serialization types.
Type getDistributedSerializationRequirementType(NominalTypeDecl *);
// /// Find the type of SerializationRequirement on the passed nominal.
// ///
// /// This type exists as a typealias/associatedtype on all distributed actors,
// /// actor systems, and related serialization types.
// Type getDistributedSerializationRequirementType(
// NominalTypeDecl *, ProtocolDecl *protocolDecl);
/// Find the concrete invocation decoder associated with the given actor.
NominalTypeDecl *

View File

@@ -6422,6 +6422,21 @@ public:
/// 'DistributedActorSystem' protocol.
bool isDistributedActorSystemRemoteCall(bool isVoidReturn) const;
/// Determines if this function is a 'recordArgument' function,
/// which is used as ad-hoc protocol requirement by the
/// 'DistributedTargetInvocationEncoder' protocol.
bool isDistributedTargetInvocationEncoderRecordArgument() const;
/// Determines if this function is a 'recordReturnType' function,
/// which is used as ad-hoc protocol requirement by the
/// 'DistributedTargetInvocationEncoder' protocol.
bool isDistributedTargetInvocationEncoderRecordReturnType() const;
/// Determines if this function is a 'recordErrorType' function,
/// which is used as ad-hoc protocol requirement by the
/// 'DistributedTargetInvocationEncoder' protocol.
bool isDistributedTargetInvocationEncoderRecordErrorType() const;
/// For a method of a class, checks whether it will require a new entry in the
/// vtable.
bool needsNewVTableEntry() const;

View File

@@ -37,15 +37,20 @@ Type getDistributedActorSystemType(NominalTypeDecl *actor);
/// Determine the `ID` type for the given actor.
Type getDistributedActorIDType(NominalTypeDecl *actor);
Type getDistributedActorSystemSerializationRequirementType(
NominalTypeDecl *system);
/// Get specific 'SerializationRequirement' as defined in 'nominal'
/// type, which must conform to the passed 'protocol' which is expected
/// to require the 'SerializationRequirement'.
Type getDistributedSerializationRequirementType(
NominalTypeDecl *nominal, ProtocolDecl *protocol);
///// Determine the serialization requirement for the given actor, actor system
///// or other type that has the SerializationRequirement associated type.
//Type getDistributedSerializationRequirementType(
// NominalTypeDecl *nominal, ProtocolDecl *protocol);
Type getDistributedActorSystemActorIDRequirementType(
NominalTypeDecl *system);
/// Determine the serialization requirement for the given actor, actor system
/// or other type that has the SerializationRequirement associated type.
Type getDistributedSerializationRequirementType(NominalTypeDecl *actor);
/// Get the specific protocols that the `SerializationRequirement` specifies,
/// and all parameters / return types of distributed targets must conform to.
@@ -55,7 +60,8 @@ Type getDistributedSerializationRequirementType(NominalTypeDecl *actor);
///
/// Returns an empty set if the requirement was `Any`.
llvm::SmallPtrSet<ProtocolDecl *, 2>
getDistributedSerializationRequirementProtocols(NominalTypeDecl *decl);
getDistributedSerializationRequirementProtocols(
NominalTypeDecl *decl, ProtocolDecl* protocol);
/// Desugar and flatten the `SerializationRequirement` type into a set of
/// specific protocol declarations.
@@ -78,6 +84,7 @@ bool checkDistributedSerializationRequirementIsExactlyCodable(
bool
getDistributedActorSystemSerializationRequirements(
NominalTypeDecl *systemDecl,
ProtocolDecl *protocol,
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos);
/// Given any set of generic requirements, locate those which are about the

View File

@@ -1055,6 +1055,60 @@ public:
bool isCached() const { return true; }
};
/// Obtain the 'recordArgument' function of a 'DistributedTargetInvocationEncoder'.
class GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest :
public SimpleRequest<GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest,
AbstractFunctionDecl *(NominalTypeDecl *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
private:
friend SimpleRequest;
AbstractFunctionDecl *evaluate(Evaluator &evaluator, NominalTypeDecl *encoder) const;
public:
// Caching
bool isCached() const { return true; }
};
/// Obtain the 'recordReturnType' function of a 'DistributedTargetInvocationEncoder'.
class GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest :
public SimpleRequest<GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest,
AbstractFunctionDecl *(NominalTypeDecl *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
private:
friend SimpleRequest;
AbstractFunctionDecl *evaluate(Evaluator &evaluator, NominalTypeDecl *encoder) const;
public:
// Caching
bool isCached() const { return true; }
};
/// Obtain the 'recordErrorType' function of a 'DistributedTargetInvocationEncoder'.
class GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest :
public SimpleRequest<GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest,
AbstractFunctionDecl *(NominalTypeDecl *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
private:
friend SimpleRequest;
AbstractFunctionDecl *evaluate(Evaluator &evaluator, NominalTypeDecl *encoder) const;
public:
// Caching
bool isCached() const { return true; }
};
/// Obtain the 'actorSystem' property of a 'distributed actor'.
class GetDistributedActorSystemPropertyRequest :
public SimpleRequest<GetDistributedActorSystemPropertyRequest,

View File

@@ -108,6 +108,15 @@ SWIFT_REQUEST(TypeChecker, IsDistributedActorRequest, bool(NominalTypeDecl *),
SWIFT_REQUEST(TypeChecker, GetDistributedActorSystemRemoteCallFunctionRequest,
AbstractFunctionDecl *(NominalTypeDecl *, bool),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest,
AbstractFunctionDecl *(NominalTypeDecl *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest,
AbstractFunctionDecl *(NominalTypeDecl *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest,
AbstractFunctionDecl *(NominalTypeDecl *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, GetDistributedActorIDPropertyRequest,
VarDecl *(NominalTypeDecl *),
Cached, NoLocationInfo)

View File

@@ -1358,67 +1358,28 @@ ASTContext::getRecordGenericSubstitutionOnDistributedInvocationEncoder(
return nullptr;
}
FuncDecl *ASTContext::getRecordArgumentOnDistributedInvocationEncoder(
AbstractFunctionDecl *ASTContext::getRecordArgumentOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const {
for (auto result : nominal->lookupDirect(Id_recordArgument)) {
auto *fd = dyn_cast<FuncDecl>(result);
if (!fd)
continue;
if (fd->getParameters()->size() != 1)
continue;
if (fd->hasAsync())
continue;
if (!fd->hasThrows())
continue;
// TODO(distributed): more checks
if (fd->getResultInterfaceType()->isVoid())
return fd;
}
return nullptr;
return evaluateOrDefault(
nominal->getASTContext().evaluator,
GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest{nominal},
nullptr);
}
FuncDecl *ASTContext::getRecordErrorTypeOnDistributedInvocationEncoder(
AbstractFunctionDecl *ASTContext::getRecordReturnTypeOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const {
for (auto result : nominal->lookupDirect(Id_recordErrorType)) {
auto *fd = dyn_cast<FuncDecl>(result);
if (!fd)
continue;
if (fd->getParameters()->size() != 1)
continue;
if (fd->hasAsync())
continue;
if (!fd->hasThrows())
continue;
// TODO(distributed): more checks
if (fd->getResultInterfaceType()->isVoid())
return fd;
}
return nullptr;
return evaluateOrDefault(
nominal->getASTContext().evaluator,
GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest{nominal},
nullptr);
}
FuncDecl *ASTContext::getRecordReturnTypeOnDistributedInvocationEncoder(
AbstractFunctionDecl *ASTContext::getRecordErrorTypeOnDistributedInvocationEncoder(
NominalTypeDecl *nominal) const {
for (auto result : nominal->lookupDirect(Id_recordReturnType)) {
auto *fd = dyn_cast<FuncDecl>(result);
if (!fd)
continue;
if (fd->getParameters()->size() != 1)
continue;
if (fd->hasAsync())
continue;
if (!fd->hasThrows())
continue;
// TODO(distributed): more checks
if (fd->getResultInterfaceType()->isVoid())
return fd;
}
return nullptr;
return evaluateOrDefault(
nominal->getASTContext().evaluator,
GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest{nominal},
nullptr);
}
FuncDecl *ASTContext::getDoneRecordingOnDistributedInvocationEncoder(

View File

@@ -82,22 +82,7 @@ Type swift::getDistributedActorSystemType(NominalTypeDecl *actor) {
Type swift::getDistributedActorIDType(NominalTypeDecl *actor) {
auto &C = actor->getASTContext();
return C.getAssociatedTypeOfDistributedSystem(actor, C.Id_ActorID);
}
Type swift::getDistributedActorSystemSerializationRequirementType(NominalTypeDecl *system) {
assert(!system->isDistributedActor());
auto &ctx = system->getASTContext();
auto protocol = ctx.getProtocol(KnownProtocolKind::DistributedActorSystem);
if (!protocol)
return Type();
// Dig out the serialization requirement type.
auto module = system->getParentModule();
Type selfType = system->getSelfInterfaceType();
auto conformance = module->lookupConformance(selfType, protocol);
return conformance.getTypeWitnessByName(selfType, ctx.Id_SerializationRequirement);
return C.getAssociatedTypeOfDistributedSystemOfActor(actor, C.Id_ActorID);
}
Type swift::getDistributedActorSystemActorIDRequirementType(NominalTypeDecl *system) {
@@ -115,8 +100,24 @@ Type swift::getDistributedActorSystemActorIDRequirementType(NominalTypeDecl *sys
return conformance.getTypeWitnessByName(selfType, ctx.Id_ActorID);
}
Type ASTContext::getAssociatedTypeOfDistributedSystem(NominalTypeDecl *actor,
Identifier member) {
Type swift::getDistributedSerializationRequirementType(
NominalTypeDecl *nominal, ProtocolDecl *protocol) {
assert(!nominal->isDistributedActor());
assert(protocol);
auto &ctx = nominal->getASTContext();
// Dig out the serialization requirement type.
auto module = nominal->getParentModule();
Type selfType = nominal->getSelfInterfaceType();
auto conformance = module->lookupConformance(selfType, protocol);
if (conformance.isInvalid())
return Type();
return conformance.getTypeWitnessByName(selfType, ctx.Id_SerializationRequirement);
}
Type ASTContext::getAssociatedTypeOfDistributedSystemOfActor(
NominalTypeDecl *actor, Identifier member) {
auto &ctx = actor->getASTContext();
auto actorProtocol = ctx.getProtocol(KnownProtocolKind::DistributedActor);
@@ -148,11 +149,11 @@ Type ASTContext::getAssociatedTypeOfDistributedSystem(NominalTypeDecl *actor,
actorProtocol, selfType, conformance));
}
Type ASTContext::getDistributedSerializationRequirementType(
NominalTypeDecl *nominal) {
return getAssociatedTypeOfDistributedSystem(nominal,
Id_SerializationRequirement);
}
//Type ASTContext::getDistributedSerializationRequirementType(
// NominalTypeDecl *nominal) {
// return getAssociatedTypeOfDistributedSystemOfActor(nominal,
// Id_SerializationRequirement);
//}
FuncDecl*
ASTContext::getDistributedActorArgumentDecodingMethod(NominalTypeDecl *actor) {
@@ -174,10 +175,11 @@ ASTContext::getDistributedActorInvocationDecoder(NominalTypeDecl *actor) {
bool
swift::getDistributedActorSystemSerializationRequirements(
NominalTypeDecl *systemNominal,
NominalTypeDecl *nominal,
ProtocolDecl *protocol,
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos) {
auto existentialRequirementTy =
getDistributedActorSystemSerializationRequirementType(systemNominal);
getDistributedSerializationRequirementType(nominal, protocol);
if (existentialRequirementTy->hasError()) {
fprintf(stderr, "[%s:%d] (%s) if (SerializationRequirementTy->hasError())\n", __FILE__, __LINE__, __FUNCTION__);
return false;
@@ -265,7 +267,6 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
return false;
}
// === Must be declared in a 'DistributedActorSystem' conforming type
ProtocolDecl *systemProto =
C.getProtocol(KnownProtocolKind::DistributedActorSystem);
@@ -310,7 +311,7 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
// === Get the SerializationRequirement
SmallPtrSet<ProtocolDecl*, 2> requirementProtos;
if (!getDistributedActorSystemSerializationRequirements(
systemNominal, requirementProtos)) {
systemNominal, systemProto, requirementProtos)) {
return false;
}
@@ -473,6 +474,8 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
return false;
}
}
} else {
return false;
}
}
@@ -491,6 +494,345 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
return true;
}
bool
AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const {
auto &C = getASTContext();
auto module = getParentModule();
// === Check base name
if (getBaseIdentifier() != C.Id_recordArgument)
return false;
// === Must be declared in a 'DistributedTargetInvocationEncoder' conforming type
ProtocolDecl *encoderProto =
C.getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
auto encoderNominal = getDeclContext()->getSelfNominalTypeDecl();
auto protocolConformance = module->lookupConformance(
encoderNominal->getDeclaredInterfaceType(), encoderProto);
if (protocolConformance.isInvalid()) {
return false;
}
// === Check modifiers
// --- must not be async
if (hasAsync()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- must be throwing
if (!hasThrows()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check generics
if (!isGeneric()) {
fprintf(stderr, "[%s:%d] (%s) if (!isGeneric())\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check number of generic parameters
auto genericParams = getGenericParams();
unsigned int expectedGenericParamNum = 1;
if (genericParams->size() != expectedGenericParamNum) {
fprintf(stderr, "[%s:%d] (%s) if (genericParams->size() != expectedGenericParamNum)\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Get the SerializationRequirement
SmallPtrSet<ProtocolDecl*, 2> requirementProtos;
if (!getDistributedActorSystemSerializationRequirements(
encoderNominal, encoderProto, requirementProtos)) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// -- Check number of generic requirements
size_t serializationRequirementsNum = requirementProtos.size();
size_t expectedRequirementsNum = serializationRequirementsNum;
// === Check all parameters
auto params = getParameters();
if (params->size() != 1) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check parameter: _ argument
auto argumentParam = params->get(0);
if (!argumentParam->getArgumentName().is("_")) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check generic parameters in detail
// --- Check: Argument: SerializationRequirement
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
auto sig = getGenericSignature();
auto requirements = sig.getRequirements();
if (requirements.size() != expectedRequirementsNum) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check the expected requirements
// --- all the Argument requirements ---
// conforms_to: Argument Decodable
// conforms_to: Argument Encodable
// ...
auto func = dyn_cast<FuncDecl>(this);
if (!func) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
->getMetatypeInstanceType()
->getDesugaredType();
auto resultParamType = func->mapTypeIntoContext(
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
// The result of the function must be the `Res` generic argument.
if (!resultType->isEqual(resultParamType)) {
return false;
}
for (auto requirementProto : requirementProtos) {
auto conformance = module->lookupConformance(resultType, requirementProto);
if (conformance.isInvalid()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
}
// === Check result type: Void
if (!func->getResultInterfaceType()->isVoid()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
return true;
}
bool
AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() const {
auto &C = getASTContext();
auto module = getParentModule();
// === Check base name
if (getBaseIdentifier() != C.Id_recordReturnType)
return false;
// === Must be declared in a 'DistributedTargetInvocationEncoder' conforming type
ProtocolDecl *encoderProto =
C.getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
auto encoderNominal = getDeclContext()->getSelfNominalTypeDecl();
auto protocolConformance = module->lookupConformance(
encoderNominal->getDeclaredInterfaceType(), encoderProto);
if (protocolConformance.isInvalid()) {
return false;
}
// === Check modifiers
// --- must not be async
if (hasAsync()) {
return false;
}
// --- must be throwing
if (!hasThrows()) {
return false;
}
// === Check generics
if (!isGeneric()) {
fprintf(stderr, "[%s:%d] (%s) if (!isGeneric())\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check number of generic parameters
auto genericParams = getGenericParams();
unsigned int expectedGenericParamNum = 1;
if (genericParams->size() != expectedGenericParamNum) {
fprintf(stderr, "[%s:%d] (%s) if (genericParams->size() != expectedGenericParamNum)\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Get the SerializationRequirement
SmallPtrSet<ProtocolDecl*, 2> requirementProtos;
if (!getDistributedActorSystemSerializationRequirements(
encoderNominal, encoderProto, requirementProtos)) {
return false;
}
// -- Check number of generic requirements
size_t serializationRequirementsNum = requirementProtos.size();
size_t expectedRequirementsNum = serializationRequirementsNum;
// === Check all parameters
auto params = getParameters();
if (params->size() != 1) {
return false;
}
// --- Check parameter: _ argument
auto argumentParam = params->get(0);
if (!argumentParam->getArgumentName().is("_")) {
return false;
}
// === Check generic parameters in detail
// --- Check: Argument: SerializationRequirement
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
auto sig = getGenericSignature();
auto requirements = sig.getRequirements();
if (requirements.size() != expectedRequirementsNum) {
return false;
}
// --- Check the expected requirements
// --- all the Argument requirements ---
// conforms_to: Argument Decodable
// conforms_to: Argument Encodable
// ...
auto func = dyn_cast<FuncDecl>(this);
if (!func)
return false;
auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType())
->getMetatypeInstanceType()
->getDesugaredType();
auto resultParamType = func->mapTypeIntoContext(
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
// The result of the function must be the `Res` generic argument.
if (!resultType->isEqual(resultParamType)) {
return false;
}
for (auto requirementProto : requirementProtos) {
auto conformance = module->lookupConformance(resultType, requirementProto);
if (conformance.isInvalid()) {
return false;
}
}
// === Check result type: Void
if (!func->getResultInterfaceType()->isVoid()) {
return false;
}
return true;
}
bool
AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() const {
auto &C = getASTContext();
auto module = getParentModule();
// === Check base name
if (getBaseIdentifier() != C.Id_recordReturnType)
return false;
// === Must be declared in a 'DistributedTargetInvocationEncoder' conforming type
ProtocolDecl *encoderProto =
C.getProtocol(KnownProtocolKind::DistributedTargetInvocationEncoder);
auto encoderNominal = getDeclContext()->getSelfNominalTypeDecl();
auto protocolConformance = module->lookupConformance(
encoderNominal->getDeclaredInterfaceType(), encoderProto);
if (protocolConformance.isInvalid()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check modifiers
// --- must not be async
if (hasAsync()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- must be throwing
if (!hasThrows()) {
fprintf(stderr, "[%s:%d] (%s) return false\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check generics
if (!isGeneric()) {
fprintf(stderr, "[%s:%d] (%s) if (!isGeneric())\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// --- Check number of generic parameters
auto genericParams = getGenericParams();
unsigned int expectedGenericParamNum = 1;
if (genericParams->size() != expectedGenericParamNum) {
fprintf(stderr, "[%s:%d] (%s) if (genericParams->size() != expectedGenericParamNum)\n", __FILE__, __LINE__, __FUNCTION__);
return false;
}
// === Check all parameters
auto params = getParameters();
if (params->size() != 1) {
return false;
}
// --- Check parameter: _ argument
auto argumentParam = params->get(0);
if (!argumentParam->getArgumentName().is("_")) {
return false;
}
// --- Check: Argument: SerializationRequirement
auto sig = getGenericSignature();
auto requirements = sig.getRequirements();
if (requirements.size() != 1) {
return false;
}
// === Check generic parameters in detail
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
// --- Check requirement: conforms_to: Err Error
auto errorReq = requirements[0];
auto errorTy = C.getProtocol(KnownProtocolKind::Error)
->getInterfaceType()
->getMetatypeInstanceType();
if (errorReq.getKind() != RequirementKind::Conformance) {
return false;
}
if (!errorReq.getSecondType()->isEqual(errorTy)) {
return false;
}
// === Check result type: Void
auto func = dyn_cast<FuncDecl>(this);
if (!func)
return false;
if (!func->getResultInterfaceType()->isVoid()) {
return false;
}
return true;
}
llvm::SmallPtrSet<ProtocolDecl *, 2>
swift::extractDistributedSerializationRequirements(
ASTContext &C, ArrayRef<Requirement> allRequirements) {

View File

@@ -1058,7 +1058,7 @@ void SILGenFunction::emitDistributedThunk(SILDeclRef thunk) {
}
nextNormalBB = nullptr;
FuncDecl *recordArgumentFnDecl =
AbstractFunctionDecl *recordArgumentFnDecl =
ctx.getRecordArgumentOnDistributedInvocationEncoder(
invocationEncoderNominal);
auto recordArgumentFnRef = SILDeclRef(recordArgumentFnDecl);
@@ -1210,7 +1210,7 @@ void SILGenFunction::emitDistributedThunk(SILDeclRef thunk) {
auto errorMetatypeValue = B.createMetatype(loc, errorMetatype);
// Get the function
FuncDecl *recordErrorTypeFnDecl =
AbstractFunctionDecl *recordErrorTypeFnDecl =
ctx.getRecordErrorTypeOnDistributedInvocationEncoder(
invocationEncoderNominal);
assert(recordErrorTypeFnDecl);
@@ -1258,7 +1258,7 @@ void SILGenFunction::emitDistributedThunk(SILDeclRef thunk) {
auto returnMetatypeValue = B.createMetatype(loc, returnMetatype);
// Get the function
FuncDecl *recordReturnTypeFnDecl =
AbstractFunctionDecl *recordReturnTypeFnDecl =
ctx.getRecordReturnTypeOnDistributedInvocationEncoder(
invocationEncoderNominal);
assert(recordReturnTypeFnDecl);

View File

@@ -80,40 +80,6 @@ static VarDecl *addImplicitDistributedActorIDProperty(
return propDecl;
}
/******************************************************************************/
/************ LOCATING AD-HOC PROTOCOL REQUIREMENT IMPLS **********************/
/******************************************************************************/
AbstractFunctionDecl*
GetDistributedActorSystemRemoteCallFunctionRequest::evaluate(
Evaluator &evaluator, NominalTypeDecl *decl, bool isVoidReturn) const {
auto &C = decl->getASTContext();
// It would be nice to check if this is a DistributedActorSystem
// "conforming" type, but we can't do this as we invoke this function WHILE
// deciding if the type conforms or not;
// Not via `ensureDistributedModuleLoaded` to avoid generating a warning,
// we won't be emitting the offending decl after all.
if (!C.getLoadedModule(C.Id_Distributed)) {
return nullptr;
}
auto callId = isVoidReturn ? C.Id_remoteCallVoid : C.Id_remoteCall;
AbstractFunctionDecl *remoteCallFunc = nullptr;
for (auto value : decl->lookupDirect(callId)) {
auto func = dyn_cast<AbstractFunctionDecl>(value);
if (func && func->isDistributedActorSystemRemoteCall(isVoidReturn)) {
remoteCallFunc = func;
break;
}
}
return remoteCallFunc;
}
/******************************************************************************/
/************************ SYNTHESIS ENTRY POINT *******************************/
/******************************************************************************/

View File

@@ -52,6 +52,78 @@ DistributedModuleIsAvailableRequest::evaluate(Evaluator &evaluator,
return false;
}
/******************************************************************************/
/************ LOCATING AD-HOC PROTOCOL REQUIREMENT IMPLS **********************/
/******************************************************************************/
static AbstractFunctionDecl *findDistributedAdHocRequirement(
NominalTypeDecl *decl, Identifier identifier,
std::function<bool(AbstractFunctionDecl *)> matchFn) {
auto &C = decl->getASTContext();
// It would be nice to check if this is a DistributedActorSystem
// "conforming" type, but we can't do this as we invoke this function WHILE
// deciding if the type conforms or not;
// Not via `ensureDistributedModuleLoaded` to avoid generating a warning,
// we won't be emitting the offending decl after all.
if (!C.getLoadedModule(C.Id_Distributed)) {
return nullptr;
}
for (auto value : decl->lookupDirect(identifier)) {
auto func = dyn_cast<AbstractFunctionDecl>(value);
if (func && matchFn(func))
return func;
}
return nullptr;
}
AbstractFunctionDecl *
GetDistributedActorSystemRemoteCallFunctionRequest::evaluate(
Evaluator &evaluator, NominalTypeDecl *decl, bool isVoidReturn) const {
auto &C = decl->getASTContext();
auto callId = isVoidReturn ? C.Id_remoteCallVoid : C.Id_remoteCall;
return findDistributedAdHocRequirement(
decl, callId, [isVoidReturn](AbstractFunctionDecl *func) {
return func->isDistributedActorSystemRemoteCall(isVoidReturn);
});
}
AbstractFunctionDecl *
GetDistributedTargetInvocationEncoderRecordArgumentFunctionRequest::evaluate(
Evaluator &evaluator, NominalTypeDecl *decl) const {
auto &C = decl->getASTContext();
return findDistributedAdHocRequirement(
decl, C.Id_recordArgument, [](AbstractFunctionDecl *func) {
return func->isDistributedTargetInvocationEncoderRecordArgument();
});
}
AbstractFunctionDecl *
GetDistributedTargetInvocationEncoderRecordReturnTypeFunctionRequest::evaluate(
Evaluator &evaluator, NominalTypeDecl *decl) const {
auto &C = decl->getASTContext();
return findDistributedAdHocRequirement(
decl, C.Id_recordReturnType, [](AbstractFunctionDecl *func) {
return func->isDistributedTargetInvocationEncoderRecordReturnType();
});
}
AbstractFunctionDecl *
GetDistributedTargetInvocationEncoderRecordErrorTypeFunctionRequest::evaluate(
Evaluator &evaluator, NominalTypeDecl *decl) const {
auto &C = decl->getASTContext();
return findDistributedAdHocRequirement(
decl, C.Id_recordErrorType, [](AbstractFunctionDecl *func) {
return func->isDistributedTargetInvocationEncoderRecordErrorType();
});
}
// ==== ------------------------------------------------------------------------
/// Add Fix-It text for the given protocol type to inherit DistributedActor.
@@ -213,7 +285,7 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
decl->getDescriptiveKind(), decl->getName(), C.Id_recordArgument);
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
decl->getName(), C.Id_recordArgument,
"mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws");
"mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws\n");
anyMissingAdHocRequirements = true;
}
@@ -224,7 +296,7 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
decl->getDescriptiveKind(), decl->getName(), C.Id_recordErrorType);
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
decl->getName(), C.Id_recordErrorType,
"mutating func recordErrorType<Err: Error>(_ errorType: Err.Type) throws");
"mutating func recordErrorType<Err: Error>(_ errorType: Err.Type) throws\n");
anyMissingAdHocRequirements = true;
}
@@ -235,7 +307,7 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
decl->getDescriptiveKind(), decl->getName(), C.Id_recordReturnType);
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
decl->getName(), C.Id_recordReturnType,
"mutating func recordReturnType<Res: SerializationRequirement>(_ resultType: Res.Type) throws");
"mutating func recordReturnType<Res: SerializationRequirement>(_ resultType: Res.Type) throws\n");
anyMissingAdHocRequirements = true;
}
@@ -340,8 +412,8 @@ bool swift::checkDistributedFunction(FuncDecl *func, bool diagnose) {
serializationRequirements = extractDistributedSerializationRequirements(
C, extension->getGenericRequirements());
} else if (auto actor = dyn_cast<ClassDecl>(declContext)) {
serializationRequirements =
getDistributedSerializationRequirementProtocols(actor);
serializationRequirements = getDistributedSerializationRequirementProtocols(
actor, C.getProtocol(KnownProtocolKind::DistributedActor));
} // TODO(distributed): need to handle ProtocolDecl too?
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
@@ -406,6 +478,8 @@ bool swift::checkDistributedFunction(FuncDecl *func, bool diagnose) {
/// \returns \c true if there was a problem with adding the attribute, \c false
/// otherwise.
bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
auto &C = var->getASTContext();
auto DC = var->getDeclContext();
/// === Check if the declaration is a valid combination of attributes
if (var->isStatic()) {
var->diagnose(diag::distributed_property_cannot_be_static,
@@ -431,7 +505,8 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
// === Check the type of the property
auto serializationRequirements =
getDistributedSerializationRequirementProtocols(
var->getDeclContext()->getSelfNominalTypeDecl());
DC->getSelfNominalTypeDecl(),
C.getProtocol(KnownProtocolKind::DistributedActor));
auto module = var->getModuleContext();
if (checkDistributedTargetResultType(module, var, serializationRequirements, diagnose)) {
@@ -531,10 +606,16 @@ void TypeChecker::checkDistributedActor(ClassDecl *decl) {
}
llvm::SmallPtrSet<ProtocolDecl *, 2>
swift::getDistributedSerializationRequirementProtocols(NominalTypeDecl *nominal) {
auto &ctx = nominal->getASTContext();
swift::getDistributedSerializationRequirementProtocols(
NominalTypeDecl *nominal, ProtocolDecl *protocol) {
if (!protocol)
return {};
auto ty = ctx.getDistributedSerializationRequirementType(nominal);
// auto ty = ctx.getDistributedSerializationRequirementType(nominal);
// if (ty->hasError())
// return {};
auto ty = getDistributedSerializationRequirementType(
nominal, protocol);
if (ty->hasError())
return {};
@@ -584,7 +665,7 @@ GetDistributedActorInvocationDecoderRequest::evaluate(Evaluator &evaluator,
NominalTypeDecl *actor) const {
auto &ctx = actor->getASTContext();
auto decoderTy =
ctx.getAssociatedTypeOfDistributedSystem(actor, ctx.Id_InvocationDecoder);
ctx.getAssociatedTypeOfDistributedSystemOfActor(actor, ctx.Id_InvocationDecoder);
return decoderTy->hasError() ? nullptr : decoderTy->getAnyNominal();
}
@@ -603,7 +684,8 @@ GetDistributedActorArgumentDecodingMethodRequest::evaluate(Evaluator &evaluator,
// typealias SerializationRequirement = any ...
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationReqs =
getDistributedSerializationRequirementProtocols(actor);
getDistributedSerializationRequirementProtocols(
actor, ctx.getProtocol(KnownProtocolKind::DistributedActor));
SmallVector<FuncDecl *, 2> candidates;
// Looking for `decodeNextArgument<Arg: <SerializationReq>>() throws -> Arg`

View File

@@ -493,6 +493,18 @@ struct FakeInvocationEncoder_missing_recordArgument: DistributedTargetInvocation
mutating func doneRecording() throws {}
}
struct FakeInvocationEncoder_missing_recordArgument2: DistributedTargetInvocationEncoder {
//expected-error@-1{{struct 'FakeInvocationEncoder_missing_recordArgument2' is missing witness for protocol requirement 'recordArgument'}}
//expected-note@-2{{protocol 'FakeInvocationEncoder_missing_recordArgument2' requires function 'recordArgument' with signature:}}
typealias SerializationRequirement = Codable
mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {}
mutating func recordArgument<Argument>(_ argument: Argument) throws {} // BAD
mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
mutating func recordErrorType<E: Error>(_ type: E.Type) throws {}
mutating func doneRecording() throws {}
}
struct FakeInvocationEncoder_missing_recordReturnType: DistributedTargetInvocationEncoder {
//expected-error@-1{{struct 'FakeInvocationEncoder_missing_recordReturnType' is missing witness for protocol requirement 'recordReturnType'}}
//expected-note@-2{{protocol 'FakeInvocationEncoder_missing_recordReturnType' requires function 'recordReturnType' with signature:}}