//===--- PackConformance.cpp - Variadic Protocol Conformance --------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2022 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements the PackConformance structure, which describes the // conformance of a type pack parameter to a protocol. // //===----------------------------------------------------------------------===// #include "swift/AST/PackConformance.h" #include "swift/AST/ASTContext.h" #include "swift/AST/Decl.h" #include "swift/AST/Module.h" #include "swift/AST/Types.h" #define DEBUG_TYPE "AST" using namespace swift; void PackConformance::Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, ConformingType, Protocol, getPatternConformances()); } void PackConformance::Profile(llvm::FoldingSetNodeID &ID, PackType *conformingType, ProtocolDecl *protocol, ArrayRef conformances) { ID.AddPointer(conformingType); ID.AddPointer(protocol); for (auto conformance : conformances) ID.AddPointer(conformance.getOpaqueValue()); } PackConformance::PackConformance(PackType *conformingType, ProtocolDecl *protocol, ArrayRef conformances) : ConformingType(conformingType), Protocol(protocol) { assert(ConformingType->getNumElements() == conformances.size()); std::uninitialized_copy(conformances.begin(), conformances.end(), getTrailingObjects()); } size_t PackConformance::numTrailingObjects( OverloadToken) const { return ConformingType->getNumElements(); } ArrayRef PackConformance::getPatternConformances() const { return {getTrailingObjects(), ConformingType->getNumElements()}; } bool PackConformance::isCanonical() const { if (!ConformingType->isCanonical()) return false; for (auto conformance : getPatternConformances()) if (!conformance.isCanonical()) return false; return true; } PackConformance *PackConformance::getCanonicalConformance() const { if (isCanonical()) return const_cast(this); SmallVector conformances; for (auto conformance : getPatternConformances()) conformances.push_back(conformance.getCanonicalConformanceRef()); auto canonical = PackConformance::get( cast(ConformingType->getCanonicalType()), Protocol, conformances); assert(canonical->isCanonical()); return canonical; } /// Project the corresponding associated type from each pack element /// of the conforming type, collecting the results into a new pack type /// that has the same pack expansion structure as the conforming type. PackType *PackConformance::getAssociatedType(Type assocType) const { SmallVector packElements; auto conformances = getPatternConformances(); for (unsigned i : indices(conformances)) { auto packElement = ConformingType->getElementType(i); // If the pack element is a pack expansion, build a new pack expansion // with the same count type as the original element, and the pattern // type replaced with the associated type witness from the pattern // conformance. if (auto *packExpansion = packElement->getAs()) { auto assocTypePattern = conformances[i].getAssociatedType(packExpansion->getPatternType(), assocType); packElements.push_back(PackExpansionType::get( assocTypePattern, packExpansion->getCountType())); // If the pack element is a scalar type, replace the scalar type with // the associated type witness from the pattern conformance. } else { auto assocTypeScalar = conformances[i].getAssociatedType(packElement, assocType); packElements.push_back(assocTypeScalar); } } return PackType::get(Protocol->getASTContext(), packElements); } /// Project the corresponding associated conformance from each pack element /// of the conforming type, collecting the results into a new pack conformnace /// whose conforming type has the same pack expansion structure as our /// conforming type. PackConformance *PackConformance::getAssociatedConformance( Type assocType, ProtocolDecl *protocol) const { SmallVector packElements; SmallVector packConformances; auto conformances = getPatternConformances(); for (unsigned i : indices(conformances)) { auto packElement = ConformingType->getElementType(i); if (auto *packExpansion = packElement->getAs()) { auto assocTypePattern = conformances[i].getAssociatedType(packExpansion->getPatternType(), assocType); packElements.push_back(PackExpansionType::get( assocTypePattern, packExpansion->getCountType())); auto assocConformancePattern = conformances[i].getAssociatedConformance(packExpansion->getPatternType(), assocType, protocol); packConformances.push_back(assocConformancePattern); } else { auto assocTypeScalar = conformances[i].getAssociatedType(packElement, assocType); packElements.push_back(assocTypeScalar); auto assocConformanceScalar = conformances[i].getAssociatedConformance(packElement, assocType, protocol); packConformances.push_back(assocConformanceScalar); } } auto conformingType = PackType::get(Protocol->getASTContext(), packElements); return PackConformance::get(conformingType, protocol, packConformances); } ProtocolConformanceRef PackConformance::subst(SubstitutionMap subMap, SubstOptions options) const { return subst(QuerySubstitutionMap{subMap}, LookUpConformanceInSubstitutionMap(subMap), options); } // TODO: Move this elsewhere since it's generally useful static bool arePackShapesEqual(PackType *lhs, PackType *rhs) { if (lhs->getNumElements() != rhs->getNumElements()) return false; for (unsigned i = 0, e = lhs->getNumElements(); i < e; ++i) { auto lhsElt = lhs->getElementType(i); auto rhsElt = rhs->getElementType(i); if (lhsElt->is() != rhsElt->is()) return false; } return true; } static bool isRootParameterPack(Type t) { if (auto *paramTy = t->getAs()) { return paramTy->isParameterPack(); } else if (auto *archetypeTy = t->getAs()) { return archetypeTy->isRoot(); } return false; } static bool isRootedInParameterPack(Type t) { if (auto *archetypeTy = t->getAs()) { return true; } return t->getRootGenericParam()->isParameterPack(); } namespace { template class PackExpander { protected: TypeSubstitutionFn subs; LookupConformanceFn conformances; SubstOptions options; PackExpander(TypeSubstitutionFn subs, LookupConformanceFn conformances, SubstOptions options) : subs(subs), conformances(conformances), options(options) {} ImplClass *asImpl() { return static_cast(this); } /// We're replacing a pack expansion type with a pack -- flatten the pack /// using the pack expansion's pattern. void addExpandedExpansion(Type origPatternType, PackType *expandedCountType, unsigned i) { // Get all pack parameters referenced from the pattern. SmallVector rootParameterPacks; origPatternType->getTypeParameterPacks(rootParameterPacks); // Each pack parameter referenced from the pattern must be replaced // with a pack type, and all pack types must have the same shape as // the expanded count pack type. llvm::SmallDenseMap expandedPacks; for (auto origParamType : rootParameterPacks) { auto substParamType = origParamType.subst(subs, conformances, options); if (auto expandedParamType = substParamType->template getAs()) { assert(arePackShapesEqual(expandedParamType, expandedCountType) && "TODO: Return an invalid conformance if this fails"); auto inserted = expandedPacks.insert( std::make_pair(origParamType->getCanonicalType(), expandedParamType)).second; assert(inserted && "getTypeParameterPacks() should not return duplicates"); } else { assert(false && "TODO: Return an invalid conformance if this fails"); } } // For each element of the expanded count, compute the substituted // pattern type. for (unsigned j = 0, ee = expandedCountType->getNumElements(); j < ee; ++j) { auto projectedSubs = [&](SubstitutableType *type) -> Type { // Nested sequence archetypes get passed in here, but we must // handle them via the standard nested type path. if (auto *archetypeType = dyn_cast(type)) { if (!archetypeType->isRoot()) return Type(); } // Compute the substituted type using our parent substitutions. auto substType = Type(type).subst(subs, conformances, options); // If the substituted type is a pack, project the jth element. if (isRootParameterPack(type)) { // FIXME: What if you have something like G... where G<> is // variadic? assert(substType->template is() && "TODO: Return an invalid conformance if this fails"); auto *packType = substType->template castTo(); assert(arePackShapesEqual(packType, expandedCountType) && "TODO: Return an invalid conformance if this fails"); return packType->getElementType(j); } return subs(type); }; auto projectedConformances = [&](CanType origType, Type substType, ProtocolDecl *proto) -> ProtocolConformanceRef { auto substConformance = conformances(origType, substType, proto); // If the substituted conformance is a pack, project the jth element. if (isRootedInParameterPack(origType)) { return substConformance.getPack()->getPatternConformances()[j]; } return substConformance; }; auto origCountElement = expandedCountType->getElementType(j); auto substCountElement = origCountElement.subst( projectedSubs, projectedConformances, options); asImpl()->add(origCountElement, substCountElement, i); } } /// A pack expansion remains unexpanded, so we substitute the pattern and /// form a new pack expansion. void addUnexpandedExpansion(Type origPatternType, Type substCountType, unsigned i) { auto substPatternType = origPatternType.subst(subs, conformances, options); auto substExpansion = PackExpansionType::get(substPatternType, substCountType); asImpl()->add(origPatternType, substExpansion, i); } /// Scalar elements of the original pack are substituted and added to the /// flattened pack. void addScalar(Type origElement, unsigned i) { auto substElement = origElement.subst(subs, conformances, options); asImpl()->add(origElement, substElement, i); } /// Potentially expand an element of the original pack. void maybeExpandExpansion(PackExpansionType *origExpansion, unsigned i) { auto origPatternType = origExpansion->getPatternType(); auto origCountType = origExpansion->getCountType(); auto substCountType = origCountType.subst(subs, conformances, options); // If the substituted count type is a pack, we're expanding the // original element. if (auto *expandedCountType = substCountType->template getAs()) { addExpandedExpansion(origPatternType, expandedCountType, i); return; } addUnexpandedExpansion(origPatternType, substCountType, i); } public: void expand(PackType *origPackType) { for (unsigned i = 0, e = origPackType->getNumElements(); i < e; ++i) { auto origElement = origPackType->getElementType(i); // Check if the original element is potentially being expanded. if (auto *origExpansion = origElement->getAs()) { maybeExpandExpansion(origExpansion, i); continue; } addScalar(origElement, i); } } }; class PackConformanceExpander : public PackExpander { public: SmallVector substElements; SmallVector substConformances; ArrayRef origConformances; PackConformanceExpander(TypeSubstitutionFn subs, LookupConformanceFn conformances, SubstOptions options, ArrayRef origConformances) : PackExpander(subs, conformances, options), origConformances(origConformances) {} void add(Type origType, Type substType, unsigned i) { substElements.push_back(substType); // FIXME: Pass down projection callbacks substConformances.push_back(origConformances[i].subst( origType, subs, conformances, options)); } }; } ProtocolConformanceRef PackConformance::subst(TypeSubstitutionFn subs, LookupConformanceFn conformances, SubstOptions options) const { PackConformanceExpander expander(subs, conformances, options, getPatternConformances()); expander.expand(ConformingType); auto &ctx = Protocol->getASTContext(); auto *substConformingType = PackType::get(ctx, expander.substElements); auto substConformance = PackConformance::get(substConformingType, Protocol, expander.substConformances); return ProtocolConformanceRef(substConformance); } void swift::simple_display(llvm::raw_ostream &out, PackConformance *conformance) { out << conformance->getType() << " : {"; bool first = true; for (auto patternConformance : conformance->getPatternConformances()) { if (first) { out << ", "; first = false; } simple_display(out, patternConformance); } out << "}"; }