diff --git a/include/swift/AST/ActorIsolation.h b/include/swift/AST/ActorIsolation.h index 898438e8629..3d26c419269 100644 --- a/include/swift/AST/ActorIsolation.h +++ b/include/swift/AST/ActorIsolation.h @@ -25,6 +25,7 @@ class raw_ostream; namespace swift { class DeclContext; +class ModuleDecl; class NominalTypeDecl; class SubstitutionMap; @@ -33,7 +34,7 @@ class SubstitutionMap; bool areTypesEqual(Type type1, Type type2); /// Determine whether the given type is suitable as a concurrent value type. -bool isSendableType(const DeclContext *dc, Type type); +bool isSendableType(ModuleDecl *module, Type type); /// Describes the actor isolation of a given declaration, which determines /// the actors with which it can interact. diff --git a/include/swift/AST/Requirement.h b/include/swift/AST/Requirement.h index d0c1eef287a..bc81d9ced49 100644 --- a/include/swift/AST/Requirement.h +++ b/include/swift/AST/Requirement.h @@ -76,6 +76,21 @@ public: ProtocolDecl *getProtocolDecl() const; + /// Determines if this substituted requirement is satisfied. + /// + /// \param conditionalRequirements An out parameter initialized to an + /// array of requirements that the caller must check to ensure this + /// requirement is completely satisfied. + bool isSatisfied(ArrayRef &conditionalRequirements) const; + + /// Determines if this substituted requirement can ever be satisfied, + /// possibly with additional substitutions. + /// + /// For example, if 'T' is unconstrained, then a superclass requirement + /// 'T : C' can be satisfied; however, if 'T' already has an unrelated + /// superclass requirement, 'T : C' cannot be satisfied. + bool canBeSatisfied() const; + SWIFT_DEBUG_DUMP; void dump(raw_ostream &out) const; void print(raw_ostream &os, const PrintOptions &opts) const; diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index 8156395be8a..3b443ed2db6 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -5269,17 +5269,9 @@ bool hasAppliedSelf(ConstraintSystem &cs, const OverloadChoice &choice); bool hasAppliedSelf(const OverloadChoice &choice, llvm::function_ref getFixedType); -/// Check whether type conforms to a given known protocol. -bool conformsToKnownProtocol(DeclContext *dc, Type type, - KnownProtocolKind protocol); - -/// Check whether given type conforms to `RawPepresentable` protocol +/// Check whether given type conforms to `RawRepresentable` protocol /// and return witness type. Type isRawRepresentable(ConstraintSystem &cs, Type type); -/// Check whether given type conforms to a specific known kind -/// `RawPepresentable` protocol and return witness type. -Type isRawRepresentable(ConstraintSystem &cs, Type type, - KnownProtocolKind rawRepresentableProtocol); /// Compute the type that shall stand in for dynamic 'Self' in a member /// reference with a base of the given object type. diff --git a/lib/AST/GenericSignature.cpp b/lib/AST/GenericSignature.cpp index 516b0945f14..c8febf9b505 100644 --- a/lib/AST/GenericSignature.cpp +++ b/lib/AST/GenericSignature.cpp @@ -428,62 +428,28 @@ bool GenericSignatureImpl::areSameTypeParameterInContext(Type type1, bool GenericSignatureImpl::isRequirementSatisfied( Requirement requirement) const { - auto GSB = getGenericSignatureBuilder(); + if (requirement.getFirstType()->hasTypeParameter()) { + auto *genericEnv = getGenericEnvironment(); - auto firstType = requirement.getFirstType(); - auto canFirstType = getCanonicalTypeInContext(firstType); + auto substituted = requirement.subst( + [&](SubstitutableType *type) -> Type { + if (auto *paramType = type->getAs()) + return genericEnv->mapTypeIntoContext(paramType); - switch (requirement.getKind()) { - case RequirementKind::Conformance: { - auto *protocol = requirement.getProtocolDecl(); - - if (canFirstType->isTypeParameter()) - return requiresProtocol(canFirstType, protocol); - else - return (bool)GSB->lookupConformance(canFirstType, protocol); - } - - case RequirementKind::SameType: { - auto canSecondType = getCanonicalTypeInContext(requirement.getSecondType()); - return canFirstType->isEqual(canSecondType); - } - - case RequirementKind::Superclass: { - auto requiredSuperclass = - getCanonicalTypeInContext(requirement.getSecondType()); - - // The requirement could be in terms of type parameters like a user-written - // requirement, but it could also be in terms of concrete types if it has - // been substituted/otherwise 'resolved', so we need to handle both. - auto baseType = canFirstType; - if (baseType->isTypeParameter()) { - auto directSuperclass = getSuperclassBound(baseType); - if (!directSuperclass) - return false; - - baseType = getCanonicalTypeInContext(directSuperclass); - } - - return requiredSuperclass->isExactSuperclassOf(baseType); - } - - case RequirementKind::Layout: { - auto requiredLayout = requirement.getLayoutConstraint(); - - if (canFirstType->isTypeParameter()) { - if (auto layout = getLayoutConstraint(canFirstType)) - return static_cast(layout.merge(requiredLayout)); + return type; + }, + LookUpConformanceInSignature(this)); + if (!substituted) return false; - } - // The requirement is on a concrete type, so it's either globally correct - // or globally incorrect, independent of this generic context. The latter - // case should be diagnosed elsewhere, so let's assume it's correct. - return true; + requirement = *substituted; } - } - llvm_unreachable("unhandled kind"); + + // FIXME: Need to check conditional requirements here. + ArrayRef conditionalRequirements; + + return requirement.isSatisfied(conditionalRequirements); } SmallVector GenericSignatureImpl::requirementsNotSatisfiedBy( @@ -494,7 +460,7 @@ SmallVector GenericSignatureImpl::requirementsNotSatisfiedBy( if (otherSig.getPointer() == this) return result; // If there is no other signature, no requirements are satisfied. - if (!otherSig){ + if (!otherSig) { const auto reqs = getRequirements(); result.append(reqs.begin(), reqs.end()); return result; @@ -722,3 +688,67 @@ ProtocolDecl *Requirement::getProtocolDecl() const { assert(getKind() == RequirementKind::Conformance); return getSecondType()->castTo()->getDecl(); } + +bool +Requirement::isSatisfied(ArrayRef &conditionalRequirements) const { + switch (getKind()) { + case RequirementKind::Conformance: { + auto *proto = getProtocolDecl(); + auto *module = proto->getParentModule(); + auto conformance = module->lookupConformance(getFirstType(), proto); + if (!conformance) + return false; + + conditionalRequirements = conformance.getConditionalRequirements(); + return true; + } + + case RequirementKind::Layout: { + if (auto *archetypeType = getFirstType()->getAs()) { + auto layout = archetypeType->getLayoutConstraint(); + return (layout && layout.merge(getLayoutConstraint())); + } + + if (getLayoutConstraint()->isClass()) + return getFirstType()->satisfiesClassConstraint(); + + // TODO: Statically check other layout constraints, once they can + // be spelled in Swift. + return true; + } + + case RequirementKind::Superclass: + return getSecondType()->isExactSuperclassOf(getFirstType()); + + case RequirementKind::SameType: + return getFirstType()->isEqual(getSecondType()); + } + + llvm_unreachable("Bad requirement kind"); +} + +bool Requirement::canBeSatisfied() const { + switch (getKind()) { + case RequirementKind::Conformance: + return getFirstType()->is(); + + case RequirementKind::Layout: { + if (auto *archetypeType = getFirstType()->getAs()) { + auto layout = archetypeType->getLayoutConstraint(); + return (!layout || layout.merge(getLayoutConstraint())); + } + + return false; + } + + case RequirementKind::Superclass: + return (getFirstType()->isBindableTo(getSecondType()) || + getSecondType()->isBindableTo(getFirstType())); + + case RequirementKind::SameType: + return (getFirstType()->isBindableTo(getSecondType()) || + getSecondType()->isBindableTo(getFirstType())); + } + + llvm_unreachable("Bad requirement kind"); +} \ No newline at end of file diff --git a/lib/IDE/CodeCompletion.cpp b/lib/IDE/CodeCompletion.cpp index 4fdfa78f53c..dbdfe5f6373 100644 --- a/lib/IDE/CodeCompletion.cpp +++ b/lib/IDE/CodeCompletion.cpp @@ -2539,20 +2539,21 @@ public: // If the reference is 'async', all types must be 'Sendable'. if (implicitlyAsync && T) { + auto *M = CurrDeclContext->getParentModule(); if (isa(VD)) { - if (!isSendableType(CurrDeclContext, T)) { + if (!isSendableType(M, T)) { NotRecommended = NotRecommendedReason::CrossActorReference; } } else { assert(isa(VD) || isa(VD)); // Check if the result and the param types are all 'Sendable'. auto *AFT = T->castTo(); - if (!isSendableType(CurrDeclContext, AFT->getResult())) { + if (!isSendableType(M, AFT->getResult())) { NotRecommended = NotRecommendedReason::CrossActorReference; } else { for (auto ¶m : AFT->getParams()) { Type paramType = param.getPlainType(); - if (!isSendableType(CurrDeclContext, paramType)) { + if (!isSendableType(M, paramType)) { NotRecommended = NotRecommendedReason::CrossActorReference; break; } diff --git a/lib/IDE/IDETypeChecking.cpp b/lib/IDE/IDETypeChecking.cpp index 8b8a80befaa..f2e1057c90d 100644 --- a/lib/IDE/IDETypeChecking.cpp +++ b/lib/IDE/IDETypeChecking.cpp @@ -22,6 +22,7 @@ #include "swift/AST/Module.h" #include "swift/AST/NameLookup.h" #include "swift/AST/ProtocolConformance.h" +#include "swift/AST/Requirement.h" #include "swift/AST/SourceFile.h" #include "swift/AST/Types.h" #include "swift/Sema/IDETypeChecking.h" @@ -169,13 +170,13 @@ struct SynthesizedExtensionAnalyzer::Implementation { bool Unmergable; unsigned InheritsCount; std::set Requirements; - void addRequirement(GenericSignature GenericSig, - Type First, Type Second, RequirementKind Kind) { - CanType CanFirst = GenericSig->getCanonicalTypeInContext(First); - CanType CanSecond; - if (Second) CanSecond = GenericSig->getCanonicalTypeInContext(Second); + void addRequirement(GenericSignature GenericSig, swift::Requirement Req) { + auto First = Req.getFirstType(); + auto CanFirst = GenericSig->getCanonicalTypeInContext(First); + auto Second = Req.getSecondType(); + auto CanSecond = GenericSig->getCanonicalTypeInContext(Second); - Requirements.insert({First, Second, Kind, CanFirst, CanSecond}); + Requirements.insert({First, Second, Req.getKind(), CanFirst, CanSecond}); } bool operator== (const ExtensionMergeInfo& Another) const { // Trivially unmergeable. @@ -289,84 +290,53 @@ struct SynthesizedExtensionAnalyzer::Implementation { ProtocolDecl *BaseProto = OwningExt->getInnermostDeclContext() ->getSelfProtocolDecl(); for (auto Req : Reqs) { - auto Kind = Req.getKind(); - - // FIXME: Could do something here - if (Kind == RequirementKind::Layout) + // FIXME: Don't skip layout requirements. + if (Req.getKind() == RequirementKind::Layout) continue; - Type First = Req.getFirstType(); - Type Second = Req.getSecondType(); - // Skip protocol's Self : requirement. if (BaseProto && Req.getKind() == RequirementKind::Conformance && - First->isEqual(BaseProto->getSelfInterfaceType()) && - Second->getAnyNominal() == BaseProto) + Req.getFirstType()->isEqual(BaseProto->getSelfInterfaceType()) && + Req.getProtocolDecl() == BaseProto) continue; if (!BaseType->isExistentialType()) { // Apply any substitutions we need to map the requirements from a // a protocol extension to an extension on the conforming type. - First = First.subst(subMap); - Second = Second.subst(subMap); - - if (First->hasError() || Second->hasError()) { + auto SubstReq = Req.subst(subMap); + if (!SubstReq) { // Substitution with interface type bases can only fail // if a concrete type fails to conform to a protocol. // In this case, just give up on the extension altogether. return true; } + + Req = *SubstReq; } - assert(!First->hasArchetype() && !Second->hasArchetype()); - switch (Kind) { - case RequirementKind::Conformance: { - auto *M = DC->getParentModule(); - auto *Proto = Second->castTo()->getDecl(); - if (!First->isTypeParameter() && - M->conformsToProtocol(First, Proto).isInvalid()) - return true; - if (M->conformsToProtocol(First, Proto).isInvalid()) - MergeInfo.addRequirement(GenericSig, First, Second, Kind); - break; - } + assert(!Req.getFirstType()->hasArchetype()); + assert(!Req.getSecondType()->hasArchetype()); - case RequirementKind::Superclass: - // If the subject type of the requirement is still a type parameter, - // we need to check if the contextual type could possibly be bound to - // the superclass. If not, this extension isn't applicable. - if (First->isTypeParameter()) { - if (!Target->mapTypeIntoContext(First)->isBindableTo( - Target->mapTypeIntoContext(Second))) { - return true; - } - MergeInfo.addRequirement(GenericSig, First, Second, Kind); - break; - } + auto *M = DC->getParentModule(); + auto SubstReq = Req.subst( + [&](Type type) -> Type { + if (type->isTypeParameter()) + return Target->mapTypeIntoContext(type); - // If we've substituted in a concrete type for the subject, we can - // check for an exact superclass match, and disregard the extension if - // it missed. - // FIXME: What if it ends being something like `C : C`? - // Arguably we should allow that to be mirrored with a U == Int - // constraint. - if (!Second->isExactSuperclassOf(First)) + return type; + }, + LookUpConformanceInModule(M)); + if (!SubstReq) + return true; + + // FIXME: Need to handle conditional requirements here! + ArrayRef conditionalRequirements; + if (!SubstReq->isSatisfied(conditionalRequirements)) { + if (!SubstReq->canBeSatisfied()) return true; - break; - - case RequirementKind::SameType: - if (!First->isBindableTo(Second) && - !Second->isBindableTo(First)) { - return true; - } else if (!First->isEqual(Second)) { - MergeInfo.addRequirement(GenericSig, First, Second, Kind); - } - break; - - case RequirementKind::Layout: - llvm_unreachable("Handled above"); + MergeInfo.addRequirement(GenericSig, Req); } } return false; diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 073d527b7ff..00951b71d9e 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -101,7 +101,7 @@ Solution::computeSubstitutions(GenericSignature sig, // FIXME: Retrieve the conformance from the solution itself. return TypeChecker::conformsToProtocol(replacement, protoType, - getConstraintSystem().DC); + getConstraintSystem().DC->getParentModule()); }; return SubstitutionMap::get(sig, @@ -620,7 +620,7 @@ namespace { // the protocol requirement with Self == the concrete type, and SILGen // (or later) can devirtualize as appropriate. auto conformance = - TypeChecker::conformsToProtocol(baseTy, proto, cs.DC); + TypeChecker::conformsToProtocol(baseTy, proto, cs.DC->getParentModule()); if (conformance.isConcrete()) { if (auto witness = conformance.getConcrete()->getWitnessDecl(decl)) { bool isMemberOperator = witness->getDeclContext()->isTypeContext(); @@ -2418,7 +2418,7 @@ namespace { auto bridgedToObjectiveCConformance = TypeChecker::conformsToProtocol(valueType, bridgedProto, - cs.DC); + cs.DC->getParentModule()); FuncDecl *fn = nullptr; @@ -2678,7 +2678,7 @@ namespace { ProtocolDecl *protocol = TypeChecker::getProtocol( ctx, expr->getLoc(), KnownProtocolKind::ExpressibleByStringLiteral); - if (!TypeChecker::conformsToProtocol(type, protocol, cs.DC)) { + if (!TypeChecker::conformsToProtocol(type, protocol, cs.DC->getParentModule())) { // If the type does not conform to ExpressibleByStringLiteral, it should // be ExpressibleByExtendedGraphemeClusterLiteral. protocol = TypeChecker::getProtocol( @@ -2687,7 +2687,7 @@ namespace { isStringLiteral = false; isGraphemeClusterLiteral = true; } - if (!TypeChecker::conformsToProtocol(type, protocol, cs.DC)) { + if (!TypeChecker::conformsToProtocol(type, protocol, cs.DC->getParentModule())) { // ... or it should be ExpressibleByUnicodeScalarLiteral. protocol = TypeChecker::getProtocol( cs.getASTContext(), expr->getLoc(), @@ -2802,7 +2802,7 @@ namespace { assert(proto && "Missing string interpolation protocol?"); auto conformance = - TypeChecker::conformsToProtocol(type, proto, cs.DC); + TypeChecker::conformsToProtocol(type, proto, cs.DC->getParentModule()); assert(conformance && "string interpolation type conforms to protocol"); DeclName constrName(ctx, DeclBaseName::createConstructor(), argLabels); @@ -2908,7 +2908,8 @@ namespace { auto proto = TypeChecker::getLiteralProtocol(ctx, expr); assert(proto && "Missing object literal protocol?"); auto conformance = - TypeChecker::conformsToProtocol(conformingType, proto, cs.DC); + TypeChecker::conformsToProtocol(conformingType, proto, + cs.DC->getParentModule()); assert(conformance && "object literal type conforms to protocol"); auto constrName = TypeChecker::getObjectLiteralConstructorName(ctx, expr); @@ -3511,7 +3512,8 @@ namespace { assert(arrayProto && "type-checked array literal w/o protocol?!"); auto conformance = - TypeChecker::conformsToProtocol(arrayTy, arrayProto, cs.DC); + TypeChecker::conformsToProtocol(arrayTy, arrayProto, + cs.DC->getParentModule()); assert(conformance && "Type does not conform to protocol?"); DeclName name(ctx, DeclBaseName::createConstructor(), @@ -3555,7 +3557,8 @@ namespace { KnownProtocolKind::ExpressibleByDictionaryLiteral); auto conformance = - TypeChecker::conformsToProtocol(dictionaryTy, dictionaryProto, cs.DC); + TypeChecker::conformsToProtocol(dictionaryTy, dictionaryProto, + cs.DC->getParentModule()); if (conformance.isInvalid()) return nullptr; @@ -4298,7 +4301,7 @@ namespace { // Special handle for literals conditional checked cast when they can // be statically coerced to the cast type. if (protocol && TypeChecker::conformsToProtocol( - toType, protocol, cs.DC)) { + toType, protocol, cs.DC->getParentModule())) { ctx.Diags .diagnose(expr->getLoc(), diag::literal_conditional_downcast_to_coercion, @@ -5193,7 +5196,8 @@ namespace { // verified by the solver, we just need to get it again // with all of the generic parameters resolved. auto hashableConformance = - TypeChecker::conformsToProtocol(indexType, hashable, cs.DC); + TypeChecker::conformsToProtocol(indexType, hashable, + cs.DC->getParentModule()); assert(hashableConformance); conformances.push_back(hashableConformance); @@ -5503,13 +5507,13 @@ Expr *ExprRewriter::coerceSuperclass(Expr *expr, Type toType) { /// conformances. static ArrayRef collectExistentialConformances(Type fromType, Type toType, - DeclContext *DC) { + ModuleDecl *module) { auto layout = toType->getExistentialLayout(); SmallVector conformances; for (auto proto : layout.getProtocols()) { conformances.push_back(TypeChecker::containsProtocol( - fromType, proto->getDecl(), DC)); + fromType, proto->getDecl(), module)); } return toType->getASTContext().AllocateCopy(conformances); @@ -5532,7 +5536,8 @@ Expr *ExprRewriter::coerceExistential(Expr *expr, Type toType) { ASTContext &ctx = cs.getASTContext(); auto conformances = - collectExistentialConformances(fromInstanceType, toInstanceType, cs.DC); + collectExistentialConformances(fromInstanceType, toInstanceType, + cs.DC->getParentModule()); // For existential-to-existential coercions, open the source existential. if (fromType->isAnyExistentialType()) { @@ -6689,7 +6694,7 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType, auto hashable = ctx.getProtocol(KnownProtocolKind::Hashable); auto conformance = TypeChecker::conformsToProtocol( - cs.getType(expr), hashable, cs.DC); + cs.getType(expr), hashable, cs.DC->getParentModule()); assert(conformance && "must conform to Hashable"); return cs.cacheType( @@ -7359,7 +7364,7 @@ Expr *ExprRewriter::convertLiteralInPlace( // initialize via the builtin protocol. if (builtinProtocol) { auto builtinConformance = TypeChecker::conformsToProtocol( - type, builtinProtocol, cs.DC); + type, builtinProtocol, cs.DC->getParentModule()); if (builtinConformance) { // Find the witness that we'll use to initialize the type via a builtin // literal. @@ -7382,7 +7387,8 @@ Expr *ExprRewriter::convertLiteralInPlace( // This literal type must conform to the (non-builtin) protocol. assert(protocol && "requirements should have stopped recursion"); - auto conformance = TypeChecker::conformsToProtocol(type, protocol, cs.DC); + auto conformance = TypeChecker::conformsToProtocol(type, protocol, + cs.DC->getParentModule()); assert(conformance && "must conform to literal protocol"); // Dig out the literal type and perform a builtin literal conversion to it. @@ -7510,7 +7516,8 @@ ExprRewriter::buildDynamicCallable(ApplyExpr *apply, SelectedOverload selected, auto dictLitProto = ctx.getProtocol(KnownProtocolKind::ExpressibleByDictionaryLiteral); auto conformance = - TypeChecker::conformsToProtocol(argumentType, dictLitProto, cs.DC); + TypeChecker::conformsToProtocol(argumentType, dictLitProto, + cs.DC->getParentModule()); auto keyType = conformance.getTypeWitnessByName(argumentType, ctx.Id_Key); auto valueType = conformance.getTypeWitnessByName(argumentType, ctx.Id_Value); @@ -8448,7 +8455,7 @@ static Optional applySolutionToForEachStmt( stmt->getAwaitLoc().isValid() ? KnownProtocolKind::AsyncSequence : KnownProtocolKind::Sequence); auto sequenceConformance = TypeChecker::conformsToProtocol( - forEachStmtInfo.sequenceType, sequenceProto, cs.DC); + forEachStmtInfo.sequenceType, sequenceProto, cs.DC->getParentModule()); assert(!sequenceConformance.isInvalid() && "Couldn't find sequence conformance"); diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 18c15e9aabd..15a78e2c330 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -859,7 +859,8 @@ bool LiteralRequirement::isCoveredBy(Type type, DeclContext *useDC) const { if (hasDefaultType() && coversDefaultType(type, getDefaultType())) return true; - return bool(TypeChecker::conformsToProtocol(type, getProtocol(), useDC)); + return (bool)TypeChecker::conformsToProtocol(type, getProtocol(), + useDC->getParentModule()); } std::pair diff --git a/lib/Sema/CSDiagnostics.cpp b/lib/Sema/CSDiagnostics.cpp index 1db9a8467bd..3057e82eaf3 100644 --- a/lib/Sema/CSDiagnostics.cpp +++ b/lib/Sema/CSDiagnostics.cpp @@ -165,7 +165,8 @@ Type FailureDiagnostic::restoreGenericParameters( bool FailureDiagnostic::conformsToKnownProtocol( Type type, KnownProtocolKind protocol) const { auto &cs = getConstraintSystem(); - return constraints::conformsToKnownProtocol(cs.DC, type, protocol); + return TypeChecker::conformsToKnownProtocol(type, protocol, + cs.DC->getParentModule()); } Type RequirementFailure::getOwnerType() const { @@ -2073,10 +2074,10 @@ AssignmentFailure::getMemberRef(ConstraintLocator *locator) const { if (!member->choice.isDecl()) return member->choice; - auto *DC = getDC(); auto *decl = member->choice.getDecl(); if (isa(decl) && - isValidDynamicMemberLookupSubscript(cast(decl), DC)) { + isValidDynamicMemberLookupSubscript(cast(decl), + getParentModule())) { auto *subscript = cast(decl); // If this is a keypath dynamic member lookup, we have to // adjust the locator to find member referred by it. @@ -2754,7 +2755,7 @@ bool ContextualFailure::diagnoseThrowsTypeMismatch() const { Ctx.getProtocol(KnownProtocolKind::ErrorCodeProtocol)) { Type errorCodeType = getFromType(); auto conformance = TypeChecker::conformsToProtocol( - errorCodeType, errorCodeProtocol, getDC()); + errorCodeType, errorCodeProtocol, getParentModule()); if (conformance) { Type errorType = conformance @@ -2962,7 +2963,8 @@ bool ContextualFailure::tryProtocolConformanceFixIt( SmallVector missingProtoTypeStrings; SmallVector missingProtocols; for (auto protocol : layout.getProtocols()) { - if (!TypeChecker::conformsToProtocol(fromType, protocol->getDecl(), getDC())) { + if (!TypeChecker::conformsToProtocol(fromType, protocol->getDecl(), + getParentModule())) { missingProtoTypeStrings.push_back(protocol->getString()); missingProtocols.push_back(protocol->getDecl()); } diff --git a/lib/Sema/CSDiagnostics.h b/lib/Sema/CSDiagnostics.h index b6b05484264..93955ee2a17 100644 --- a/lib/Sema/CSDiagnostics.h +++ b/lib/Sema/CSDiagnostics.h @@ -150,6 +150,10 @@ protected: return cs.DC; } + ModuleDecl *getParentModule() const { + return getDC()->getParentModule(); + } + ASTContext &getASTContext() const { auto &cs = getConstraintSystem(); return cs.getASTContext(); diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 47d70b1137c..140e93744cf 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -411,7 +411,7 @@ namespace { if (otherArgTy && otherArgTy->getAnyNominal()) { if (otherArgTy->isEqual(paramTy) && TypeChecker::conformsToProtocol( - otherArgTy, literalProto, CS.DC)) { + otherArgTy, literalProto, CS.DC->getParentModule())) { return true; } } else if (Type defaultType = @@ -1737,13 +1737,15 @@ namespace { if (!contextualType) return false; + auto *M = CS.DC->getParentModule(); + auto type = contextualType->lookThroughAllOptionalTypes(); - if (conformsToKnownProtocol( - CS.DC, type, KnownProtocolKind::ExpressibleByArrayLiteral)) + if (TypeChecker::conformsToKnownProtocol( + type, KnownProtocolKind::ExpressibleByArrayLiteral, M)) return false; - return conformsToKnownProtocol( - CS.DC, type, KnownProtocolKind::ExpressibleByDictionaryLiteral); + return TypeChecker::conformsToKnownProtocol( + type, KnownProtocolKind::ExpressibleByDictionaryLiteral, M); }; if (isDictionaryContextualType(contextualType)) { @@ -4135,31 +4137,6 @@ void ConstraintSystem::optimizeConstraints(Expr *e) { e->walk(optimizer); } -bool swift::areGenericRequirementsSatisfied( - const DeclContext *DC, GenericSignature sig, - SubstitutionMap Substitutions, bool isExtension) { - - ConstraintSystemOptions Options; - ConstraintSystem CS(const_cast(DC), Options); - auto Loc = CS.getConstraintLocator({}); - - // For every requirement, add a constraint. - for (auto Req : sig->getRequirements()) { - if (auto resolved = Req.subst( - QuerySubstitutionMap{Substitutions}, - LookUpConformanceInModule(DC->getParentModule()))) { - CS.addConstraint(*resolved, Loc); - } else if (isExtension) { - return false; - } - // Unresolved requirements are requirements of the function itself. This - // does not prevent it from being applied. E.g. func foo(x: T). - } - - // Having a solution implies the requirements have been fulfilled. - return CS.solveSingle().hasValue(); -} - struct ResolvedMemberResult::Implementation { llvm::SmallVector AllDecls; unsigned ViableStartIdx; diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 7d77073ee15..b73891efa31 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -3469,9 +3469,10 @@ static bool repairArrayLiteralUsedAsDictionary( if (unwrappedDict->isTypeVariableOrMember()) return false; - if (!conformsToKnownProtocol( - cs.DC, unwrappedDict, - KnownProtocolKind::ExpressibleByDictionaryLiteral)) + if (!TypeChecker::conformsToKnownProtocol( + unwrappedDict, + KnownProtocolKind::ExpressibleByDictionaryLiteral, + cs.DC->getParentModule())) return false; // Ignore any attempts at promoting the value to an optional as even after @@ -6108,7 +6109,8 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint( switch (kind) { case ConstraintKind::SelfObjectOfProtocol: { auto conformance = TypeChecker::containsProtocol( - type, protocol, DC, /*skipConditionalRequirements=*/true); + type, protocol, DC->getParentModule(), + /*skipConditionalRequirements=*/true); if (conformance) { return recordConformance(conformance); } @@ -6237,7 +6239,8 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint( if (auto rawValue = isRawRepresentable(*this, type)) { if (!rawValue->isTypeVariableOrMember() && - TypeChecker::conformsToProtocol(rawValue, protocol, DC)) { + TypeChecker::conformsToProtocol(rawValue, protocol, + DC->getParentModule())) { auto *fix = UseRawValue::create(*this, type, protocolTy, loc); return recordFix(fix) ? SolutionKind::Error : SolutionKind::Solved; } @@ -7575,7 +7578,8 @@ performMemberLookup(ConstraintKind constraintKind, DeclNameRef memberName, auto *SD = cast(candidate.getDecl()); bool isKeyPathBased = isValidKeyPathDynamicMemberLookup(SD); - if (isValidStringDynamicMemberLookup(SD, DC) || isKeyPathBased) + if (isValidStringDynamicMemberLookup(SD, DC->getParentModule()) || + isKeyPathBased) result.addViable(OverloadChoice::getDynamicMemberLookup( baseTy, SD, name, isKeyPathBased)); } @@ -10212,7 +10216,6 @@ lookupDynamicCallableMethods(Type type, ConstraintSystem &CS, const ConstraintLocatorBuilder &locator, Identifier argumentName, bool hasKeywordArgs) { auto &ctx = CS.getASTContext(); - auto decl = type->getAnyNominal(); DeclNameRef methodName({ ctx, ctx.Id_dynamicallyCall, { argumentName } }); auto matches = CS.performMemberLookup( ConstraintKind::ValueMember, methodName, type, @@ -10222,7 +10225,8 @@ lookupDynamicCallableMethods(Type type, ConstraintSystem &CS, auto candidates = matches.ViableCandidates; auto filter = [&](OverloadChoice choice) { auto cand = cast(choice.getDecl()); - return !isValidDynamicCallableMethod(cand, decl, hasKeywordArgs); + return !isValidDynamicCallableMethod(cand, CS.DC->getParentModule(), + hasKeywordArgs); }; candidates.erase( std::remove_if(candidates.begin(), candidates.end(), filter), diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index e1961547f52..5e19a1f75f5 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -1890,7 +1890,7 @@ void DisjunctionChoiceProducer::partitionGenericOperators( return refined->inheritsFrom(protocol); return (bool)TypeChecker::conformsToProtocol(nominal->getDeclaredType(), protocol, - nominal->getDeclContext()); + CS.DC->getParentModule()); }; // Gather Numeric and Sequence overloads into separate buckets. @@ -1938,15 +1938,18 @@ void DisjunctionChoiceProducer::partitionGenericOperators( if (argType->isTypeVariableOrMember()) continue; - if (conformsToKnownProtocol(CS.DC, argType, - KnownProtocolKind::AdditiveArithmetic)) { + if (TypeChecker::conformsToKnownProtocol( + argType, KnownProtocolKind::AdditiveArithmetic, + CS.DC->getParentModule())) { first = std::copy(numericOverloads.begin(), numericOverloads.end(), first); numericOverloads.clear(); break; } - if (conformsToKnownProtocol(CS.DC, argType, KnownProtocolKind::Sequence)) { + if (TypeChecker::conformsToKnownProtocol( + argType, KnownProtocolKind::Sequence, + CS.DC->getParentModule())) { first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first); sequenceOverloads.clear(); diff --git a/lib/Sema/CSStep.cpp b/lib/Sema/CSStep.cpp index 88ce76e44ff..14616bda4a8 100644 --- a/lib/Sema/CSStep.cpp +++ b/lib/Sema/CSStep.cpp @@ -538,8 +538,8 @@ StepResult DisjunctionStep::resume(bool prevFailed) { } bool IsDeclRefinementOfRequest::evaluate(Evaluator &evaluator, - ValueDecl *declA, - ValueDecl *declB) const { + ValueDecl *declA, + ValueDecl *declB) const { auto *typeA = declA->getInterfaceType()->getAs(); auto *typeB = declB->getInterfaceType()->getAs(); @@ -578,8 +578,7 @@ bool IsDeclRefinementOfRequest::evaluate(Evaluator &evaluator, return false; auto result = TypeChecker::checkGenericArguments( - declA->getDeclContext(), SourceLoc(), SourceLoc(), typeB, - genericSignatureB->getGenericParams(), + declA->getDeclContext()->getParentModule(), genericSignatureB->getRequirements(), QueryTypeSubstitutionMap{ substMap }); @@ -681,7 +680,8 @@ bool DisjunctionStep::shouldSkip(const DisjunctionChoice &choice) const { continue; for (auto *protocol : signature->getRequiredProtocols(paramType)) { - if (!TypeChecker::conformsToProtocol(argType, protocol, useDC)) + if (!TypeChecker::conformsToProtocol(argType, protocol, + useDC->getParentModule())) return skip("unsatisfied"); } } diff --git a/lib/Sema/CodeSynthesis.cpp b/lib/Sema/CodeSynthesis.cpp index 246665bfe1f..0a79e08c92c 100644 --- a/lib/Sema/CodeSynthesis.cpp +++ b/lib/Sema/CodeSynthesis.cpp @@ -745,8 +745,7 @@ createDesignatedInitOverride(ClassDecl *classDecl, // satisfied by the derived class. In this case, we don't want to inherit // this initializer; there's no way to call it on the derived class. auto checkResult = TypeChecker::checkGenericArguments( - superclassCtor, SourceLoc(), SourceLoc(), Type(), - superclassCtorSig->getGenericParams(), + classDecl->getParentModule(), superclassCtorSig->getRequirements(), [&](Type type) -> Type { auto substType = type.subst(overrideInfo.OverrideSubMap); diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 9ccbe93e0e6..281f04d2940 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -4347,14 +4347,6 @@ bool constraints::hasAppliedSelf(const OverloadChoice &choice, doesMemberRefApplyCurriedSelf(baseType, decl); } -bool constraints::conformsToKnownProtocol(DeclContext *dc, Type type, - KnownProtocolKind protocol) { - if (auto *proto = - TypeChecker::getProtocol(dc->getASTContext(), SourceLoc(), protocol)) - return (bool)TypeChecker::conformsToProtocol(type, proto, dc); - return false; -} - /// Check whether given type conforms to `RawPepresentable` protocol /// and return the witness type. Type constraints::isRawRepresentable(ConstraintSystem &cs, Type type) { @@ -4365,24 +4357,14 @@ Type constraints::isRawRepresentable(ConstraintSystem &cs, Type type) { if (!rawReprType) return Type(); - auto conformance = TypeChecker::conformsToProtocol(type, rawReprType, DC); + auto conformance = TypeChecker::conformsToProtocol(type, rawReprType, + DC->getParentModule()); if (conformance.isInvalid()) return Type(); return conformance.getTypeWitnessByName(type, cs.getASTContext().Id_RawValue); } -Type constraints::isRawRepresentable( - ConstraintSystem &cs, Type type, - KnownProtocolKind rawRepresentableProtocol) { - Type rawTy = isRawRepresentable(cs, type); - if (!rawTy || - !conformsToKnownProtocol(cs.DC, rawTy, rawRepresentableProtocol)) - return Type(); - - return rawTy; -} - void ConstraintSystem::generateConstraints( SmallVectorImpl &constraints, Type type, ArrayRef choices, DeclContext *useDC, diff --git a/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp b/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp index f530c47b4ea..819d187a6dc 100644 --- a/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp +++ b/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp @@ -77,7 +77,8 @@ bool DerivedConformance::canDeriveAdditiveArithmetic(NominalTypeDecl *nominal, if (v->getInterfaceType()->hasError()) return false; auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType()); - return (bool)TypeChecker::conformsToProtocol(varType, proto, DC); + return (bool)TypeChecker::conformsToProtocol(varType, proto, + DC->getParentModule()); }); } diff --git a/lib/Sema/DerivedConformanceCodable.cpp b/lib/Sema/DerivedConformanceCodable.cpp index 67bd8332f3d..51a24ba9336 100644 --- a/lib/Sema/DerivedConformanceCodable.cpp +++ b/lib/Sema/DerivedConformanceCodable.cpp @@ -249,7 +249,7 @@ static EnumDecl *validateCodingKeysType(const DerivedConformance &derived, // Ensure that the type we found conforms to the CodingKey protocol. auto *codingKeyProto = C.getProtocol(KnownProtocolKind::CodingKey); if (!TypeChecker::conformsToProtocol(codingKeysType, codingKeyProto, - derived.getConformanceContext())) { + derived.getParentModule())) { // If CodingKeys is a typealias which doesn't point to a valid nominal type, // codingKeysTypeDecl will be nullptr here. In that case, we need to warn on // the location of the usage, since there isn't an underlying type to @@ -308,7 +308,7 @@ static bool validateCodingKeysEnum(const DerivedConformance &derived, auto target = derived.getConformanceContext()->mapTypeIntoContext( it->second->getValueInterfaceType()); if (TypeChecker::conformsToProtocol(target, derived.Protocol, - derived.getConformanceContext()) + derived.getParentModule()) .isInvalid()) { TypeLoc typeLoc = { it->second->getTypeReprOrParentPatternTypeRepr(), @@ -1828,7 +1828,8 @@ static bool canSynthesize(DerivedConformance &derived, ValueDecl *requirement) { if (auto *superclassDecl = classDecl->getSuperclassDecl()) { DeclName memberName; auto superType = superclassDecl->getDeclaredInterfaceType(); - if (TypeChecker::conformsToProtocol(superType, proto, superclassDecl)) { + if (TypeChecker::conformsToProtocol(superType, proto, + derived.getParentModule())) { // super.init(from:) must be accessible. memberName = cast(requirement)->getName(); } else { diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index e9e1aa8dd04..9fcd5f9610a 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -84,7 +84,7 @@ getStoredPropertiesForDifferentiation( continue; auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType()); auto conformance = TypeChecker::conformsToProtocol( - varType, diffableProto, nominal); + varType, diffableProto, DC->getParentModule()); if (!conformance) continue; // Skip `let` stored properties with a mutating `move(by:)` if requested. @@ -113,11 +113,12 @@ static StructDecl *convertToStructDecl(ValueDecl *v) { /// for the given interface type and declaration context. static Type getTangentVectorInterfaceType(Type contextualType, DeclContext *DC) { - auto &C = contextualType->getASTContext(); + auto &C = DC->getASTContext(); auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); assert(diffableProto && "`Differentiable` protocol not found"); auto conf = - TypeChecker::conformsToProtocol(contextualType, diffableProto, DC); + TypeChecker::conformsToProtocol(contextualType, diffableProto, + DC->getParentModule()); assert(conf && "Contextual type must conform to `Differentiable`"); if (!conf) return nullptr; @@ -139,7 +140,8 @@ static bool canDeriveTangentVectorAsSelf(NominalTypeDecl *nominal, auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); // `Self` must conform to `AdditiveArithmetic`. - if (!TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto, DC)) + if (!TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto, + DC->getParentModule())) return false; for (auto *field : nominal->getStoredProperties()) { // `Self` must not have any `@noDerivative` stored properties. @@ -147,7 +149,8 @@ static bool canDeriveTangentVectorAsSelf(NominalTypeDecl *nominal, return false; // `Self` must have all stored properties satisfy `Self == TangentVector`. auto fieldType = DC->mapTypeIntoContext(field->getValueInterfaceType()); - auto conf = TypeChecker::conformsToProtocol(fieldType, diffableProto, DC); + auto conf = TypeChecker::conformsToProtocol(fieldType, diffableProto, + DC->getParentModule()); if (!conf) return false; auto tangentType = conf.getTypeWitnessByName(fieldType, C.Id_TangentVector); @@ -210,7 +213,8 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, if (v->getInterfaceType()->hasError()) return false; auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType()); - return (bool)TypeChecker::conformsToProtocol(varType, diffableProto, DC); + return (bool)TypeChecker::conformsToProtocol(varType, diffableProto, + DC->getParentModule()); }); } @@ -551,7 +555,8 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context, // Check whether to diagnose stored property. auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType()); auto diffableConformance = - TypeChecker::conformsToProtocol(varType, diffableProto, nominal); + TypeChecker::conformsToProtocol(varType, diffableProto, + DC->getParentModule()); // If stored property should not be diagnosed, continue. if (diffableConformance && canInvokeMoveByOnProperty(vd, diffableConformance)) diff --git a/lib/Sema/DerivedConformanceEquatableHashable.cpp b/lib/Sema/DerivedConformanceEquatableHashable.cpp index ad1929ac3b5..dc8d9db0d2a 100644 --- a/lib/Sema/DerivedConformanceEquatableHashable.cpp +++ b/lib/Sema/DerivedConformanceEquatableHashable.cpp @@ -864,17 +864,16 @@ static ValueDecl *deriveHashable_hashValue(DerivedConformance &derived) { // We can't form a Hashable conformance if Int isn't Hashable or // ExpressibleByIntegerLiteral. - if (TypeChecker::conformsToProtocol( - intType, C.getProtocol(KnownProtocolKind::Hashable), parentDC) - .isInvalid()) { + if (!TypeChecker::conformsToKnownProtocol( + intType, KnownProtocolKind::Hashable, + derived.getParentModule())) { derived.ConformanceDecl->diagnose(diag::broken_int_hashable_conformance); return nullptr; } - ProtocolDecl *intLiteralProto = - C.getProtocol(KnownProtocolKind::ExpressibleByIntegerLiteral); - if (TypeChecker::conformsToProtocol(intType, intLiteralProto, parentDC) - .isInvalid()) { + if (!TypeChecker::conformsToKnownProtocol( + intType, KnownProtocolKind::ExpressibleByIntegerLiteral, + derived.getParentModule())) { derived.ConformanceDecl->diagnose( diag::broken_int_integer_literal_convertible_conformance); return nullptr; diff --git a/lib/Sema/DerivedConformanceRawRepresentable.cpp b/lib/Sema/DerivedConformanceRawRepresentable.cpp index b9d4282709b..694c0d83ceb 100644 --- a/lib/Sema/DerivedConformanceRawRepresentable.cpp +++ b/lib/Sema/DerivedConformanceRawRepresentable.cpp @@ -403,12 +403,9 @@ deriveRawRepresentable_init(DerivedConformance &derived) { assert([&]() -> bool { - auto equatableProto = TypeChecker::getProtocol(C, enumDecl->getLoc(), - KnownProtocolKind::Equatable); - if (!equatableProto) { - return false; - } - return !TypeChecker::conformsToProtocol(rawType, equatableProto, enumDecl).isInvalid(); + return TypeChecker::conformsToKnownProtocol( + rawType, KnownProtocolKind::Equatable, + derived.getParentModule()); }()); auto *rawDecl = new (C) @@ -464,14 +461,8 @@ bool DerivedConformance::canDeriveRawRepresentable(DeclContext *DC, // The raw type must be Equatable, so that we have a suitable ~= for // synthesized switch statements. - auto equatableProto = - TypeChecker::getProtocol(enumDecl->getASTContext(), enumDecl->getLoc(), - KnownProtocolKind::Equatable); - if (!equatableProto) - return false; - - if (TypeChecker::conformsToProtocol(rawType, equatableProto, DC) - .isInvalid()) + if (!TypeChecker::conformsToKnownProtocol(rawType, KnownProtocolKind::Equatable, + DC->getParentModule())) return false; auto &C = type->getASTContext(); diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 371ad8640d2..643c5f8125d 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -40,6 +40,10 @@ DeclContext *DerivedConformance::getConformanceContext() const { return cast(ConformanceDecl); } +ModuleDecl *DerivedConformance::getParentModule() const { + return cast(ConformanceDecl)->getParentModule(); +} + void DerivedConformance::addMembersToConformanceContext( ArrayRef children) { auto IDC = cast(ConformanceDecl); @@ -162,7 +166,7 @@ DerivedConformance::storedPropertiesNotConformingToProtocol( nonconformingProperties.push_back(propertyDecl); if (!TypeChecker::conformsToProtocol(DC->mapTypeIntoContext(type), protocol, - DC)) { + DC->getParentModule())) { nonconformingProperties.push_back(propertyDecl); } } @@ -715,8 +719,8 @@ DeclRefExpr *DerivedConformance::convertEnumToIndex(SmallVectorImpl &st /// \p protocol The protocol being requested. /// \return The ParamDecl of each associated value whose type does not conform. SmallVector -DerivedConformance::associatedValuesNotConformingToProtocol(DeclContext *DC, EnumDecl *theEnum, - ProtocolDecl *protocol) { +DerivedConformance::associatedValuesNotConformingToProtocol( + DeclContext *DC, EnumDecl *theEnum, ProtocolDecl *protocol) { SmallVector nonconformingAssociatedValues; for (auto elt : theEnum->getAllElements()) { auto PL = elt->getParameterList(); @@ -726,7 +730,7 @@ DerivedConformance::associatedValuesNotConformingToProtocol(DeclContext *DC, Enu for (auto param : *PL) { auto type = param->getInterfaceType(); if (TypeChecker::conformsToProtocol(DC->mapTypeIntoContext(type), - protocol, DC) + protocol, DC->getParentModule()) .isInvalid()) { nonconformingAssociatedValues.push_back(param); } diff --git a/lib/Sema/DerivedConformances.h b/lib/Sema/DerivedConformances.h index 4c962925ed9..a52e873f689 100644 --- a/lib/Sema/DerivedConformances.h +++ b/lib/Sema/DerivedConformances.h @@ -62,6 +62,9 @@ public: /// nominal type, or an extension of it) as a \c DeclContext. DeclContext *getConformanceContext() const; + /// Retrieve the module in which the conformance is declared. + ModuleDecl *getParentModule() const; + /// Add \c children as members of the context that declares the conformance. void addMembersToConformanceContext(ArrayRef children); diff --git a/lib/Sema/IDETypeCheckingRequests.cpp b/lib/Sema/IDETypeCheckingRequests.cpp index e83d35ba85f..c8e274f3f78 100644 --- a/lib/Sema/IDETypeCheckingRequests.cpp +++ b/lib/Sema/IDETypeCheckingRequests.cpp @@ -131,10 +131,12 @@ static bool isExtensionAppliedInternal(const DeclContext *DC, Type BaseTy, return true; GenericSignature genericSig = ED->getGenericSignature(); + auto *module = DC->getParentModule(); SubstitutionMap substMap = BaseTy->getContextSubstitutionMap( - DC->getParentModule(), ED->getExtendedNominal()); - return areGenericRequirementsSatisfied(DC, genericSig, substMap, - /*isExtension=*/true); + module, ED->getExtendedNominal()); + return TypeChecker::checkGenericArguments( + module, genericSig->getRequirements(), + QuerySubstitutionMap{substMap}) == RequirementCheckResult::Success; } static bool isMemberDeclAppliedInternal(const DeclContext *DC, Type BaseTy, @@ -155,10 +157,15 @@ static bool isMemberDeclAppliedInternal(const DeclContext *DC, Type BaseTy, if (!genericSig) return true; + auto *module = DC->getParentModule(); SubstitutionMap substMap = BaseTy->getContextSubstitutionMap( - DC->getParentModule(), VD->getDeclContext()); - return areGenericRequirementsSatisfied(DC, genericSig, substMap, - /*isExtension=*/false); + module, VD->getDeclContext()); + + // Note: we treat substitution failure as success, to avoid tripping + // up over generic parameters introduced by the declaration itself. + return TypeChecker::checkGenericArguments( + module, genericSig->getRequirements(), + QuerySubstitutionMap{substMap}) != RequirementCheckResult::Failure; } bool diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp index e223a480275..be09432b034 100644 --- a/lib/Sema/MiscDiagnostics.cpp +++ b/lib/Sema/MiscDiagnostics.cpp @@ -4561,12 +4561,12 @@ static void diagnoseComparisonWithNaN(const Expr *E, const DeclContext *DC) { auto secondArg = BE->getArg()->getElement(1); // Both arguments must conform to FloatingPoint protocol. - if (!conformsToKnownProtocol(const_cast(DC), - firstArg->getType(), - KnownProtocolKind::FloatingPoint) || - !conformsToKnownProtocol(const_cast(DC), - secondArg->getType(), - KnownProtocolKind::FloatingPoint)) { + if (!TypeChecker::conformsToKnownProtocol(firstArg->getType(), + KnownProtocolKind::FloatingPoint, + DC->getParentModule()) || + !TypeChecker::conformsToKnownProtocol(secondArg->getType(), + KnownProtocolKind::FloatingPoint, + DC->getParentModule())) { return; } diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 19a448b9bc3..7c2f841b0c9 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -1229,9 +1229,9 @@ void TypeChecker::checkDeclAttributes(Decl *D) { /// @dynamicCallable attribute requirement. The method is given to be defined /// as one of the following: `dynamicallyCall(withArguments:)` or /// `dynamicallyCall(withKeywordArguments:)`. -bool swift::isValidDynamicCallableMethod(FuncDecl *decl, DeclContext *DC, +bool swift::isValidDynamicCallableMethod(FuncDecl *decl, ModuleDecl *module, bool hasKeywordArguments) { - auto &ctx = decl->getASTContext(); + auto &ctx = module->getASTContext(); // There are two cases to check. // 1. `dynamicallyCall(withArguments:)`. // In this case, the method is valid if the argument has type `A` where @@ -1252,7 +1252,7 @@ bool swift::isValidDynamicCallableMethod(FuncDecl *decl, DeclContext *DC, if (!hasKeywordArguments) { auto arrayLitProto = ctx.getProtocol(KnownProtocolKind::ExpressibleByArrayLiteral); - return (bool)TypeChecker::conformsToProtocol(argType, arrayLitProto, DC); + return (bool)TypeChecker::conformsToProtocol(argType, arrayLitProto, module); } // If keyword arguments, check that argument type conforms to // `ExpressibleByDictionaryLiteral` and that the `Key` associated type @@ -1261,11 +1261,11 @@ bool swift::isValidDynamicCallableMethod(FuncDecl *decl, DeclContext *DC, ctx.getProtocol(KnownProtocolKind::ExpressibleByStringLiteral); auto dictLitProto = ctx.getProtocol(KnownProtocolKind::ExpressibleByDictionaryLiteral); - auto dictConf = TypeChecker::conformsToProtocol(argType, dictLitProto, DC); + auto dictConf = TypeChecker::conformsToProtocol(argType, dictLitProto, module); if (dictConf.isInvalid()) return false; auto keyType = dictConf.getTypeWitnessByName(argType, ctx.Id_Key); - return (bool)TypeChecker::conformsToProtocol(keyType, stringLitProtocol, DC); + return (bool)TypeChecker::conformsToProtocol(keyType, stringLitProtocol, module); } /// Returns true if the given nominal type has a valid implementation of a @@ -1280,9 +1280,10 @@ static bool hasValidDynamicCallableMethod(NominalTypeDecl *decl, if (candidates.empty()) return false; // Filter valid candidates. + auto *module = decl->getParentModule(); candidates.filter([&](LookupResultEntry entry, bool isOuter) { auto candidate = cast(entry.getValueDecl()); - return isValidDynamicCallableMethod(candidate, decl, hasKeywordArgs); + return isValidDynamicCallableMethod(candidate, module, hasKeywordArgs); }); // If there are no valid candidates, return false. @@ -1331,17 +1332,17 @@ static bool hasSingleNonVariadicParam(SubscriptDecl *decl, /// the `subscript(dynamicMember:)` requirement for @dynamicMemberLookup. /// The method is given to be defined as `subscript(dynamicMember:)`. bool swift::isValidDynamicMemberLookupSubscript(SubscriptDecl *decl, - DeclContext *DC, + ModuleDecl *module, bool ignoreLabel) { // It could be // - `subscript(dynamicMember: {Writable}KeyPath<...>)`; or // - `subscript(dynamicMember: String*)` return isValidKeyPathDynamicMemberLookup(decl, ignoreLabel) || - isValidStringDynamicMemberLookup(decl, DC, ignoreLabel); + isValidStringDynamicMemberLookup(decl, module, ignoreLabel); } bool swift::isValidStringDynamicMemberLookup(SubscriptDecl *decl, - DeclContext *DC, + ModuleDecl *module, bool ignoreLabel) { auto &ctx = decl->getASTContext(); // There are two requirements: @@ -1354,11 +1355,9 @@ bool swift::isValidStringDynamicMemberLookup(SubscriptDecl *decl, const auto *param = decl->getIndices()->get(0); auto paramType = param->getType(); - auto stringLitProto = - ctx.getProtocol(KnownProtocolKind::ExpressibleByStringLiteral); - // If this is `subscript(dynamicMember: String*)` - return (bool)TypeChecker::conformsToProtocol(paramType, stringLitProto, DC); + return TypeChecker::conformsToKnownProtocol( + paramType, KnownProtocolKind::ExpressibleByStringLiteral, module); } bool swift::isValidKeyPathDynamicMemberLookup(SubscriptDecl *decl, @@ -1389,6 +1388,8 @@ visitDynamicMemberLookupAttr(DynamicMemberLookupAttr *attr) { auto type = decl->getDeclaredType(); auto &ctx = decl->getASTContext(); + auto *module = decl->getParentModule(); + auto emitInvalidTypeDiagnostic = [&](const SourceLoc loc) { diagnose(loc, diag::invalid_dynamic_member_lookup_type, type); attr->setInvalid(); @@ -1404,7 +1405,7 @@ visitDynamicMemberLookupAttr(DynamicMemberLookupAttr *attr) { auto oneCandidate = candidates.front().getValueDecl(); candidates.filter([&](LookupResultEntry entry, bool isOuter) -> bool { auto cand = cast(entry.getValueDecl()); - return isValidDynamicMemberLookupSubscript(cand, decl); + return isValidDynamicMemberLookupSubscript(cand, module); }); if (candidates.empty()) { @@ -1426,7 +1427,7 @@ visitDynamicMemberLookupAttr(DynamicMemberLookupAttr *attr) { // Validate the candidates while ignoring the label. newCandidates.filter([&](const LookupResultEntry entry, bool isOuter) { auto cand = cast(entry.getValueDecl()); - return isValidDynamicMemberLookupSubscript(cand, decl, + return isValidDynamicMemberLookupSubscript(cand, module, /*ignoreLabel*/ true); }); @@ -1800,8 +1801,9 @@ void AttributeChecker::checkApplicationMainAttribute(DeclAttribute *attr, } if (!ApplicationDelegateProto || - !TypeChecker::conformsToProtocol(CD->getDeclaredType(), - ApplicationDelegateProto, CD)) { + !TypeChecker::conformsToProtocol(CD->getDeclaredInterfaceType(), + ApplicationDelegateProto, + CD->getParentModule())) { diagnose(attr->getLocation(), diag::attr_ApplicationMain_not_ApplicationDelegate, applicationMainKind); @@ -2721,9 +2723,10 @@ bool TypeEraserHasViableInitRequest::evaluate(Evaluator &evaluator, TypeEraserAttr *attr, ProtocolDecl *protocol) const { - auto &ctx = protocol->getASTContext(); - auto &diags = ctx.Diags; DeclContext *dc = protocol->getDeclContext(); + ModuleDecl *module = dc->getParentModule(); + auto &ctx = module->getASTContext(); + auto &diags = ctx.Diags; Type protocolType = protocol->getDeclaredInterfaceType(); // Get the NominalTypeDecl for the type eraser. @@ -2751,7 +2754,7 @@ TypeEraserHasViableInitRequest::evaluate(Evaluator &evaluator, } // The type eraser must conform to the annotated protocol - if (!TypeChecker::conformsToProtocol(typeEraser, protocol, dc)) { + if (!TypeChecker::conformsToProtocol(typeEraser, protocol, module)) { diags.diagnose(attr->getLoc(), diag::type_eraser_does_not_conform, typeEraser, protocolType); diags.diagnose(nominalTypeDecl->getLoc(), diag::type_eraser_declared_here); @@ -2794,26 +2797,21 @@ TypeEraserHasViableInitRequest::evaluate(Evaluator &evaluator, // type conforming to the annotated protocol. We will check this by // substituting the protocol's Self type for the generic arg and check that // the requirements in the generic signature are satisfied. + auto *module = nominalTypeDecl->getParentModule(); auto baseMap = - typeEraser->getContextSubstitutionMap(nominalTypeDecl->getParentModule(), + typeEraser->getContextSubstitutionMap(module, nominalTypeDecl); QuerySubstitutionMap getSubstitution{baseMap}; - auto subMap = SubstitutionMap::get( - genericSignature, - [&](SubstitutableType *type) -> Type { - if (type->isEqual(genericParamType)) - return protocol->getSelfTypeInContext(); - - return getSubstitution(type); - }, - LookUpConformanceInModule(dc->getParentModule())); // Use invalid 'SourceLoc's to suppress diagnostics. auto result = TypeChecker::checkGenericArguments( - dc, SourceLoc(), SourceLoc(), typeEraser, - genericSignature->getGenericParams(), - genericSignature->getRequirements(), - QuerySubstitutionMap{subMap}); + module, genericSignature->getRequirements(), + [&](SubstitutableType *type) -> Type { + if (type->isEqual(genericParamType)) + return protocol->getSelfTypeInContext(); + + return getSubstitution(type); + }); if (result != RequirementCheckResult::Success) { unviable.push_back( @@ -3548,12 +3546,12 @@ SpecializeAttrTargetDeclRequest::evaluate(Evaluator &evaluator, /// Returns true if the given type conforms to `Differentiable` in the given /// context. If `tangentVectorEqualsSelf` is true, also check whether the given /// type satisfies `TangentVector == Self`. -static bool conformsToDifferentiable(Type type, DeclContext *DC, +static bool conformsToDifferentiable(Type type, ModuleDecl *module, bool tangentVectorEqualsSelf = false) { - auto &ctx = type->getASTContext(); + auto &ctx = module->getASTContext(); auto *differentiableProto = ctx.getProtocol(KnownProtocolKind::Differentiable); - auto conf = TypeChecker::conformsToProtocol(type, differentiableProto, DC); + auto conf = TypeChecker::conformsToProtocol(type, differentiableProto, module); if (conf.isInvalid()) return false; if (!tangentVectorEqualsSelf) @@ -3564,7 +3562,8 @@ static bool conformsToDifferentiable(Type type, DeclContext *DC, IndexSubset *TypeChecker::inferDifferentiabilityParameters( AbstractFunctionDecl *AFD, GenericEnvironment *derivativeGenEnv) { - auto &ctx = AFD->getASTContext(); + auto *module = AFD->getParentModule(); + auto &ctx = module->getASTContext(); auto *functionType = AFD->getInterfaceType()->castTo(); auto numUncurriedParams = functionType->getNumParams(); if (auto *resultFnType = @@ -3587,7 +3586,7 @@ IndexSubset *TypeChecker::inferDifferentiabilityParameters( if (paramType->isExistentialType()) return false; // Return true if the type conforms to `Differentiable`. - return conformsToDifferentiable(paramType, AFD); + return conformsToDifferentiable(paramType, module); }; // Get all parameter types. @@ -3620,7 +3619,8 @@ static IndexSubset *computeDifferentiabilityParameters( ArrayRef parsedDiffParams, AbstractFunctionDecl *function, GenericEnvironment *derivativeGenEnv, StringRef attrName, SourceLoc attrLoc) { - auto &ctx = function->getASTContext(); + auto *module = function->getParentModule(); + auto &ctx = module->getASTContext(); auto &diags = ctx.Diags; // Get function type and parameters. @@ -3647,7 +3647,7 @@ static IndexSubset *computeDifferentiabilityParameters( selfType = derivativeGenEnv->mapTypeIntoContext(selfType); else selfType = function->mapTypeIntoContext(selfType); - if (!conformsToDifferentiable(selfType, function)) { + if (!conformsToDifferentiable(selfType, module)) { diags .diagnose(attrLoc, diag::diff_function_no_parameters, function->getName()) @@ -5140,7 +5140,7 @@ static bool checkLinearityParameters( SmallVector linearParams, GenericEnvironment *derivativeGenEnv, ModuleDecl *module, ArrayRef parsedLinearParams, SourceLoc attrLoc) { - auto &ctx = originalAFD->getASTContext(); + auto &ctx = module->getASTContext(); auto &diags = ctx.Diags; // Check that linearity parameters have allowed types. @@ -5156,7 +5156,7 @@ static bool checkLinearityParameters( parsedLinearParams.empty() ? attrLoc : parsedLinearParams[i].getLoc(); // Parameter must conform to `Differentiable` and satisfy // `Self == Self.TangentVector`. - if (!conformsToDifferentiable(linearParamType, originalAFD, + if (!conformsToDifferentiable(linearParamType, module, /*tangentVectorEqualsSelf*/ true)) { diags.diagnose(loc, diag::transpose_attr_invalid_linearity_parameter_or_result, @@ -5203,6 +5203,7 @@ doTransposeStaticAndInstanceSelfTypesMatch(AnyFunctionType *transposeType, void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) { auto *transpose = cast(D); + auto *module = transpose->getParentModule(); auto originalName = attr->getOriginalFunctionName(); auto *transposeInterfaceType = transpose->getInterfaceType()->castTo(); @@ -5268,7 +5269,7 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) { if (expectedOriginalResultType->hasTypeParameter()) expectedOriginalResultType = transpose->mapTypeIntoContext( expectedOriginalResultType); - if (!conformsToDifferentiable(expectedOriginalResultType, transpose, + if (!conformsToDifferentiable(expectedOriginalResultType, module, /*tangentVectorEqualsSelf*/ true)) { diagnoseAndRemoveAttr( attr, diag::transpose_attr_invalid_linearity_parameter_or_result, diff --git a/lib/Sema/TypeCheckAvailability.cpp b/lib/Sema/TypeCheckAvailability.cpp index 1a5bac2aada..ffc6b627409 100644 --- a/lib/Sema/TypeCheckAvailability.cpp +++ b/lib/Sema/TypeCheckAvailability.cpp @@ -3015,16 +3015,11 @@ swift::diagnoseDeclAvailability(const ValueDecl *D, /// Return true if the specified type looks like an integer of floating point /// type. -static bool isIntegerOrFloatingPointType(Type ty, DeclContext *DC, - ASTContext &Context) { - auto integerType = - Context.getProtocol(KnownProtocolKind::ExpressibleByIntegerLiteral); - auto floatingType = - Context.getProtocol(KnownProtocolKind::ExpressibleByFloatLiteral); - if (!integerType || !floatingType) return false; - - return TypeChecker::conformsToProtocol(ty, integerType, DC) || - TypeChecker::conformsToProtocol(ty, floatingType, DC); +static bool isIntegerOrFloatingPointType(Type ty, ModuleDecl *M) { + return (TypeChecker::conformsToKnownProtocol( + ty, KnownProtocolKind::ExpressibleByIntegerLiteral, M) || + TypeChecker::conformsToKnownProtocol( + ty, KnownProtocolKind::ExpressibleByFloatLiteral, M)); } @@ -3054,7 +3049,7 @@ ExprAvailabilityWalker::diagnoseIncDecRemoval(const ValueDecl *D, SourceRange R, // to "lvalue += 1". auto *DC = Where.getDeclContext(); std::string replacement; - if (isIntegerOrFloatingPointType(call->getType(), DC, Context)) + if (isIntegerOrFloatingPointType(call->getType(), DC->getParentModule())) replacement = isInc ? " += 1" : " -= 1"; else { // Otherwise, it must be an index type. Rewrite to: diff --git a/lib/Sema/TypeCheckConcurrency.cpp b/lib/Sema/TypeCheckConcurrency.cpp index bca31ad49d8..a5abd9b1efa 100644 --- a/lib/Sema/TypeCheckConcurrency.cpp +++ b/lib/Sema/TypeCheckConcurrency.cpp @@ -155,6 +155,8 @@ VarDecl *GlobalActorInstanceRequest::evaluate( return nullptr; } + auto *module = nominal->getParentModule(); + // Global actors have a static property "shared" that provides an actor // instance. The value must SmallVector decls; @@ -174,7 +176,7 @@ VarDecl *GlobalActorInstanceRequest::evaluate( cast(varDC)->isConstrainedExtension()) && TypeChecker::conformsToProtocol( varDC->mapTypeIntoContext(var->getValueInterfaceType()), - actorProto, nominal)) { + actorProto, module)) { sharedVar = var; break; } @@ -599,15 +601,14 @@ static bool isSendableClosure( } /// Determine whether the given type is suitable as a concurrent value type. -bool swift::isSendableType(const DeclContext *dc, Type type) { +bool swift::isSendableType(ModuleDecl *module, Type type) { class IsSendable : public TypeVisitor { - DeclContext *dc; + ModuleDecl *module; ProtocolDecl *SendableProto; public: - IsSendable(const DeclContext *dc) - : dc(const_cast(dc)) { - SendableProto = dc->getASTContext().getProtocol( + IsSendable(ModuleDecl *module) : module(module) { + SendableProto = module->getASTContext().getProtocol( KnownProtocolKind::Sendable); } @@ -648,7 +649,7 @@ bool swift::isSendableType(const DeclContext *dc, Type type) { return true; return !TypeChecker::conformsToProtocol( - Type(type), SendableProto, dc).isInvalid(); + Type(type), SendableProto, module).isInvalid(); } bool visitTupleType(TupleType *type) { @@ -681,7 +682,7 @@ bool swift::isSendableType(const DeclContext *dc, Type type) { return true; return !TypeChecker::containsProtocol( - Type(type), SendableProto, dc).isInvalid(); + Type(type), SendableProto, module).isInvalid(); } bool visitBoundGenericType(BoundGenericType *type) { @@ -718,7 +719,7 @@ bool swift::isSendableType(const DeclContext *dc, Type type) { if (!SendableProto) return true; - return !TypeChecker::containsProtocol(type, SendableProto, dc) + return !TypeChecker::containsProtocol(type, SendableProto, module) .isInvalid(); } @@ -729,7 +730,7 @@ bool swift::isSendableType(const DeclContext *dc, Type type) { bool visitInOutType(InOutType *type) { return visit(type->getObjectType()); } - } checker(dc); + } checker(module); return checker.visit(type); } @@ -771,10 +772,10 @@ static bool shouldDiagnoseNonSendableViolations( } bool swift::diagnoseNonConcurrentTypesInReference( - ConcreteDeclRef declRef, const DeclContext *dc, SourceLoc loc, + ConcreteDeclRef declRef, ModuleDecl *module, SourceLoc loc, ConcurrentReferenceKind refKind, DiagnosticBehavior behavior) { // Bail out immediately if we aren't supposed to do this checking. - if (!shouldDiagnoseNonSendableViolations(dc->getASTContext().LangOpts)) + if (!shouldDiagnoseNonSendableViolations(module->getASTContext().LangOpts)) return false; // For functions, check the parameter and result types. @@ -782,7 +783,7 @@ bool swift::diagnoseNonConcurrentTypesInReference( if (auto function = dyn_cast(declRef.getDecl())) { for (auto param : *function->getParameters()) { Type paramType = param->getInterfaceType().subst(subs); - if (!isSendableType(dc, paramType)) { + if (!isSendableType(module, paramType)) { return diagnoseNonConcurrentParameter( loc, refKind, declRef, param, paramType, behavior); } @@ -791,7 +792,7 @@ bool swift::diagnoseNonConcurrentTypesInReference( // Check the result type of a function. if (auto func = dyn_cast(function)) { Type resultType = func->getResultInterfaceType().subst(subs); - if (!isSendableType(dc, resultType)) { + if (!isSendableType(module, resultType)) { return diagnoseNonConcurrentResult(loc, refKind, declRef, resultType, behavior); } @@ -804,7 +805,7 @@ bool swift::diagnoseNonConcurrentTypesInReference( Type propertyType = var->isLocalCapture() ? var->getType() : var->getValueInterfaceType().subst(subs); - if (!isSendableType(dc, propertyType)) { + if (!isSendableType(module, propertyType)) { return diagnoseNonConcurrentProperty(loc, refKind, var, propertyType, behavior); } @@ -813,7 +814,7 @@ bool swift::diagnoseNonConcurrentTypesInReference( if (auto subscript = dyn_cast(declRef.getDecl())) { for (auto param : *subscript->getIndices()) { Type paramType = param->getInterfaceType().subst(subs); - if (!isSendableType(dc, paramType)) { + if (!isSendableType(module, paramType)) { return diagnoseNonConcurrentParameter( loc, refKind, declRef, param, paramType, behavior); } @@ -821,7 +822,7 @@ bool swift::diagnoseNonConcurrentTypesInReference( // Check the element type of a subscript. Type resultType = subscript->getElementInterfaceType().subst(subs); - if (!isSendableType(dc, resultType)) { + if (!isSendableType(module, resultType)) { return diagnoseNonConcurrentResult(loc, refKind, declRef, resultType, behavior); } @@ -918,6 +919,10 @@ namespace { return contextStack.back(); } + ModuleDecl *getParentModule() const { + return getDeclContext()->getParentModule(); + } + /// Determine whether code in the given use context might execute /// concurrently with code in the definition context. bool mayExecuteConcurrentlyWith( @@ -1509,7 +1514,7 @@ namespace { // Check for non-concurrent types. bool problemFound = diagnoseNonConcurrentTypesInReference( - concDeclRef, getDeclContext(), declLoc, + concDeclRef, getDeclContext()->getParentModule(), declLoc, ConcurrentReferenceKind::SynchronousAsAsyncCall); if (problemFound) result = AsyncMarkingResult::NotSendable; @@ -1580,7 +1585,7 @@ namespace { // Check for sendability of the parameter types. for (const auto ¶m : fnType->getParams()) { // FIXME: Dig out the locations of the corresponding arguments. - if (!isSendableType(getDeclContext(), param.getParameterType())) { + if (!isSendableType(getParentModule(), param.getParameterType())) { ctx.Diags.diagnose( apply->getLoc(), diag::non_concurrent_param_type, param.getParameterType()); @@ -1589,7 +1594,7 @@ namespace { } // Check for sendability of the result type. - if (!isSendableType(getDeclContext(), fnType->getResult())) { + if (!isSendableType(getParentModule(), fnType->getResult())) { ctx.Diags.diagnose( apply->getLoc(), diag::non_concurrent_result_type, fnType->getResult()); @@ -1618,7 +1623,7 @@ namespace { // A cross-actor access requires types to be concurrent-safe. if (isCrossActor) { return diagnoseNonConcurrentTypesInReference( - valueRef, getDeclContext(), loc, + valueRef, getParentModule(), loc, ConcurrentReferenceKind::CrossActor); } @@ -1744,7 +1749,7 @@ namespace { (ctx.LangOpts.EnableExperimentalFlowSensitiveConcurrentCaptures && parent.dyn_cast())) { return diagnoseNonConcurrentTypesInReference( - valueRef, getDeclContext(), loc, + valueRef, getParentModule(), loc, ConcurrentReferenceKind::LocalCapture); } @@ -1795,7 +1800,7 @@ namespace { if (varDecl->isLet()) { auto type = component.getComponentType(); if (shouldDiagnoseNonSendableViolations(ctx.LangOpts) - && !isSendableType(getDeclContext(), type)) { + && !isSendableType(getParentModule(), type)) { ctx.Diags.diagnose( component.getLoc(), diag::non_concurrent_keypath_access, type); @@ -1858,7 +1863,7 @@ namespace { if (auto indexExpr = component.getIndexExpr()) { auto type = indexExpr->getType(); if (type && shouldDiagnoseNonSendableViolations(ctx.LangOpts) - && !isSendableType(getDeclContext(), type)) { + && !isSendableType(getParentModule(), type)) { ctx.Diags.diagnose( component.getLoc(), diag::non_concurrent_keypath_capture, indexExpr->getType()); @@ -1969,7 +1974,7 @@ namespace { return false; return diagnoseNonConcurrentTypesInReference( - memberRef, getDeclContext(), memberLoc, + memberRef, getDeclContext()->getParentModule(), memberLoc, ConcurrentReferenceKind::CrossActor); } @@ -2963,7 +2968,7 @@ static bool checkSendableInstanceStorage( } auto propertyType = dc->mapTypeIntoContext(property->getInterfaceType()); - if (!isSendableType(dc, propertyType)) { + if (!isSendableType(dc->getParentModule(), propertyType)) { if (behavior == DiagnosticBehavior::Ignore) return true; property->diagnose(diag::non_concurrent_type_member, @@ -2989,7 +2994,7 @@ static bool checkSendableInstanceStorage( auto elementType = dc->mapTypeIntoContext( element->getArgumentInterfaceType()); - if (!isSendableType(dc, elementType)) { + if (!isSendableType(dc->getParentModule(), elementType)) { if (behavior == DiagnosticBehavior::Ignore) return true; element->diagnose(diag::non_concurrent_type_member, diff --git a/lib/Sema/TypeCheckConcurrency.h b/lib/Sema/TypeCheckConcurrency.h index c4f977f1e7c..81c13f6fd1f 100644 --- a/lib/Sema/TypeCheckConcurrency.h +++ b/lib/Sema/TypeCheckConcurrency.h @@ -198,7 +198,7 @@ bool contextUsesConcurrencyFeatures(const DeclContext *dc); /// domain, including the substitutions so that (e.g.) we can consider the /// specific types at the use site. /// -/// \param dc The declaration context from which the reference occurs. This is +/// \param module The module from which the reference occurs. This is /// used to perform lookup of conformances to the \c Sendable protocol. /// /// \param loc The location at which the reference occurs, which will be @@ -209,7 +209,7 @@ bool contextUsesConcurrencyFeatures(const DeclContext *dc); /// /// \returns true if an problem was detected, false otherwise. bool diagnoseNonConcurrentTypesInReference( - ConcreteDeclRef declRef, const DeclContext *dc, SourceLoc loc, + ConcreteDeclRef declRef, ModuleDecl *module, SourceLoc loc, ConcurrentReferenceKind refKind, DiagnosticBehavior behavior = DiagnosticBehavior::Unspecified); diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index ecf4f08af77..ad90bae5134 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -216,7 +216,7 @@ void ParentConditionalConformance::diagnoseConformanceStack( ArrayRef conformances) { for (auto history : llvm::reverse(conformances)) { diags.diagnose(loc, diag::requirement_implied_by_conditional_conformance, - history.ConformingType, history.Protocol); + history.ConformingType, history.Protocol->getDeclaredInterfaceType()); } } @@ -1325,6 +1325,8 @@ CheckedCastKind TypeChecker::typeCheckCheckedCast(Type fromType, Type origFromType = fromType; Type origToType = toType; + auto *module = dc->getParentModule(); + auto &diags = dc->getASTContext().Diags; bool optionalToOptionalCast = false; @@ -1713,7 +1715,7 @@ CheckedCastKind TypeChecker::typeCheckCheckedCast(Type fromType, // if (auto *protocolDecl = dyn_cast_or_null(fromType->getAnyNominal())) { - if (!couldDynamicallyConformToProtocol(toType, protocolDecl, dc)) { + if (!couldDynamicallyConformToProtocol(toType, protocolDecl, module)) { return failed(); } } else if (auto protocolComposition = @@ -1723,7 +1725,7 @@ CheckedCastKind TypeChecker::typeCheckCheckedCast(Type fromType, if (auto protocolDecl = dyn_cast_or_null( protocolType->getAnyNominal())) { return !couldDynamicallyConformToProtocol( - toType, protocolDecl, dc); + toType, protocolDecl, module); } return false; })) { @@ -1819,7 +1821,7 @@ CheckedCastKind TypeChecker::typeCheckCheckedCast(Type fromType, auto nsErrorTy = Context.getNSErrorType(); if (auto errorTypeProto = Context.getProtocol(KnownProtocolKind::Error)) { - if (!conformsToProtocol(toType, errorTypeProto, dc).isInvalid()) { + if (conformsToProtocol(toType, errorTypeProto, module)) { if (nsErrorTy) { if (isSubtypeOf(fromType, nsErrorTy, dc) // Don't mask "always true" warnings if NSError is cast to @@ -1829,7 +1831,7 @@ CheckedCastKind TypeChecker::typeCheckCheckedCast(Type fromType, } } - if (!conformsToProtocol(fromType, errorTypeProto, dc).isInvalid()) { + if (conformsToProtocol(fromType, errorTypeProto, module)) { // Cast of an error-conforming type to NSError or NSObject. if ((nsObject && toType->isEqual(nsObject)) || (nsErrorTy && toType->isEqual(nsErrorTy))) diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index a254fb9326b..6a1253e830d 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -1102,13 +1102,13 @@ swift::computeAutomaticEnumValueKind(EnumDecl *ED) { if (ED->getGenericEnvironmentOfContext() != nullptr) rawTy = ED->mapTypeIntoContext(rawTy); - + + auto *module = ED->getParentModule(); + // Swift enums require that the raw type is convertible from one of the // primitive literal protocols. auto conformsToProtocol = [&](KnownProtocolKind protoKind) { - ProtocolDecl *proto = ED->getASTContext().getProtocol(protoKind); - return proto && - TypeChecker::conformsToProtocol(rawTy, proto, ED->getDeclContext()); + return TypeChecker::conformsToKnownProtocol(rawTy, protoKind, module); }; static auto otherLiteralProtocolKinds = { diff --git a/lib/Sema/TypeCheckDeclObjC.cpp b/lib/Sema/TypeCheckDeclObjC.cpp index 42ad5f6af6d..2743e6145c3 100644 --- a/lib/Sema/TypeCheckDeclObjC.cpp +++ b/lib/Sema/TypeCheckDeclObjC.cpp @@ -407,8 +407,8 @@ static bool checkObjCActorIsolation(const ValueDecl *VD, case ActorIsolationRestriction::CrossActorSelf: // FIXME: Substitution map? diagnoseNonConcurrentTypesInReference( - const_cast(VD), VD->getDeclContext(), VD->getLoc(), - ConcurrentReferenceKind::CrossActor, behavior); + const_cast(VD), VD->getDeclContext()->getParentModule(), + VD->getLoc(), ConcurrentReferenceKind::CrossActor, behavior); return false; case ActorIsolationRestriction::ActorSelf: // Actor-isolated functions cannot be @objc. diff --git a/lib/Sema/TypeCheckDeclPrimary.cpp b/lib/Sema/TypeCheckDeclPrimary.cpp index 6484a415693..d46556c02a6 100644 --- a/lib/Sema/TypeCheckDeclPrimary.cpp +++ b/lib/Sema/TypeCheckDeclPrimary.cpp @@ -368,10 +368,10 @@ static void checkForEmptyOptionSet(const VarDecl *VD) { return; // Make sure this type conforms to OptionSet - auto *optionSetProto = VD->getASTContext().getProtocol(KnownProtocolKind::OptionSet); - bool conformsToOptionSet = (bool)TypeChecker::conformsToProtocol( - DC->getSelfTypeInContext(), - optionSetProto, DC); + bool conformsToOptionSet = + (bool)TypeChecker::conformsToKnownProtocol(DC->getSelfTypeInContext(), + KnownProtocolKind::OptionSet, + DC->getParentModule()); if (!conformsToOptionSet) return; @@ -1111,13 +1111,13 @@ static Optional buildDefaultInitializerString(DeclContext *dc, // Special-case the various types we might see here. auto type = pattern->getType(); + auto *module = dc->getParentModule(); + // For literal-convertible types, form the corresponding literal. -#define CHECK_LITERAL_PROTOCOL(Kind, String) \ - if (auto proto = TypeChecker::getProtocol( \ - type->getASTContext(), SourceLoc(), KnownProtocolKind::Kind)) { \ - if (TypeChecker::conformsToProtocol(type, proto, dc)) \ - return std::string(String); \ - } +#define CHECK_LITERAL_PROTOCOL(Kind, String) \ + if (TypeChecker::conformsToKnownProtocol(type, KnownProtocolKind::Kind, module)) \ + return std::string(String); + CHECK_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "[]") CHECK_LITERAL_PROTOCOL(ExpressibleByDictionaryLiteral, "[:]") CHECK_LITERAL_PROTOCOL(ExpressibleByUnicodeScalarLiteral, "\"\"") @@ -1236,7 +1236,7 @@ static void diagnoseClassWithoutInitializers(ClassDecl *classDecl) { auto *decodableProto = C.getProtocol(KnownProtocolKind::Decodable); auto superclassType = superclassDecl->getDeclaredInterfaceType(); auto ref = TypeChecker::conformsToProtocol( - superclassType, decodableProto, superclassDecl); + superclassType, decodableProto, classDecl->getParentModule()); if (ref) { // super conforms to Decodable, so we've failed to inherit init(from:). // Let's suggest overriding it here. @@ -1264,7 +1264,7 @@ static void diagnoseClassWithoutInitializers(ClassDecl *classDecl) { // we can produce a slightly different diagnostic to suggest doing so. auto *encodableProto = C.getProtocol(KnownProtocolKind::Encodable); auto ref = TypeChecker::conformsToProtocol( - superclassType, encodableProto, superclassDecl); + superclassType, encodableProto, classDecl->getParentModule()); if (ref) { // We only want to produce this version of the diagnostic if the // subclass doesn't directly implement encode(to:). diff --git a/lib/Sema/TypeCheckGeneric.cpp b/lib/Sema/TypeCheckGeneric.cpp index e7e78e7e15c..0e731650a7b 100644 --- a/lib/Sema/TypeCheckGeneric.cpp +++ b/lib/Sema/TypeCheckGeneric.cpp @@ -774,7 +774,12 @@ RequirementCheckResult TypeChecker::checkGenericArguments( auto req = rawReq; if (current.Parents.empty()) { auto substed = rawReq.subst( - substitutions, + [&](SubstitutableType *type) -> Type { + auto substType = substitutions(type); + if (substType->hasTypeParameter()) + return dc->mapTypeIntoContext(substType); + return substType; + }, LookUpConformanceInModule(module), options); if (!substed) { @@ -786,105 +791,70 @@ RequirementCheckResult TypeChecker::checkGenericArguments( req = *substed; } - auto kind = req.getKind(); - Type rawFirstType = rawReq.getFirstType(); - Type firstType = req.getFirstType(); - if (firstType->hasTypeParameter()) - firstType = dc->mapTypeIntoContext(firstType); + ArrayRef conditionalRequirements; + if (req.isSatisfied(conditionalRequirements)) { + if (!conditionalRequirements.empty()) { + assert(req.getKind() == RequirementKind::Conformance); - Type rawSecondType, secondType; - if (kind != RequirementKind::Layout) { - rawSecondType = rawReq.getSecondType(); - secondType = req.getSecondType(); - if (secondType->hasTypeParameter()) - secondType = dc->mapTypeIntoContext(secondType); - } - - // Don't do further checking on error types. - if (firstType->hasError() || (secondType && secondType->hasError())) { - // Another requirement will fail later; just continue. - valid = false; + auto history = current.Parents; + history.push_back({req.getFirstType(), req.getProtocolDecl()}); + pendingReqs.push_back({conditionalRequirements, std::move(history)}); + } continue; } - bool requirementFailure = false; - - Diag diagnostic; - Diag diagnosticNote; - - switch (kind) { - case RequirementKind::Conformance: { - // Protocol conformance requirements. - auto proto = secondType->castTo(); - auto conformance = module->lookupConformance(firstType, proto->getDecl()); - - if (conformance) { - auto conditionalReqs = conformance.getConditionalRequirements(); - if (!conditionalReqs.empty()) { - auto history = current.Parents; - history.push_back({firstType, proto}); - pendingReqs.push_back({conditionalReqs, std::move(history)}); - } - continue; - } - - if (loc.isValid()) - diagnoseConformanceFailure(firstType, proto->getDecl(), module, loc); - - if (current.Parents.empty()) - return RequirementCheckResult::Failure; - - // Failure needs to emit a diagnostic. - diagnostic = diag::type_does_not_conform_owner; - diagnosticNote = diag::type_does_not_inherit_or_conform_requirement; - requirementFailure = true; - break; - } - - case RequirementKind::Layout: - // TODO: Statically check other layout constraints, once they can - // be spelled in Swift. - if (req.getLayoutConstraint()->isClass() && - !firstType->satisfiesClassConstraint()) { - diagnostic = diag::type_is_not_a_class; - diagnosticNote = diag::anyobject_requirement; - requirementFailure = true; - } - break; - - case RequirementKind::Superclass: { - // Superclass requirements. - if (!secondType->isExactSuperclassOf(firstType)) { - diagnostic = diag::type_does_not_inherit; - diagnosticNote = diag::type_does_not_inherit_or_conform_requirement; - requirementFailure = true; - } - break; - } - - case RequirementKind::SameType: - if (!firstType->isEqual(secondType)) { - diagnostic = diag::types_not_equal; - diagnosticNote = diag::types_not_equal_requirement; - requirementFailure = true; - } - break; - } - - if (!requirementFailure) - continue; - if (loc.isValid()) { + Diag diagnostic; + Diag diagnosticNote; + + switch (req.getKind()) { + case RequirementKind::Conformance: { + diagnoseConformanceFailure(req.getFirstType(), req.getProtocolDecl(), + module, loc); + + if (current.Parents.empty()) + return RequirementCheckResult::Failure; + + diagnostic = diag::type_does_not_conform_owner; + diagnosticNote = diag::type_does_not_inherit_or_conform_requirement; + break; + } + + case RequirementKind::Layout: + diagnostic = diag::type_is_not_a_class; + diagnosticNote = diag::anyobject_requirement; + break; + + case RequirementKind::Superclass: + diagnostic = diag::type_does_not_inherit; + diagnosticNote = diag::type_does_not_inherit_or_conform_requirement; + break; + + case RequirementKind::SameType: + diagnostic = diag::types_not_equal; + diagnosticNote = diag::types_not_equal_requirement; + break; + } + + Type rawSecondType, secondType; + if (req.getKind() != RequirementKind::Layout) { + rawSecondType = rawReq.getSecondType(); + secondType = req.getSecondType(); + } + // FIXME: Poor source-location information. - ctx.Diags.diagnose(loc, diagnostic, owner, firstType, secondType); + ctx.Diags.diagnose(loc, diagnostic, owner, + req.getFirstType(), secondType); std::string genericParamBindingsText; if (!genericParams.empty()) { genericParamBindingsText = gatherGenericParamBindingsText( - {rawFirstType, rawSecondType}, genericParams, substitutions); + {rawReq.getFirstType(), rawSecondType}, + genericParams, substitutions); } - ctx.Diags.diagnose(noteLoc, diagnosticNote, rawFirstType, rawSecondType, + ctx.Diags.diagnose(noteLoc, diagnosticNote, + rawReq.getFirstType(), rawSecondType, genericParamBindingsText); ParentConditionalConformance::diagnoseConformanceStack( @@ -900,6 +870,37 @@ RequirementCheckResult TypeChecker::checkGenericArguments( return RequirementCheckResult::SubstitutionFailure; } +RequirementCheckResult +TypeChecker::checkGenericArguments(ModuleDecl *module, + ArrayRef requirements, + TypeSubstitutionFn substitutions) { + SmallVector worklist; + bool valid = true; + + for (auto req : requirements) { + if (auto resolved = req.subst(substitutions, + LookUpConformanceInModule(module))) { + worklist.push_back(*resolved); + } else { + valid = false; + } + } + + while (!worklist.empty()) { + auto req = worklist.pop_back_val(); + ArrayRef conditionalRequirements; + if (!req.isSatisfied(conditionalRequirements)) + return RequirementCheckResult::Failure; + + worklist.append(conditionalRequirements.begin(), + conditionalRequirements.end()); + } + + if (valid) + return RequirementCheckResult::Success; + return RequirementCheckResult::SubstitutionFailure; +} + Requirement RequirementRequest::evaluate(Evaluator &evaluator, WhereClauseOwner owner, diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 8ed008d8d04..97e19f6ad9d 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -2798,7 +2798,8 @@ bool ConformanceChecker::checkActorIsolation( case ActorIsolationRestriction::CrossActorSelf: return diagnoseNonConcurrentTypesInReference( - witness, DC, witness->getLoc(), ConcurrentReferenceKind::CrossActor); + witness, DC->getParentModule(), witness->getLoc(), + ConcurrentReferenceKind::CrossActor); case ActorIsolationRestriction::GlobalActorUnsafe: witnessIsUnsafe = true; @@ -2863,7 +2864,8 @@ bool ConformanceChecker::checkActorIsolation( return false; return diagnoseNonConcurrentTypesInReference( - witness, DC, witness->getLoc(), ConcurrentReferenceKind::CrossActor); + witness, DC->getParentModule(), witness->getLoc(), + ConcurrentReferenceKind::CrossActor); } // If the witness has a global actor but the requirement does not, we have @@ -2897,7 +2899,8 @@ bool ConformanceChecker::checkActorIsolation( if (isCrossActor) { return diagnoseNonConcurrentTypesInReference( - witness, DC, witness->getLoc(), ConcurrentReferenceKind::CrossActor); + witness, DC->getParentModule(), witness->getLoc(), + ConcurrentReferenceKind::CrossActor); } witness->diagnose( @@ -3898,7 +3901,8 @@ ConformanceChecker::resolveWitnessViaLookup(ValueDecl *requirement) { // a member that could in turn satisfy *this* requirement. auto derivableProto = cast(derivable->getDeclContext()); auto conformance = - TypeChecker::conformsToProtocol(Adoptee, derivableProto, DC); + TypeChecker::conformsToProtocol(Adoptee, derivableProto, + DC->getParentModule()); if (conformance.isConcrete()) { (void)conformance.getConcrete()->getWitnessDecl(derivable); } @@ -4918,7 +4922,7 @@ void swift::diagnoseConformanceFailure(Type T, // If we're checking conformance of an existential type to a protocol, // do a little bit of extra work to produce a better diagnostic. if (T->isExistentialType() && - TypeChecker::containsProtocol(T, Proto, DC)) { + TypeChecker::containsProtocol(T, Proto, DC->getParentModule())) { if (!T->isObjCExistentialType()) { diags.diagnose(ComplainLoc, diag::type_cannot_conform, true, @@ -4956,12 +4960,8 @@ void swift::diagnoseConformanceFailure(Type T, // If the reason is that the raw type does not conform to // Equatable, say so. - auto equatableProto = ctx.getProtocol(KnownProtocolKind::Equatable); - if (!equatableProto) - return; - - if (TypeChecker::conformsToProtocol(rawType, equatableProto, enumDecl) - .isInvalid()) { + if (!TypeChecker::conformsToKnownProtocol( + rawType, KnownProtocolKind::Equatable, DC->getParentModule())) { SourceLoc loc = enumDecl->getInherited()[0].getSourceRange().Start; diags.diagnose(loc, diag::enum_raw_type_not_equatable, rawType); return; @@ -5051,7 +5051,7 @@ void ConformanceChecker::emitDelayedDiags() { } ProtocolConformanceRef -TypeChecker::containsProtocol(Type T, ProtocolDecl *Proto, DeclContext *DC, +TypeChecker::containsProtocol(Type T, ProtocolDecl *Proto, ModuleDecl *M, bool skipConditionalRequirements) { // Existential types don't need to conform, i.e., they only need to // contain the protocol. @@ -5060,7 +5060,7 @@ TypeChecker::containsProtocol(Type T, ProtocolDecl *Proto, DeclContext *DC, // *and* has a witness table. if (T->isEqual(Proto->getDeclaredInterfaceType()) && Proto->requiresSelfConformanceWitnessTable()) { - auto &ctx = DC->getASTContext(); + auto &ctx = M->getASTContext(); return ProtocolConformanceRef(ctx.getSelfConformance(Proto)); } @@ -5071,8 +5071,8 @@ TypeChecker::containsProtocol(Type T, ProtocolDecl *Proto, DeclContext *DC, if (auto superclass = layout.getSuperclass()) { auto result = (skipConditionalRequirements - ? DC->getParentModule()->lookupConformance(superclass, Proto) - : TypeChecker::conformsToProtocol(superclass, Proto, DC)); + ? M->lookupConformance(superclass, Proto) + : TypeChecker::conformsToProtocol(superclass, Proto, M)); if (result) { return result; } @@ -5097,14 +5097,13 @@ TypeChecker::containsProtocol(Type T, ProtocolDecl *Proto, DeclContext *DC, // For non-existential types, this is equivalent to checking conformance. return (skipConditionalRequirements - ? DC->getParentModule()->lookupConformance(T, Proto) - : TypeChecker::conformsToProtocol(T, Proto, DC)); + ? M->lookupConformance(T, Proto) + : TypeChecker::conformsToProtocol(T, Proto, M)); } ProtocolConformanceRef -TypeChecker::conformsToProtocol(Type T, ProtocolDecl *Proto, DeclContext *DC) { +TypeChecker::conformsToProtocol(Type T, ProtocolDecl *Proto, ModuleDecl *M) { // Look up conformance in the module. - ModuleDecl *M = DC->getParentModule(); auto lookupResult = M->lookupConformance(T, Proto); if (lookupResult.isInvalid()) { return ProtocolConformanceRef::forInvalid(); @@ -5115,12 +5114,10 @@ TypeChecker::conformsToProtocol(Type T, ProtocolDecl *Proto, DeclContext *DC) { "unhandled recursion: missing conditional requirements when they're " "required"); - // If we have a conditional requirements that - // we need to check, do so now. + // If we have a conditional requirements that we need to check, do so now. if (!condReqs->empty()) { auto conditionalCheckResult = checkGenericArguments( - DC, SourceLoc(), SourceLoc(), T, - {lookupResult.getRequirement()->getSelfInterfaceType()}, *condReqs, + M, *condReqs, [](SubstitutableType *dependentType) { return Type(dependentType); }); switch (conditionalCheckResult) { case RequirementCheckResult::Success: @@ -5135,9 +5132,17 @@ TypeChecker::conformsToProtocol(Type T, ProtocolDecl *Proto, DeclContext *DC) { return lookupResult; } +bool TypeChecker::conformsToKnownProtocol(Type type, KnownProtocolKind protocol, + ModuleDecl *module) { + if (auto *proto = + TypeChecker::getProtocol(module->getASTContext(), SourceLoc(), protocol)) + return (bool)TypeChecker::conformsToProtocol(type, proto, module); + return false; +} + bool TypeChecker::couldDynamicallyConformToProtocol(Type type, ProtocolDecl *Proto, - DeclContext *DC) { + ModuleDecl *M) { // An existential may have a concrete underlying type with protocol conformances // we cannot know statically. if (type->isExistentialType()) @@ -5154,7 +5159,6 @@ TypeChecker::couldDynamicallyConformToProtocol(Type type, ProtocolDecl *Proto, return true; } - ModuleDecl *M = DC->getParentModule(); // For standard library collection types such as Array, Set or Dictionary // which have custom casting machinery implemented for situations like: // @@ -5166,7 +5170,7 @@ TypeChecker::couldDynamicallyConformToProtocol(Type type, ProtocolDecl *Proto, // are met or not. if (type->isKnownStdlibCollectionType()) return !M->lookupConformance(type, Proto).isInvalid(); - return !conformsToProtocol(type, Proto, DC).isInvalid(); + return !conformsToProtocol(type, Proto, M).isInvalid(); } /// Exposes TypeChecker functionality for querying protocol conformance. @@ -6453,12 +6457,14 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) { if (!assocType) continue; + auto *module = proto->getParentModule(); + // Find the associated type nearest our own protocol, which might have // a default not available in the associated type referenced by the // (canonicalized) requirement. if (assocType->getProtocol() != proto) { SmallVector found; - proto->getModuleContext()->lookupQualified( + module->lookupQualified( proto, DeclNameRef(assocType->getName()), NL_QualifiedDefault|NL_ProtocolMembers|NL_OnlyTypes, found); @@ -6478,7 +6484,7 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) { proto->mapTypeIntoContext(defaultAssocType); auto requirementProto = req.getProtocolDecl(); auto conformance = conformsToProtocol(defaultAssocTypeInContext, - requirementProto, proto); + requirementProto, module); if (conformance.isInvalid()) { // Diagnose the lack of a conformance. This is potentially an ABI // incompatibility. diff --git a/lib/Sema/TypeCheckStorage.cpp b/lib/Sema/TypeCheckStorage.cpp index 7332c84ac16..e6690f9eaee 100644 --- a/lib/Sema/TypeCheckStorage.cpp +++ b/lib/Sema/TypeCheckStorage.cpp @@ -978,7 +978,8 @@ static ProtocolConformanceRef checkConformanceToNSCopying(VarDecl *var, auto proto = ctx.getNSCopyingDecl(); if (proto) { - if (auto result = TypeChecker::conformsToProtocol(type, proto, dc)) + if (auto result = TypeChecker::conformsToProtocol(type, proto, + dc->getParentModule())) return result; } diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index 162c2caf951..633a7c16b55 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -3197,7 +3197,8 @@ NeverNullType TypeResolver::resolveSILFunctionType( } witnessMethodConformance = TypeChecker::conformsToProtocol( - selfType, protocolType->getDecl(), getDeclContext()); + selfType, protocolType->getDecl(), + getDeclContext()->getParentModule()); assert(witnessMethodConformance && "found witness_method without matching conformance"); } diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 07bfa8adb63..a4653095c59 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -188,7 +188,7 @@ enum class Comparison { /// formatted with \c diagnoseConformanceStack. struct ParentConditionalConformance { Type ConformingType; - ProtocolType *Protocol; + ProtocolDecl *Protocol; /// Format the stack \c conformances as a series of notes that trace a path of /// conditional conformances that lead to some other failing requirement (that @@ -498,6 +498,12 @@ RequirementCheckResult checkGenericArguments( ArrayRef requirements, TypeSubstitutionFn substitutions, SubstOptions options = None); +/// A lower-level version of the above without diagnostic emission. +RequirementCheckResult checkGenericArguments( + ModuleDecl *module, + ArrayRef requirements, + TypeSubstitutionFn substitutions); + bool checkContextualRequirements(GenericTypeDecl *decl, Type parentTy, SourceLoc loc, @@ -710,13 +716,10 @@ Expr *addImplicitLoadExpr( /// Determine whether the given type contains the given protocol. /// -/// \param DC The context in which to check conformance. This affects, for -/// example, extension visibility. -/// /// \returns the conformance, if \c T conforms to the protocol \c Proto, or /// an empty optional. ProtocolConformanceRef containsProtocol(Type T, ProtocolDecl *Proto, - DeclContext *DC, + ModuleDecl *M, bool skipConditionalRequirements=false); /// Determine whether the given type conforms to the given protocol. @@ -724,25 +727,22 @@ ProtocolConformanceRef containsProtocol(Type T, ProtocolDecl *Proto, /// Unlike subTypeOfProtocol(), this will return false for existentials of /// non-self conforming protocols. /// -/// \param DC The context in which to check conformance. This affects, for -/// example, extension visibility. -/// /// \returns The protocol conformance, if \c T conforms to the /// protocol \c Proto, or \c None. ProtocolConformanceRef conformsToProtocol(Type T, ProtocolDecl *Proto, - DeclContext *DC); + ModuleDecl *M); + +/// Check whether the type conforms to a given known protocol. +bool conformsToKnownProtocol(Type type, KnownProtocolKind protocol, + ModuleDecl *module); /// This is similar to \c conformsToProtocol, but returns \c true for cases where /// the type \p T could be dynamically cast to \p Proto protocol, such as a non-final /// class where a subclass conforms to \p Proto. /// -/// \param DC The context in which to check conformance. This affects, for -/// example, extension visibility. -/// -/// /// \returns True if \p T conforms to the protocol \p Proto, false otherwise. bool couldDynamicallyConformToProtocol(Type T, ProtocolDecl *Proto, - DeclContext *DC); + ModuleDecl *M); /// Completely check the given conformance. void checkConformance(NormalProtocolConformance *conformance); @@ -1181,13 +1181,13 @@ diag::RequirementKind getProtocolRequirementKind(ValueDecl *Requirement); /// @dynamicCallable attribute requirement. The method is given to be defined /// as one of the following: `dynamicallyCall(withArguments:)` or /// `dynamicallyCall(withKeywordArguments:)`. -bool isValidDynamicCallableMethod(FuncDecl *decl, DeclContext *DC, +bool isValidDynamicCallableMethod(FuncDecl *decl, ModuleDecl *module, bool hasKeywordArguments); /// Returns true if the given subscript method is an valid implementation of /// the `subscript(dynamicMember:)` requirement for @dynamicMemberLookup. /// The method is given to be defined as `subscript(dynamicMember:)`. -bool isValidDynamicMemberLookupSubscript(SubscriptDecl *decl, DeclContext *DC, +bool isValidDynamicMemberLookupSubscript(SubscriptDecl *decl, ModuleDecl *module, bool ignoreLabel = false); /// Returns true if the given subscript method is an valid implementation of @@ -1195,7 +1195,7 @@ bool isValidDynamicMemberLookupSubscript(SubscriptDecl *decl, DeclContext *DC, /// The method is given to be defined as `subscript(dynamicMember:)` which /// takes a single non-variadic parameter that conforms to /// `ExpressibleByStringLiteral` protocol. -bool isValidStringDynamicMemberLookup(SubscriptDecl *decl, DeclContext *DC, +bool isValidStringDynamicMemberLookup(SubscriptDecl *decl, ModuleDecl *module, bool ignoreLabel = false); /// Returns true if the given subscript method is an valid implementation of @@ -1285,11 +1285,6 @@ bool diagnoseObjCUnsatisfiedOptReqConflicts(SourceFile &sf); std::pair getObjCMethodDiagInfo( AbstractFunctionDecl *method); -bool areGenericRequirementsSatisfied(const DeclContext *DC, - GenericSignature sig, - SubstitutionMap Substitutions, - bool isExtension); - /// Check for restrictions on the use of the @unknown attribute on a /// case statement. void checkUnknownAttrRestrictions( diff --git a/test/IDE/print_synthesized_extensions.swift b/test/IDE/print_synthesized_extensions.swift index f1786c2396e..01585a220a8 100644 --- a/test/IDE/print_synthesized_extensions.swift +++ b/test/IDE/print_synthesized_extensions.swift @@ -393,7 +393,7 @@ extension C : P8 {} public class F {} extension F : P8 {} -// CHECK16: extension F where T : D { +// CHECK16: public class F<T> where T : print_synthesized_extensions.D { // CHECK16-NEXT: public func bar() // CHECK16-NEXT: } diff --git a/test/IDE/print_synthesized_extensions_superclass.swift b/test/IDE/print_synthesized_extensions_superclass.swift new file mode 100644 index 00000000000..e83de3f9e84 --- /dev/null +++ b/test/IDE/print_synthesized_extensions_superclass.swift @@ -0,0 +1,132 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend -emit-module-path %t/print_synthesized_extensions_superclass.swiftmodule -emit-module-doc -emit-module-doc-path %t/print_synthesized_extensions_superclass.swiftdoc %s +// RUN: %target-swift-ide-test -print-module -synthesize-extension -print-interface -no-empty-line-between-members -module-to-print=print_synthesized_extensions_superclass -I %t -source-filename=%s | %FileCheck %s + +public class Base {} +public class Middle : Base {} +public class Most : Middle {} + +public protocol P { + associatedtype T + associatedtype U +} + +public extension P where T : Base { + func withBase() {} +} + +public extension P where T : Middle { + func withMiddleAbstract() {} +} + +public extension P where T : Middle { + func withMiddleConcrete() {} +} + +public extension P where T : Most { + func withMost() {} +} + +// CHECK-LABEL: public struct S1 : print_synthesized_extensions_superclass.P { +// CHECK-NEXT: public typealias T = print_synthesized_extensions_superclass.Base +// CHECK-NEXT: public typealias U = Int +// CHECK-NEXT: public func withBase() +// CHECk-NEXT: } + +public struct S1 : P { + public typealias T = Base + public typealias U = Int +} + +// CHECK-LABEL: public struct S2 : print_synthesized_extensions_superclass.P { +// CHECK-NEXT: public typealias T = print_synthesized_extensions_superclass.Middle +// CHECK-NEXT: public typealias U = Int +// CHECK-NEXT: public func withBase() +// CHECk-NEXT: public func withMiddleAbstract() +// CHECk-NEXT: public func withMiddleConcrete() +// CHECk-NEXT: } + +public struct S2 : P { + public typealias T = Middle + public typealias U = Int +} + +// CHECK-LABEL: public struct S3 : print_synthesized_extensions_superclass.P { +// CHECK-NEXT: public typealias T = print_synthesized_extensions_superclass.Middle +// CHECK-NEXT: public typealias U = String +// CHECK-NEXT: public func withBase() +// CHECK-NEXT: public func withMiddleAbstract() +// CHECK-NEXT: } + +public struct S3 : P { + public typealias T = Middle + public typealias U = String +} + +// CHECK-LABEL: public struct S4 : print_synthesized_extensions_superclass.P { +// CHECK-NEXT: public typealias T = print_synthesized_extensions_superclass.Most +// CHECK-NEXT: public typealias U = Int +// CHECK-NEXT: public func withBase() +// CHECK-NEXT: public func withMiddleAbstract() +// CHECK-NEXT: public func withMiddleConcrete() +// CHECK-NEXT: public func withMost() +// CHECK-NEXT: } + +public struct S4 : P { + public typealias T = Most + public typealias U = Int +} + +// CHECK-LABEL: public struct S5 : print_synthesized_extensions_superclass.P { +// CHECK-NEXT: public typealias T = print_synthesized_extensions_superclass.Most +// CHECK-NEXT: public typealias U = String +// CHECK-NEXT: public func withBase() +// CHECK-NEXT: public func withMiddleConcrete() +// CHECK-NEXT: public func withMost() +// CHECK-NEXT: } + +public struct S5 : P { + public typealias T = Most + public typealias U = String +} + +// CHECK-LABEL: public struct S6 : print_synthesized_extensions_superclass.P where T : print_synthesized_extensions_superclass.Base { +// CHECK-NEXT: public func withBase() +// CHECK-NEXT: } + +// CHECK-LABEL: extension S6 where T : Middle { +// CHECK-NEXT: public func withMiddleAbstract() +// CHECK-NEXT: } + +// CHECK-LABEL: extension S6 where T : Middle { +// CHECK-NEXT: public func withMiddleConcrete() +// CHECK-NEXT: } + +// CHECK-LABEL: extension S6 where T : Most { +// CHECK-NEXT: public func withMost() +// CHECK-NEXT: } + +public struct S6 : P where T : Base {} + +// CHECK-LABEL: public struct S7 : print_synthesized_extensions_superclass.P where T : print_synthesized_extensions_superclass.Middle { +// CHECK-NEXT: public func withBase() +// CHECK-NEXT: public func withMiddleAbstract() +// CHECK-NEXT: } + +// CHECK-LABEL: extension S7 where T : Middle { +// CHECK-NEXT: public func withMiddleConcrete() +// CHECK-NEXT: } + +// CHECK-LABEL: extension S7 where T : Most { +// CHECK-NEXT: public func withMost() +// CHECK-NEXT: } + +public struct S7 : P where T : Middle {} + +// CHECK-LABEL: public struct S8 : print_synthesized_extensions_superclass.P where T : print_synthesized_extensions_superclass.Most { +// CHECK-NEXT: public func withBase() +// CHECK-NEXT: public func withMiddleConcrete() +// CHECK-NEXT: public func withMost() +// CHECK-NEXT: } + +public struct S8 : P where T : Most {} \ No newline at end of file diff --git a/test/SourceKit/DocSupport/doc_swift_module_class_extension.swift.response b/test/SourceKit/DocSupport/doc_swift_module_class_extension.swift.response index a8a3a3056cc..4935b7174dc 100644 --- a/test/SourceKit/DocSupport/doc_swift_module_class_extension.swift.response +++ b/test/SourceKit/DocSupport/doc_swift_module_class_extension.swift.response @@ -26,16 +26,13 @@ class E { } class F where T : module_with_class_extension.D { + + func bar() } extension F : module_with_class_extension.P8 { } -extension F where T : D { - - func bar() -} - protocol P8 { associatedtype T @@ -270,193 +267,162 @@ extension P8 where Self.T : module_with_class_extension.E { }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 328, + key.offset: 330, + key.length: 4 + }, + { + key.kind: source.lang.swift.syntaxtype.identifier, + key.offset: 335, + key.length: 3 + }, + { + key.kind: source.lang.swift.syntaxtype.keyword, + key.offset: 344, key.length: 9 }, { key.kind: source.lang.swift.ref.class, key.name: "F", key.usr: "s:27module_with_class_extension1FC", - key.offset: 338, + key.offset: 354, key.length: 1 }, { key.kind: source.lang.swift.syntaxtype.typeidentifier, - key.offset: 342, + key.offset: 358, key.length: 27 }, { key.kind: source.lang.swift.ref.protocol, key.name: "P8", key.usr: "s:27module_with_class_extension2P8P", - key.offset: 370, + key.offset: 386, key.length: 2 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 378, - key.length: 9 - }, - { - key.kind: source.lang.swift.ref.class, - key.name: "F", - key.usr: "s:27module_with_class_extension1FC", - key.offset: 388, - key.length: 1 - }, - { - key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 390, - key.length: 5 - }, - { - key.kind: source.lang.swift.ref.generic_type_param, - key.name: "T", - key.usr: "s:27module_with_class_extension1FC1Txmfp", - key.offset: 396, - key.length: 1 - }, - { - key.kind: source.lang.swift.ref.class, - key.name: "D", - key.usr: "s:27module_with_class_extension1DC", - key.offset: 400, - key.length: 1 - }, - { - key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 409, - key.length: 4 - }, - { - key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 414, - key.length: 3 - }, - { - key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 423, + key.offset: 394, key.length: 8 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 432, + key.offset: 403, key.length: 2 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 442, + key.offset: 413, key.length: 14 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 457, + key.offset: 428, key.length: 1 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 462, + key.offset: 433, key.length: 9 }, { key.kind: source.lang.swift.ref.protocol, key.name: "P8", key.usr: "s:27module_with_class_extension2P8P", - key.offset: 472, + key.offset: 443, key.length: 2 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 475, + key.offset: 446, key.length: 5 }, { key.kind: source.lang.swift.ref.generic_type_param, key.name: "Self", key.usr: "s:27module_with_class_extension2P8PA2A1DC1TRczrlE4Selfxmfp", - key.offset: 481, + key.offset: 452, key.length: 4 }, { key.kind: source.lang.swift.ref.associatedtype, key.name: "T", key.usr: "s:27module_with_class_extension2P8P1TQa", - key.offset: 486, + key.offset: 457, key.length: 1 }, { key.kind: source.lang.swift.syntaxtype.typeidentifier, - key.offset: 490, + key.offset: 461, key.length: 27 }, { key.kind: source.lang.swift.ref.class, key.name: "D", key.usr: "s:27module_with_class_extension1DC", - key.offset: 518, + key.offset: 489, key.length: 1 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 527, + key.offset: 498, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 532, + key.offset: 503, key.length: 3 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 541, + key.offset: 512, key.length: 9 }, { key.kind: source.lang.swift.ref.protocol, key.name: "P8", key.usr: "s:27module_with_class_extension2P8P", - key.offset: 551, + key.offset: 522, key.length: 2 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 554, + key.offset: 525, key.length: 5 }, { key.kind: source.lang.swift.ref.generic_type_param, key.name: "Self", key.usr: "s:27module_with_class_extension2P8PA2A1EC1TRczrlE4Selfxmfp", - key.offset: 560, + key.offset: 531, key.length: 4 }, { key.kind: source.lang.swift.ref.associatedtype, key.name: "T", key.usr: "s:27module_with_class_extension2P8P1TQa", - key.offset: 565, + key.offset: 536, key.length: 1 }, { key.kind: source.lang.swift.syntaxtype.typeidentifier, - key.offset: 569, + key.offset: 540, key.length: 27 }, { key.kind: source.lang.swift.ref.class, key.name: "E", key.usr: "s:27module_with_class_extension1EC", - key.offset: 597, + key.offset: 568, key.length: 1 }, { key.kind: source.lang.swift.syntaxtype.keyword, - key.offset: 606, + key.offset: 577, key.length: 4 }, { key.kind: source.lang.swift.syntaxtype.identifier, - key.offset: 611, + key.offset: 582, key.length: 3 } ] @@ -590,8 +556,19 @@ extension P8 where Self.T : module_with_class_extension.E { } ], key.offset: 272, - key.length: 54, - key.fully_annotated_decl: "class F<T> where T : D" + key.length: 70, + key.fully_annotated_decl: "class F<T> where T : D", + key.entities: [ + { + key.kind: source.lang.swift.decl.function.method.instance, + key.name: "bar()", + key.usr: "s:27module_with_class_extension2P8PA2A1DC1TRczrlE3baryyF::SYNTHESIZED::s:27module_with_class_extension1FC", + key.original_usr: "s:27module_with_class_extension2P8PA2A1DC1TRczrlE3baryyF", + key.offset: 330, + key.length: 10, + key.fully_annotated_decl: "func bar()" + } + ] }, { key.kind: source.lang.swift.decl.extension.class, @@ -605,7 +582,7 @@ extension P8 where Self.T : module_with_class_extension.E { key.description: "T : D" } ], - key.offset: 328, + key.offset: 344, key.length: 48, key.fully_annotated_decl: "extension F : P8", key.conforms: [ @@ -621,38 +598,11 @@ extension P8 where Self.T : module_with_class_extension.E { key.usr: "s:27module_with_class_extension1FC" } }, - { - key.kind: source.lang.swift.decl.extension.class, - key.generic_requirements: [ - { - key.description: "T : D" - } - ], - key.offset: 378, - key.length: 43, - key.fully_annotated_decl: "extension F where T : D", - key.extends: { - key.kind: source.lang.swift.ref.class, - key.name: "F", - key.usr: "s:27module_with_class_extension1FC" - }, - key.entities: [ - { - key.kind: source.lang.swift.decl.function.method.instance, - key.name: "bar()", - key.usr: "s:27module_with_class_extension2P8PA2A1DC1TRczrlE3baryyF::SYNTHESIZED::s:27module_with_class_extension1FC", - key.original_usr: "s:27module_with_class_extension2P8PA2A1DC1TRczrlE3baryyF", - key.offset: 409, - key.length: 10, - key.fully_annotated_decl: "func bar()" - } - ] - }, { key.kind: source.lang.swift.decl.protocol, key.name: "P8", key.usr: "s:27module_with_class_extension2P8P", - key.offset: 423, + key.offset: 394, key.length: 37, key.fully_annotated_decl: "protocol P8", key.entities: [ @@ -660,7 +610,7 @@ extension P8 where Self.T : module_with_class_extension.E { key.kind: source.lang.swift.decl.associatedtype, key.name: "T", key.usr: "s:27module_with_class_extension2P8P1TQa", - key.offset: 442, + key.offset: 413, key.length: 16, key.fully_annotated_decl: "associatedtype T" } @@ -673,7 +623,7 @@ extension P8 where Self.T : module_with_class_extension.E { key.description: "Self.T : D" } ], - key.offset: 462, + key.offset: 433, key.length: 77, key.fully_annotated_decl: "extension P8 where Self.T : D", key.extends: { @@ -686,7 +636,7 @@ extension P8 where Self.T : module_with_class_extension.E { key.kind: source.lang.swift.decl.function.method.instance, key.name: "bar()", key.usr: "s:27module_with_class_extension2P8PA2A1DC1TRczrlE3baryyF", - key.offset: 527, + key.offset: 498, key.length: 10, key.fully_annotated_decl: "func bar()" } @@ -699,7 +649,7 @@ extension P8 where Self.T : module_with_class_extension.E { key.description: "Self.T : E" } ], - key.offset: 541, + key.offset: 512, key.length: 77, key.fully_annotated_decl: "extension P8 where Self.T : E", key.extends: { @@ -712,7 +662,7 @@ extension P8 where Self.T : module_with_class_extension.E { key.kind: source.lang.swift.decl.function.method.instance, key.name: "baz()", key.usr: "s:27module_with_class_extension2P8PA2A1EC1TRczrlE3bazyyF", - key.offset: 606, + key.offset: 577, key.length: 10, key.fully_annotated_decl: "func baz()" }