//===--- RequirementLowering.cpp - Building rules from requirements -------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2021 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements logic for lowering generic requirements to rewrite rules // in the requirement machine. // // This includes generic requirements from canonical generic signatures and // protocol requirement signatures, as well as user-written requirements in // protocols ("structural requirements") and the 'where' clauses of generic // declarations. // // There is some additional desugaring logic for user-written requirements. // //===----------------------------------------------------------------------===// #include "RequirementLowering.h" #include "swift/AST/ASTContext.h" #include "swift/AST/Decl.h" #include "swift/AST/DiagnosticsSema.h" #include "swift/AST/ExistentialLayout.h" #include "swift/AST/Requirement.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/AST/TypeMatcher.h" #include "swift/AST/TypeRepr.h" #include "llvm/ADT/SmallVector.h" #include "RewriteContext.h" #include "RewriteSystem.h" #include "Symbol.h" #include "Term.h" using namespace swift; using namespace rewriting; // // Requirement desugaring -- used in two places: // // 1) AbstractGenericSignatureRequest, where the added requirements might have // substitutions applied. // // 2) StructuralRequirementsRequest, which performs further processing to wrap // desugared requirements with source location information. // /// Desugar a same-type requirement that possibly has concrete types on either /// side into a series of same-type and concrete-type requirements where the /// left hand side is always a type parameter. static void desugarSameTypeRequirement(Type lhs, Type rhs, SmallVectorImpl &result) { class Matcher : public TypeMatcher { SmallVectorImpl &result; public: explicit Matcher(SmallVectorImpl &result) : result(result) {} bool mismatch(TypeBase *firstType, TypeBase *secondType, Type sugaredFirstType) { if (firstType->isTypeParameter() && secondType->isTypeParameter()) { result.emplace_back(RequirementKind::SameType, firstType, secondType); return true; } if (firstType->isTypeParameter()) { result.emplace_back(RequirementKind::SameType, firstType, secondType); return true; } if (secondType->isTypeParameter()) { result.emplace_back(RequirementKind::SameType, secondType, firstType); return true; } // FIXME: Record concrete type conflict, diagnose upstream return true; } } matcher(result); if (lhs->hasError() || rhs->hasError()) return; // FIXME: Record redundancy and diagnose upstream (void) matcher.match(lhs, rhs); } static void desugarSuperclassRequirement(Type subjectType, Type constraintType, SmallVectorImpl &result) { if (!subjectType->isTypeParameter()) { // FIXME: Perform unification, diagnose redundancy or conflict upstream return; } result.emplace_back(RequirementKind::Superclass, subjectType, constraintType); } static void desugarLayoutRequirement(Type subjectType, LayoutConstraint layout, SmallVectorImpl &result) { if (!subjectType->isTypeParameter()) { // FIXME: Diagnose redundancy or conflict upstream return; } result.emplace_back(RequirementKind::Layout, subjectType, layout); } static Type lookupMemberType(Type subjectType, ProtocolDecl *protoDecl, AssociatedTypeDecl *assocType) { if (subjectType->isTypeParameter()) return DependentMemberType::get(subjectType, assocType); auto *M = protoDecl->getParentModule(); auto conformance = M->lookupConformance( subjectType, protoDecl); return conformance.getAssociatedType(subjectType, assocType->getDeclaredInterfaceType()); } /// Desugar a protocol conformance requirement by splitting up protocol /// compositions on the right hand side into conformance and superclass /// requirements. static void desugarConformanceRequirement(Type subjectType, Type constraintType, SmallVectorImpl &result) { // Fast path. if (constraintType->is()) { if (!subjectType->isTypeParameter()) { // Check if the subject type actually conforms. auto *protoDecl = constraintType->castTo()->getDecl(); auto *module = protoDecl->getParentModule(); auto conformance = module->lookupConformance(subjectType, protoDecl); if (conformance.isInvalid()) { // FIXME: Diagnose a conflict. return; } // FIXME: Diagnose a redundancy. assert(conformance.isConcrete()); auto *concrete = conformance.getConcrete(); // Introduce conditional requirements if the subject type is concrete. for (auto req : concrete->getConditionalRequirements()) { desugarRequirement(req, result); } return; } result.emplace_back(RequirementKind::Conformance, subjectType, constraintType); return; } if (auto *paramType = constraintType->getAs()) { auto *protoDecl = paramType->getBaseType()->getDecl(); desugarConformanceRequirement(subjectType, paramType->getBaseType(), result); auto *assocType = protoDecl->getPrimaryAssociatedType(); auto memberType = lookupMemberType(subjectType, protoDecl, assocType); desugarSameTypeRequirement(memberType, paramType->getArgumentType(), result); return; } auto *compositionType = constraintType->castTo(); if (compositionType->hasExplicitAnyObject()) { desugarLayoutRequirement(subjectType, LayoutConstraint::getLayoutConstraint( LayoutConstraintKind::Class), result); } for (auto memberType : compositionType->getMembers()) { if (memberType->isExistentialType()) desugarConformanceRequirement(subjectType, memberType, result); else desugarSuperclassRequirement(subjectType, memberType, result); } } /// Convert a requirement where the subject type might not be a type parameter, /// or the constraint type in the conformance requirement might be a protocol /// composition, into zero or more "proper" requirements which can then be /// converted into rewrite rules by the RuleBuilder. void swift::rewriting::desugarRequirement(Requirement req, SmallVectorImpl &result) { auto firstType = req.getFirstType(); switch (req.getKind()) { case RequirementKind::Conformance: desugarConformanceRequirement(firstType, req.getSecondType(), result); break; case RequirementKind::Superclass: desugarSuperclassRequirement(firstType, req.getSecondType(), result); break; case RequirementKind::Layout: desugarLayoutRequirement(firstType, req.getLayoutConstraint(), result); break; case RequirementKind::SameType: desugarSameTypeRequirement(firstType, req.getSecondType(), result); break; } } // // StructuralRequirementsRequest computation. // // This realizes RequirementReprs into Requirements, desugars them using the // above, performs requirement inference, and wraps them with source location // information. // static void realizeTypeRequirement(Type subjectType, Type constraintType, SourceLoc loc, SmallVectorImpl &result) { SmallVector reqs; if (constraintType->isConstraintType()) { // Handle conformance requirements. desugarConformanceRequirement(subjectType, constraintType, reqs); } else if (constraintType->getClassOrBoundGenericClass()) { // Handle superclass requirements. desugarSuperclassRequirement(subjectType, constraintType, reqs); } else { // FIXME: Diagnose return; } // Add source location information. for (auto req : reqs) result.push_back({req, loc, /*wasInferred=*/false}); } namespace { /// AST walker that infers requirements from type representations. struct InferRequirementsWalker : public TypeWalker { ModuleDecl *module; SmallVector reqs; explicit InferRequirementsWalker(ModuleDecl *module) : module(module) {} Action walkToTypePre(Type ty) override { // Unbound generic types are the result of recovered-but-invalid code, and // don't have enough info to do any useful substitutions. if (ty->is()) return Action::Stop; return Action::Continue; } Action walkToTypePost(Type ty) override { // Infer from generic typealiases. if (auto typeAlias = dyn_cast(ty.getPointer())) { auto decl = typeAlias->getDecl(); auto subMap = typeAlias->getSubstitutionMap(); for (const auto &rawReq : decl->getGenericSignature().getRequirements()) { if (auto req = rawReq.subst(subMap)) desugarRequirement(*req, reqs); } return Action::Continue; } // Infer requirements from `@differentiable` function types. // For all non-`@noDerivative` parameter and result types: // - `@differentiable`, `@differentiable(_forward)`, or // `@differentiable(reverse)`: add `T: Differentiable` requirement. // - `@differentiable(_linear)`: add // `T: Differentiable`, `T == T.TangentVector` requirements. if (auto *fnTy = ty->getAs()) { auto &ctx = module->getASTContext(); auto *differentiableProtocol = ctx.getProtocol(KnownProtocolKind::Differentiable); if (differentiableProtocol && fnTy->isDifferentiable()) { auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) { Requirement req(RequirementKind::Conformance, type, protocol->getDeclaredInterfaceType()); desugarRequirement(req, reqs); }; auto addSameTypeConstraint = [&](Type firstType, AssociatedTypeDecl *assocType) { auto *protocol = assocType->getProtocol(); auto secondType = lookupMemberType(firstType, protocol, assocType); Requirement req(RequirementKind::SameType, firstType, secondType); desugarRequirement(req, reqs); }; auto *tangentVectorAssocType = differentiableProtocol->getAssociatedType(ctx.Id_TangentVector); auto addRequirements = [&](Type type, bool isLinear) { addConformanceConstraint(type, differentiableProtocol); if (isLinear) addSameTypeConstraint(type, tangentVectorAssocType); }; auto constrainParametersAndResult = [&](bool isLinear) { for (auto ¶m : fnTy->getParams()) if (!param.isNoDerivative()) addRequirements(param.getPlainType(), isLinear); addRequirements(fnTy->getResult(), isLinear); }; // Add requirements. constrainParametersAndResult(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear); } } if (!ty->isSpecialized()) return Action::Continue; // Infer from generic nominal types. auto decl = ty->getAnyNominal(); if (!decl) return Action::Continue; auto genericSig = decl->getGenericSignature(); if (!genericSig) return Action::Continue; /// Retrieve the substitution. auto subMap = ty->getContextSubstitutionMap(module, decl); // Handle the requirements. // FIXME: Inaccurate TypeReprs. for (const auto &rawReq : genericSig.getRequirements()) { if (auto req = rawReq.subst(subMap)) desugarRequirement(*req, reqs); } return Action::Continue; } }; } /// Infer requirements from applications of BoundGenericTypes to type /// parameters. For example, given a function declaration /// /// func union(_ x: Set, _ y: Set) /// /// We automatically infer 'T : Hashable' from the fact that 'struct Set' /// declares a Hashable requirement on its generic parameter. void swift::rewriting::inferRequirements( Type type, SourceLoc loc, ModuleDecl *module, SmallVectorImpl &result) { if (!type) return; InferRequirementsWalker walker(module); type.walk(walker); for (const auto &req : walker.reqs) result.push_back({req, loc, /*wasInferred=*/true}); } /// Desugar a requirement and perform requirement inference if requested /// to obtain zero or more structural requirements. void swift::rewriting::realizeRequirement( Requirement req, RequirementRepr *reqRepr, ModuleDecl *moduleForInference, SmallVectorImpl &result) { auto firstType = req.getFirstType(); auto loc = (reqRepr ? reqRepr->getSeparatorLoc() : SourceLoc()); switch (req.getKind()) { case RequirementKind::Superclass: case RequirementKind::Conformance: { auto secondType = req.getSecondType(); if (moduleForInference) { auto firstLoc = (reqRepr ? reqRepr->getSubjectRepr()->getStartLoc() : SourceLoc()); inferRequirements(firstType, firstLoc, moduleForInference, result); auto secondLoc = (reqRepr ? reqRepr->getConstraintRepr()->getStartLoc() : SourceLoc()); inferRequirements(secondType, secondLoc, moduleForInference, result); } realizeTypeRequirement(firstType, secondType, loc, result); break; } case RequirementKind::Layout: { if (moduleForInference) { auto firstLoc = (reqRepr ? reqRepr->getSubjectRepr()->getStartLoc() : SourceLoc()); inferRequirements(firstType, firstLoc, moduleForInference, result); } SmallVector reqs; desugarLayoutRequirement(firstType, req.getLayoutConstraint(), reqs); for (auto req : reqs) result.push_back({req, loc, /*wasInferred=*/false}); break; } case RequirementKind::SameType: { auto secondType = req.getSecondType(); if (moduleForInference) { auto firstLoc = (reqRepr ? reqRepr->getFirstTypeRepr()->getStartLoc() : SourceLoc()); inferRequirements(firstType, firstLoc, moduleForInference, result); auto secondLoc = (reqRepr ? reqRepr->getSecondTypeRepr()->getStartLoc() : SourceLoc()); inferRequirements(secondType, secondLoc, moduleForInference, result); } SmallVector reqs; desugarSameTypeRequirement(req.getFirstType(), secondType, reqs); for (auto req : reqs) result.push_back({req, loc, /*wasInferred=*/false}); break; } } } /// Collect structural requirements written in the inheritance clause of an /// AssociatedTypeDecl or GenericTypeParamDecl. void swift::rewriting::realizeInheritedRequirements( TypeDecl *decl, Type type, ModuleDecl *moduleForInference, SmallVectorImpl &result) { auto &ctx = decl->getASTContext(); auto inheritedTypes = decl->getInherited(); for (unsigned index : indices(inheritedTypes)) { Type inheritedType = evaluateOrDefault(ctx.evaluator, InheritedTypeRequest{decl, index, TypeResolutionStage::Structural}, Type()); if (!inheritedType) continue; auto *typeRepr = inheritedTypes[index].getTypeRepr(); SourceLoc loc = (typeRepr ? typeRepr->getStartLoc() : SourceLoc()); if (moduleForInference) { inferRequirements(inheritedType, loc, moduleForInference, result); } realizeTypeRequirement(type, inheritedType, loc, result); } } ArrayRef StructuralRequirementsRequest::evaluate(Evaluator &evaluator, ProtocolDecl *proto) const { assert(!proto->hasLazyRequirementSignature()); SmallVector result; auto &ctx = proto->getASTContext(); auto selfTy = proto->getSelfInterfaceType(); realizeInheritedRequirements(proto, selfTy, /*moduleForInference=*/nullptr, result); // Add requirements from the protocol's own 'where' clause. WhereClauseOwner(proto).visitRequirements(TypeResolutionStage::Structural, [&](const Requirement &req, RequirementRepr *reqRepr) { realizeRequirement(req, reqRepr, /*moduleForInference=*/nullptr, result); return false; }); if (proto->isObjC()) { // @objc protocols have an implicit AnyObject requirement on Self. auto layout = LayoutConstraint::getLayoutConstraint( LayoutConstraintKind::Class, ctx); result.push_back({Requirement(RequirementKind::Layout, selfTy, layout), proto->getLoc(), /*inferred=*/true}); // Remaining logic is not relevant to @objc protocols. return ctx.AllocateCopy(result); } // Add requirements for each of the associated types. for (auto assocTypeDecl : proto->getAssociatedTypeMembers()) { // Add requirements placed directly on this associated type. auto assocType = assocTypeDecl->getDeclaredInterfaceType(); realizeInheritedRequirements(assocTypeDecl, assocType, /*moduleForInference=*/nullptr, result); // Add requirements from this associated type's where clause. WhereClauseOwner(assocTypeDecl).visitRequirements( TypeResolutionStage::Structural, [&](const Requirement &req, RequirementRepr *reqRepr) { realizeRequirement(req, reqRepr, /*moduleForInference=*/nullptr, result); return false; }); } return ctx.AllocateCopy(result); } ArrayRef TypeAliasRequirementsRequest::evaluate(Evaluator &evaluator, ProtocolDecl *proto) const { // @objc protocols don't have associated types, so all of the below // becomes a trivial no-op. if (proto->isObjC()) return ArrayRef(); assert(!proto->hasLazyRequirementSignature()); SmallVector result; auto &ctx = proto->getASTContext(); // In Verify mode, the GenericSignatureBuilder will emit the same diagnostics. bool emitDiagnostics = (ctx.LangOpts.RequirementMachineProtocolSignatures == RequirementMachineMode::Enabled); // Collect all typealiases from inherited protocols recursively. llvm::MapVector> inheritedTypeDecls; for (auto *inheritedProto : ctx.getRewriteContext().getInheritedProtocols(proto)) { for (auto req : inheritedProto->getMembers()) { if (auto *typeReq = dyn_cast(req)) { // Ignore generic types. if (auto genReq = dyn_cast(req)) if (genReq->getGenericParams()) continue; inheritedTypeDecls[typeReq->getName()].push_back(typeReq); } } } auto getStructuralType = [](TypeDecl *typeDecl) -> Type { if (auto typealias = dyn_cast(typeDecl)) { if (typealias->getUnderlyingTypeRepr() != nullptr) { auto type = typealias->getStructuralType(); if (auto *aliasTy = cast(type.getPointer())) return aliasTy->getSinglyDesugaredType(); return type; } return typealias->getUnderlyingType(); } return typeDecl->getDeclaredInterfaceType(); }; // An inferred same-type requirement between the two type declarations // within this protocol or a protocol it inherits. auto recordInheritedTypeRequirement = [&](TypeDecl *first, TypeDecl *second) { desugarSameTypeRequirement(getStructuralType(first), getStructuralType(second), result); }; // Local function to find the insertion point for the protocol's "where" // clause, as well as the string to start the insertion ("where" or ","); auto getProtocolWhereLoc = [&]() -> Located { // Already has a trailing where clause. if (auto trailing = proto->getTrailingWhereClause()) return { ", ", trailing->getRequirements().back().getSourceRange().End }; // Inheritance clause. return { " where ", proto->getInherited().back().getSourceRange().End }; }; // Retrieve the set of requirements that a given associated type declaration // produces, in the form that would be seen in the where clause. const auto getAssociatedTypeReqs = [&](const AssociatedTypeDecl *assocType, const char *start) { std::string result; { llvm::raw_string_ostream out(result); out << start; interleave(assocType->getInherited(), [&](TypeLoc inheritedType) { out << assocType->getName() << ": "; if (auto inheritedTypeRepr = inheritedType.getTypeRepr()) inheritedTypeRepr->print(out); else inheritedType.getType().print(out); }, [&] { out << ", "; }); if (const auto whereClause = assocType->getTrailingWhereClause()) { if (!assocType->getInherited().empty()) out << ", "; whereClause->print(out, /*printWhereKeyword*/false); } } return result; }; // Retrieve the requirement that a given typealias introduces when it // overrides an inherited associated type with the same name, as a string // suitable for use in a where clause. auto getConcreteTypeReq = [&](TypeDecl *type, const char *start) { std::string result; { llvm::raw_string_ostream out(result); out << start; out << type->getName() << " == "; if (auto typealias = dyn_cast(type)) { if (auto underlyingTypeRepr = typealias->getUnderlyingTypeRepr()) underlyingTypeRepr->print(out); else typealias->getUnderlyingType().print(out); } else { type->print(out); } } return result; }; for (auto assocTypeDecl : proto->getAssociatedTypeMembers()) { // Check whether we inherited any types with the same name. auto knownInherited = inheritedTypeDecls.find(assocTypeDecl->getName()); if (knownInherited == inheritedTypeDecls.end()) continue; bool shouldWarnAboutRedeclaration = emitDiagnostics && !assocTypeDecl->getAttrs().hasAttribute() && !assocTypeDecl->getAttrs().hasAttribute() && !assocTypeDecl->hasDefaultDefinitionType() && (!assocTypeDecl->getInherited().empty() || assocTypeDecl->getTrailingWhereClause() || ctx.LangOpts.WarnImplicitOverrides); for (auto inheritedType : knownInherited->second) { // If we have inherited associated type... if (auto inheritedAssocTypeDecl = dyn_cast(inheritedType)) { // Complain about the first redeclaration. if (shouldWarnAboutRedeclaration) { auto inheritedFromProto = inheritedAssocTypeDecl->getProtocol(); auto fixItWhere = getProtocolWhereLoc(); ctx.Diags.diagnose(assocTypeDecl, diag::inherited_associated_type_redecl, assocTypeDecl->getName(), inheritedFromProto->getDeclaredInterfaceType()) .fixItInsertAfter( fixItWhere.Loc, getAssociatedTypeReqs(assocTypeDecl, fixItWhere.Item)) .fixItRemove(assocTypeDecl->getSourceRange()); ctx.Diags.diagnose(inheritedAssocTypeDecl, diag::decl_declared_here, inheritedAssocTypeDecl->getName()); shouldWarnAboutRedeclaration = false; } continue; } if (emitDiagnostics) { // We inherited a type; this associated type will be identical // to that typealias. auto inheritedOwningDecl = inheritedType->getDeclContext()->getSelfNominalTypeDecl(); ctx.Diags.diagnose(assocTypeDecl, diag::associated_type_override_typealias, assocTypeDecl->getName(), inheritedOwningDecl->getDescriptiveKind(), inheritedOwningDecl->getDeclaredInterfaceType()); } recordInheritedTypeRequirement(assocTypeDecl, inheritedType); } inheritedTypeDecls.erase(knownInherited); } // Check all remaining inherited type declarations to determine if // this protocol has a non-associated-type type with the same name. inheritedTypeDecls.remove_if( [&](const std::pair> &inherited) { const auto name = inherited.first; for (auto found : proto->lookupDirect(name)) { // We only want concrete type declarations. auto type = dyn_cast(found); if (!type || isa(type)) continue; // Ignore nominal types. They're always invalid declarations. if (isa(type)) continue; // ... from the same module as the protocol. if (type->getModuleContext() != proto->getModuleContext()) continue; // Ignore types defined in constrained extensions; their equivalence // to the associated type would have to be conditional, which we cannot // model. if (auto ext = dyn_cast(type->getDeclContext())) { if (ext->isConstrainedExtension()) continue; } // We found something. bool shouldWarnAboutRedeclaration = emitDiagnostics; for (auto inheritedType : inherited.second) { // If we have inherited associated type... if (auto inheritedAssocTypeDecl = dyn_cast(inheritedType)) { // Infer a same-type requirement between the typealias' underlying // type and the inherited associated type. recordInheritedTypeRequirement(inheritedAssocTypeDecl, type); // Warn that one should use where clauses for this. if (shouldWarnAboutRedeclaration) { auto inheritedFromProto = inheritedAssocTypeDecl->getProtocol(); auto fixItWhere = getProtocolWhereLoc(); ctx.Diags.diagnose(type, diag::typealias_override_associated_type, name, inheritedFromProto->getDeclaredInterfaceType()) .fixItInsertAfter(fixItWhere.Loc, getConcreteTypeReq(type, fixItWhere.Item)) .fixItRemove(type->getSourceRange()); ctx.Diags.diagnose(inheritedAssocTypeDecl, diag::decl_declared_here, inheritedAssocTypeDecl->getName()); shouldWarnAboutRedeclaration = false; } continue; } // Two typealiases that should be the same. recordInheritedTypeRequirement(inheritedType, type); } // We can remove this entry. return true; } return false; }); // Infer same-type requirements among inherited type declarations. for (auto &entry : inheritedTypeDecls) { if (entry.second.size() < 2) continue; auto firstDecl = entry.second.front(); for (auto otherDecl : ArrayRef(entry.second).slice(1)) { recordInheritedTypeRequirement(firstDecl, otherDecl); } } return ctx.AllocateCopy(result); } ArrayRef ProtocolDependenciesRequest::evaluate(Evaluator &evaluator, ProtocolDecl *proto) const { auto &ctx = proto->getASTContext(); SmallVector result; // If we have a serialized requirement signature, deserialize it and // look at conformance requirements. // // FIXME: For now we just fall back to the GSB for all protocols // unless -requirement-machine-protocol-signatures=on is passed. if (proto->hasLazyRequirementSignature() || (ctx.LangOpts.RequirementMachineProtocolSignatures == RequirementMachineMode::Disabled)) { for (auto req : proto->getRequirementSignature()) { if (req.getKind() == RequirementKind::Conformance) { result.push_back(req.getProtocolDecl()); } } return ctx.AllocateCopy(result); } // Otherwise, we can't ask for the requirement signature, because // this request is used as part of *building* the requirement // signature. Look at the structural requirements instead. for (auto req : proto->getStructuralRequirements()) { if (req.req.getKind() == RequirementKind::Conformance) result.push_back(req.req.getProtocolDecl()); } return ctx.AllocateCopy(result); } // // Building rewrite rules from desugared requirements. // void RuleBuilder::addRequirements(ArrayRef requirements) { // Collect all protocols transitively referenced from these requirements. for (auto req : requirements) { if (req.getKind() == RequirementKind::Conformance) { addProtocol(req.getProtocolDecl(), /*initialComponent=*/false); } } collectRulesFromReferencedProtocols(); // Add rewrite rules for all top-level requirements. for (const auto &req : requirements) addRequirement(req, /*proto=*/nullptr); } void RuleBuilder::addRequirements(ArrayRef requirements) { // Collect all protocols transitively referenced from these requirements. for (auto req : requirements) { if (req.req.getKind() == RequirementKind::Conformance) { addProtocol(req.req.getProtocolDecl(), /*initialComponent=*/false); } } collectRulesFromReferencedProtocols(); // Add rewrite rules for all top-level requirements. for (const auto &req : requirements) addRequirement(req, /*proto=*/nullptr); } void RuleBuilder::addProtocols(ArrayRef protos) { // Collect all protocols transitively referenced from this connected component // of the protocol dependency graph. for (auto proto : protos) { addProtocol(proto, /*initialComponent=*/true); } collectRulesFromReferencedProtocols(); } /// For an associated type T in a protocol P, we add a rewrite rule: /// /// [P].T => [P:T] /// /// Intuitively, this means "if a type conforms to P, it has a nested type /// named T". void RuleBuilder::addAssociatedType(const AssociatedTypeDecl *type, const ProtocolDecl *proto) { MutableTerm lhs; lhs.add(Symbol::forProtocol(proto, Context)); lhs.add(Symbol::forName(type->getName(), Context)); MutableTerm rhs; rhs.add(Symbol::forAssociatedType(proto, type->getName(), Context)); PermanentRules.emplace_back(lhs, rhs); } /// Lowers a desugared generic requirement to a rewrite rule. /// /// If \p proto is null, this is a generic requirement from the top-level /// generic signature. The added rewrite rule will be rooted in a generic /// parameter symbol. /// /// If \p proto is non-null, this is a generic requirement in the protocol's /// requirement signature. The added rewrite rule will be rooted in a /// protocol symbol. std::pair swift::rewriting::getRuleForRequirement(const Requirement &req, const ProtocolDecl *proto, Optional> substitutions, RewriteContext &ctx) { assert(!substitutions.hasValue() || proto == nullptr && "Can't have both"); // Compute the left hand side. auto subjectType = CanType(req.getFirstType()); auto subjectTerm = (substitutions ? ctx.getRelativeTermForType( subjectType, *substitutions) : ctx.getMutableTermForType( subjectType, proto)); // Compute the right hand side. MutableTerm constraintTerm; switch (req.getKind()) { case RequirementKind::Conformance: { // A conformance requirement T : P becomes a rewrite rule // // T.[P] == T // // Intuitively, this means "any type ending with T conforms to P". auto *proto = req.getProtocolDecl(); constraintTerm = subjectTerm; constraintTerm.add(Symbol::forProtocol(proto, ctx)); break; } case RequirementKind::Superclass: { // A superclass requirement T : C becomes a rewrite rule // // T.[superclass: C] => T auto otherType = CanType(req.getSecondType()); // Build the symbol [superclass: C]. SmallVector result; otherType = (substitutions ? ctx.getRelativeSubstitutionSchemaFromType( otherType, *substitutions, result) : ctx.getSubstitutionSchemaFromType( otherType, proto, result)); auto superclassSymbol = Symbol::forSuperclass(otherType, result, ctx); // Build the term T.[superclass: C]. constraintTerm = subjectTerm; constraintTerm.add(superclassSymbol); break; } case RequirementKind::Layout: { // A layout requirement T : L becomes a rewrite rule // // T.[layout: L] == T constraintTerm = subjectTerm; constraintTerm.add(Symbol::forLayout(req.getLayoutConstraint(), ctx)); break; } case RequirementKind::SameType: { auto otherType = CanType(req.getSecondType()); if (!otherType->isTypeParameter()) { // A concrete same-type requirement T == C becomes a // rewrite rule // // T.[concrete: C] => T SmallVector result; otherType = (substitutions ? ctx.getRelativeSubstitutionSchemaFromType( otherType, *substitutions, result) : ctx.getSubstitutionSchemaFromType( otherType, proto, result)); constraintTerm = subjectTerm; constraintTerm.add(Symbol::forConcreteType(otherType, result, ctx)); break; } constraintTerm = (substitutions ? ctx.getRelativeTermForType( otherType, *substitutions) : ctx.getMutableTermForType( otherType, proto)); break; } } return std::make_pair(subjectTerm, constraintTerm); } void RuleBuilder::addRequirement(const Requirement &req, const ProtocolDecl *proto) { if (Dump) { llvm::dbgs() << "+ "; req.dump(llvm::dbgs()); llvm::dbgs() << "\n"; } RequirementRules.push_back( getRuleForRequirement(req, proto, /*substitutions=*/None, Context)); } void RuleBuilder::addRequirement(const StructuralRequirement &req, const ProtocolDecl *proto) { // FIXME: Preserve source location information for diagnostics. addRequirement(req.req.getCanonical(), proto); } /// Record information about a protocol if we have no seen it yet. void RuleBuilder::addProtocol(const ProtocolDecl *proto, bool initialComponent) { if (ProtocolMap.count(proto) > 0) return; ProtocolMap[proto] = initialComponent; Protocols.push_back(proto); } /// Compute the transitive closure of the set of all protocols referenced from /// the right hand sides of conformance requirements, and convert their /// requirements to rewrite rules. void RuleBuilder::collectRulesFromReferencedProtocols() { unsigned i = 0; while (i < Protocols.size()) { auto *proto = Protocols[i++]; for (auto *depProto : proto->getProtocolDependencies()) { addProtocol(depProto, /*initialComponent=*/false); } } // Add rewrite rules for each protocol. for (auto *proto : Protocols) { if (Dump) { llvm::dbgs() << "protocol " << proto->getName() << " {\n"; } // Add the identity conformance rule [P].[P] => [P]. MutableTerm lhs; lhs.add(Symbol::forProtocol(proto, Context)); lhs.add(Symbol::forProtocol(proto, Context)); MutableTerm rhs; rhs.add(Symbol::forProtocol(proto, Context)); PermanentRules.emplace_back(lhs, rhs); for (auto *assocType : proto->getAssociatedTypeMembers()) addAssociatedType(assocType, proto); for (auto *inheritedProto : Context.getInheritedProtocols(proto)) { for (auto *assocType : inheritedProto->getAssociatedTypeMembers()) addAssociatedType(assocType, proto); } // If this protocol is part of the initial connected component, we're // building requirement signatures for all protocols in this component, // and so we must start with the structural requirements. // // Otherwise, we should either already have a requirement signature, or // we can trigger the computation of the requirement signatures of the // next component recursively. if (ProtocolMap[proto]) { for (auto req : proto->getStructuralRequirements()) addRequirement(req, proto); for (auto req : proto->getTypeAliasRequirements()) addRequirement(req.getCanonical(), proto); } else { for (auto req : proto->getRequirementSignature()) addRequirement(req.getCanonical(), proto); } if (Dump) { llvm::dbgs() << "}\n"; } } }