//===--- ParameterPack.cpp - Utilities for variadic generics --------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2022 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements utilities for substituting type parameter packs // appearing in pack expansion types. // //===----------------------------------------------------------------------===// #include "swift/AST/ASTContext.h" #include "swift/AST/Decl.h" #include "swift/AST/GenericEnvironment.h" #include "swift/AST/GenericParamList.h" #include "swift/AST/ParameterList.h" #include "swift/AST/Type.h" #include "swift/AST/Types.h" #include "swift/Basic/Assertions.h" #include "llvm/ADT/SmallVector.h" using namespace swift; /// FV(PackExpansionType(Pattern, Count), N) = FV(Pattern, N+1) /// FV(PackElementType(Param, M), N) = FV(Param, 0) if M >= N, {} otherwise /// FV(Param, N) = {Param} static Type transformTypeParameterPacksRec( Type t, llvm::function_ref(SubstitutableType *)> fn, unsigned expansionLevel) { return t.transformWithPosition( TypePosition::Invariant, [&](TypeBase *t, TypePosition p) -> std::optional { // If we're already inside N levels of PackExpansionType, and we're // walking into another PackExpansionType, a type parameter pack // reference now needs level (N+1) to be free. if (auto *expansionType = dyn_cast(t)) { auto countType = expansionType->getCountType(); auto patternType = expansionType->getPatternType(); auto newPatternType = transformTypeParameterPacksRec( patternType, fn, expansionLevel + 1); if (patternType.getPointer() != newPatternType.getPointer()) return Type(PackExpansionType::get(patternType, countType)); return Type(expansionType); } // A PackElementType with level N reaches past N levels of // nested PackExpansionType. So a type parameter pack reference // therein is free if N is greater than or equal to our current // expansion level. if (auto *eltType = dyn_cast(t)) { if (eltType->getLevel() >= expansionLevel) { return transformTypeParameterPacksRec(eltType->getPackType(), fn, /*expansionLevel=*/0); } return Type(eltType); } // A bare type parameter pack is like a PackElementType with level 0. if (auto *paramType = dyn_cast(t)) { if (expansionLevel == 0 && (isa(paramType) || paramType->isRootParameterPack())) { return fn(paramType); } return Type(paramType); } return std::nullopt; }); } Type Type::transformTypeParameterPacks( llvm::function_ref(SubstitutableType *)> fn) const { return transformTypeParameterPacksRec(*this, fn, /*expansionLevel=*/0); } namespace { /// Collects all unique pack type parameters referenced from the pattern type, /// skipping those captured by nested pack expansion types. struct PackReferenceCollector: TypeWalker { llvm::function_ref fn; unsigned expansionLevel; SmallVector elementLevel; PackReferenceCollector(llvm::function_ref fn) : fn(fn), expansionLevel(0) { elementLevel.push_back(0); } Action walkToTypePre(Type t) override { if (t->is()) { ++expansionLevel; return Action::Continue; } if (auto *boundGenericType = dyn_cast(t.getPointer())) { if (auto parentType = boundGenericType->getParent()) parentType.walk(*this); for (auto type : boundGenericType->getExpandedGenericArgs()) type.walk(*this); return Action::SkipNode; } if (auto *typeAliasType = dyn_cast(t.getPointer())) { if (typeAliasType->getDecl()->isGeneric()) { if (auto parentType = typeAliasType->getParent()) parentType.walk(*this); for (auto type : typeAliasType->getExpandedGenericArgs()) type.walk(*this); return Action::SkipNode; } } if (auto *eltType = t->getAs()) { elementLevel.push_back(eltType->getLevel()); return Action::Continue; } if (elementLevel.back() == expansionLevel) { if (fn(t)) return Action::Stop; } return Action::Continue; } Action walkToTypePost(Type t) override { if (t->is()) --expansionLevel; if (t->is()) elementLevel.pop_back(); return Action::Continue; } }; } void TypeBase::walkPackReferences( llvm::function_ref fn) { Type(this).walk(PackReferenceCollector(fn)); } void TypeBase::getTypeParameterPacks( SmallVectorImpl &rootParameterPacks) { walkPackReferences([&](Type t) { if (auto *paramTy = t->getAs()) { if (paramTy->isParameterPack()) rootParameterPacks.push_back(paramTy); } else if (auto *archetypeTy = t->getAs()) { if (archetypeTy->isRoot()) { rootParameterPacks.push_back(archetypeTy); } else { auto *genericEnv = archetypeTy->getGenericEnvironment(); auto paramTy = archetypeTy->getInterfaceType()->getRootGenericParam(); rootParameterPacks.push_back( genericEnv->mapTypeIntoContext(paramTy)); } } return false; }); } bool TypeBase::isParameterPack() { return getDependentMemberRoot()->isRootParameterPack(); } bool TypeBase::isRootParameterPack() { Type t(this); return t->is() && t->castTo()->isParameterPack(); } PackType *TypeBase::getPackSubstitutionAsPackType() { if (auto pack = getAs()) { return pack; } else { return PackType::getSingletonPackExpansion(this); } } static Type increasePackElementLevelImpl( Type type, unsigned level, unsigned outerLevel) { assert(level > 0); return type.transformRec([&](TypeBase *t) -> std::optional { if (auto *elementType = dyn_cast(t)) { if (elementType->getLevel() >= outerLevel) { elementType = PackElementType::get(elementType->getPackType(), elementType->getLevel() + level); } return Type(elementType); } if (auto *expansionType = dyn_cast(t)) { return Type(PackExpansionType::get( increasePackElementLevelImpl(expansionType->getPatternType(), level, outerLevel + 1), expansionType->getCountType())); } if (t->isParameterPack() || isa(t)) { if (outerLevel == 0) return Type(PackElementType::get(t, level)); return Type(t); } return std::nullopt; }); } Type TypeBase::increasePackElementLevel(unsigned level) { if (level == 0) return Type(this); return increasePackElementLevelImpl(Type(this), level, 0); } CanType PackExpansionType::getReducedShape() { auto reducedShape = countType->getReducedShape(); if (reducedShape == getASTContext().TheEmptyTupleType) return reducedShape; return CanType(PackExpansionType::get(reducedShape, reducedShape)); } unsigned TupleType::getNumScalarElements() const { unsigned n = 0; for (auto elt : getElements()) { if (!elt.getType()->is()) ++n; } return n; } bool TupleType::containsPackExpansionType() const { assert(!hasTypeVariable()); for (auto elt : getElements()) { auto eltTy = elt.getType(); assert(!eltTy->hasTypeVariable()); if (eltTy->is()) return true; } return false; } bool CanTupleType::containsPackExpansionTypeImpl(CanTupleType tuple) { for (auto eltType : tuple.getElementTypes()) { if (isa(eltType)) return true; } return false; } bool AnyFunctionType::containsPackExpansionType(ArrayRef params) { for (auto param : params) { auto paramTy = param.getPlainType(); assert(!paramTy->hasTypeVariable()); if (paramTy->is()) return true; } return false; } bool PackType::containsPackExpansionType() const { for (auto type : getElementTypes()) { if (type->is()) return true; } return false; } template static CanPackType getReducedShapeOfPack(const ASTContext &ctx, const T &elementTypes) { SmallVector elts; elts.reserve(elementTypes.size()); for (auto elt : elementTypes) { // T... => shape(T)... if (auto *packExpansionType = elt->template getAs()) { elts.push_back(packExpansionType->getReducedShape()); continue; } // Use () as a placeholder for scalar shape. assert(!elt->template is() && "Pack archetype outside of a pack expansion"); elts.push_back(ctx.TheEmptyTupleType); } return CanPackType(PackType::get(ctx, elts)); } CanPackType PackType::getReducedShape() { return getReducedShapeOfPack(getASTContext(), getElementTypes()); } CanPackType SILPackType::getReducedShape() const { return getReducedShapeOfPack(getASTContext(), getElementTypes()); } CanType TypeBase::getReducedShape() { if (isTypeParameter()) { auto *genericParam = getRootGenericParam(); if (genericParam->isParameterPack()) return genericParam->getCanonicalType(); // Use () as a placeholder for scalar shape. return getASTContext().TheEmptyTupleType; } if (auto *packArchetype = getAs()) return packArchetype->getReducedShape(); if (auto *packType = getAs()) return packType->getReducedShape(); if (auto *expansionType = getAs()) return expansionType->getReducedShape(); if (auto *silPackType = getAs()) { auto can = cast(silPackType->getCanonicalType()); return can->getReducedShape(); } SmallVector rootParameterPacks; getTypeParameterPacks(rootParameterPacks); if (!rootParameterPacks.empty()) return rootParameterPacks.front()->getReducedShape(); assert(!isTypeVariableOrMember()); assert(!hasTypeParameter()); // Use () as a placeholder for scalar shape. return getASTContext().TheEmptyTupleType; } unsigned ParameterList::getOrigParamIndex(SubstitutionMap subMap, unsigned substIndex) const { unsigned remappedIndex = substIndex; for (unsigned i = 0, e = size(); i < e; ++i) { auto *param = get(i); auto paramType = param->getInterfaceType(); unsigned substCount = 1; if (auto *packExpansionType = paramType->getAs()) { auto replacementType = packExpansionType->getCountType().subst(subMap); if (auto *packType = replacementType->getAs()) { substCount = packType->getNumElements(); } } if (remappedIndex < substCount) return i; remappedIndex -= substCount; } llvm::errs() << "Invalid substituted argument index: " << substIndex << "\n"; subMap.dump(llvm::errs()); dump(llvm::errs()); abort(); } /// Foo => Pack{T..., Int, String} SmallVector BoundGenericType::getExpandedGenericArgs() { // It would be nicer to use genericSig.getInnermostGenericParams() here, // but that triggers a request cycle if we're in the middle of computing // the generic signature already. SmallVector params; for (auto *paramDecl : getDecl()->getGenericParams()->getParams()) { params.push_back(paramDecl->getDeclaredInterfaceType() ->castTo()); } return PackType::getExpandedGenericArgs(params, getGenericArgs()); } /// Foo => Pack{T..., Int, String} SmallVector TypeAliasType::getExpandedGenericArgs() { if (!getDecl()->isGeneric()) return SmallVector(); auto genericSig = getGenericSignature(); return PackType::getExpandedGenericArgs( genericSig.getInnermostGenericParams(), getDirectGenericArgs()); } /// Pack{T, Pack{Int, String}} => {T..., Int, String} SmallVector PackType::getExpandedGenericArgs(ArrayRef params, ArrayRef args) { SmallVector wrappedArgs; assert(params.size() == args.size()); for (unsigned i = 0, e = params.size(); i < e; ++i) { auto arg = args[i]; if (params[i]->isParameterPack()) { // FIXME: A temporary fix to make it possible to debug expressions // with partially resolved variadic generic types. The issue stems // from the fact that `BoundGenericType` is allowed to have pack // parameters directly represented by type variables, as soon as // that is no longer the case this check should be removed. if (arg->is()) { wrappedArgs.push_back(arg); } else { auto argPackElements = arg->castTo()->getElementTypes(); wrappedArgs.append(argPackElements.begin(), argPackElements.end()); } continue; } wrappedArgs.push_back(arg); } return wrappedArgs; } PackType *PackType::getSingletonPackExpansion(Type param) { SmallVector rootParameterPacks; param->getTypeParameterPacks(rootParameterPacks); assert(rootParameterPacks.size() >= 1); auto count = rootParameterPacks[0]; return get(param->getASTContext(), {PackExpansionType::get(param, count)}); } CanPackType CanPackType::getSingletonPackExpansion(CanType param) { // Note: You can't just wrap the result in CanPackType() here because // PackExpansionType has the additional requirement that the count type // must be a reduced shape. return cast( PackType::getSingletonPackExpansion(param) ->getCanonicalType()); } PackExpansionType *PackType::unwrapSingletonPackExpansion() const { if (getNumElements() == 1) { if (auto expansion = getElementTypes()[0]->getAs()) { auto pattern = expansion->getPatternType(); if (pattern->isParameterPack() || pattern->is()) return expansion; } } return nullptr; } bool SILPackType::containsPackExpansionType() const { for (auto type : getElementTypes()) { if (isa(type)) return true; } return false; } CanPackType CanTupleType::getInducedPackTypeImpl(CanTupleType tuple) { return getInducedPackTypeImpl(tuple, 0, tuple->getNumElements()); } CanPackType CanTupleType::getInducedPackTypeImpl(CanTupleType tuple, unsigned start, unsigned count) { assert(start + count <= tuple->getNumElements() && "range out of range"); auto &ctx = tuple->getASTContext(); return CanPackType::get(ctx, tuple.getElementTypes().slice(start, count)); } static CanType getApproximateFormalElementType(const ASTContext &ctx, CanType loweredEltType) { CanType formalEltType = TupleType::getEmpty(ctx); if (auto expansion = dyn_cast(loweredEltType)) formalEltType = CanPackExpansionType::get(formalEltType, expansion.getCountType()); return formalEltType; } template static CanPackType getApproximateFormalPackType(const ASTContext &ctx, Collection loweredEltTypes) { // Build an array of formal element types, but be lazy about it: // use the original array unless we see an element type that doesn't // work as a legal formal type. std::optional> formalEltTypes; for (auto i : indices(loweredEltTypes)) { auto loweredEltType = loweredEltTypes[i]; bool isLegal = loweredEltType->isLegalFormalType(); // If the type isn't legal as a formal type, substitute the empty // tuple type (or an invariant expansion of it over the count type). CanType formalEltType = loweredEltType; if (!isLegal) { formalEltType = getApproximateFormalElementType(ctx, loweredEltType); } // If we're already building an array, unconditionally append to it. // Otherwise, if the type isn't legal, build the array up to this // point and then append. Otherwise, we're still being lazy. if (formalEltTypes) { formalEltTypes->push_back(formalEltType); } else if (!isLegal) { formalEltTypes.emplace(); formalEltTypes->reserve(loweredEltTypes.size()); formalEltTypes->append(loweredEltTypes.begin(), loweredEltTypes.begin() + i); formalEltTypes->push_back(formalEltType); } assert(isLegal || formalEltTypes.has_value()); } // Use the array we built if we made one (if we ever saw a non-legal // element type). if (formalEltTypes) { return CanPackType::get(ctx, *formalEltTypes); } else { return CanPackType::get(ctx, loweredEltTypes); } } CanPackType SILPackType::getApproximateFormalPackType() const { return ::getApproximateFormalPackType(getASTContext(), getElementTypes()); } CanPackType CanTupleType::getInducedApproximateFormalPackTypeImpl(CanTupleType tuple) { return ::getApproximateFormalPackType(tuple->getASTContext(), tuple.getElementTypes()); }