diff --git a/include/swift/AST/InFlightSubstitution.h b/include/swift/AST/InFlightSubstitution.h index 1317b1508ce..bda9dd5d561 100644 --- a/include/swift/AST/InFlightSubstitution.h +++ b/include/swift/AST/InFlightSubstitution.h @@ -32,6 +32,12 @@ class InFlightSubstitution { TypeSubstitutionFn BaselineSubstType; LookupConformanceFn BaselineLookupConformance; + struct ActivePackExpansion { + bool isSubstExpansion = false; + unsigned expansionIndex = 0; + }; + SmallVector ActivePackExpansions; + public: InFlightSubstitution(TypeSubstitutionFn substType, LookupConformanceFn lookupConformance, @@ -43,16 +49,72 @@ public: InFlightSubstitution(const InFlightSubstitution &) = delete; InFlightSubstitution &operator=(const InFlightSubstitution &) = delete; - Type substType(SubstitutableType *ty) { - return BaselineSubstType(ty); - } + // TODO: when we add PackElementType, we should recognize it during + // substitution and either call different methods on this class or + // pass an extra argument for the pack-expansion depth D. We should + // be able to rely on that to mark a pack-element reference instead + // of checking whether the original type was a pack. Substitution + // should use the D'th entry from the end of ActivePackExpansions to + // guide the element substitution: + // - project the given index of the pack substitution + // - wrap it in a PackElementType if it's a subst expansion + // - the depth of that PackElementType is the number of subst + // expansions between the depth entry and the end of + // ActivePackExpansions + /// Perform primitive substitution on the given type. Returns Type() + /// if the type should not be substituted as a whole. + Type substType(SubstitutableType *origType); + + /// Perform primitive conformance lookup on the given type. ProtocolConformanceRef lookupConformance(CanType dependentType, Type conformingReplacementType, - ProtocolDecl *conformedProtocol) { - return BaselineLookupConformance(dependentType, - conformingReplacementType, - conformedProtocol); + ProtocolDecl *conformedProtocol); + + /// Given the shape type of a pack expansion, invoke the given callback + /// for each expanded component of it. If the substituted component + /// is an expansion component, the desired shape of that expansion + /// is passed as the argument; otherwise, the argument is Type(). + /// In either case, an active expansion is entered on this IFS for + /// the duration of the call to handleComponent, and subsequent + /// pack-element type references will substitute to the corresponding + /// element of the substitution of the pack. + void expandPackExpansionShape(Type origShape, + llvm::function_ref handleComponent); + + /// Call the given function for each expanded component type of the + /// given pack expansion type. The function will be invoked with the + /// active expansion still active. + void expandPackExpansionType(PackExpansionType *origExpansionType, + llvm::function_ref handleComponentType) { + expandPackExpansionShape(origExpansionType->getCountType(), + [&](Type substComponentShape) { + auto origPatternType = origExpansionType->getPatternType(); + auto substEltType = origPatternType.subst(*this); + + auto substComponentType = + (substComponentShape + ? PackExpansionType::get(substEltType, substComponentShape) + : substEltType); + handleComponentType(substComponentType); + }); + } + + /// Return a list of component types that the pack expansion expands to. + SmallVector + expandPackExpansionType(PackExpansionType *origExpansionType) { + SmallVector substComponentTypes; + expandPackExpansionType(origExpansionType, substComponentTypes); + return substComponentTypes; + } + + /// Expand the list of component types that the pack expansion expands + /// to into the given array. + void expandPackExpansionType(PackExpansionType *origExpansionType, + SmallVectorImpl &substComponentTypes) { + expandPackExpansionType(origExpansionType, [&](Type substComponentType) { + substComponentTypes.push_back(substComponentType); + }); } class OptionsAdjustmentScope { diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 176a1b2c185..b8213f4c125 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -3257,6 +3257,12 @@ CanPackExpansionType::get(CanType patternType, CanType countType) { } PackExpansionType *PackExpansionType::get(Type patternType, Type countType) { + assert(!patternType->is()); + assert(!countType->is()); + // FIXME: stop doing this deliberately in PackExpansionMatcher + //assert(!patternType->is()); + //assert(!countType->is()); + auto properties = patternType->getRecursiveProperties(); properties |= countType->getRecursiveProperties(); diff --git a/lib/AST/PackConformance.cpp b/lib/AST/PackConformance.cpp index e4006e6c8ae..93662faea8e 100644 --- a/lib/AST/PackConformance.cpp +++ b/lib/AST/PackConformance.cpp @@ -167,207 +167,65 @@ ProtocolConformanceRef PackConformance::subst(SubstitutionMap subMap, return subst(IFS); } -// TODO: Move this elsewhere since it's generally useful -static bool arePackShapesEqual(PackType *lhs, PackType *rhs) { - if (lhs->getNumElements() != rhs->getNumElements()) - return false; - - for (unsigned i = 0, e = lhs->getNumElements(); i < e; ++i) { - auto lhsElt = lhs->getElementType(i); - auto rhsElt = rhs->getElementType(i); - - if (lhsElt->is() != rhsElt->is()) - return false; - } - - return true; -} - -static bool isRootParameterPack(Type t) { - if (auto *paramTy = t->getAs()) { - return paramTy->isParameterPack(); - } else if (auto *archetypeTy = t->getAs()) { - return archetypeTy->isRoot(); - } - - return false; -} - -static bool isRootedInParameterPack(Type t) { - if (auto *archetypeTy = t->getAs()) { - return true; - } - - return t->getRootGenericParam()->isParameterPack(); -} - namespace { -template -class PackExpander { -protected: +struct PackConformanceExpander { InFlightSubstitution &IFS; + ArrayRef origConformances; - PackExpander(InFlightSubstitution &IFS) : IFS(IFS) {} +public: + // Results built up by the expansion. + SmallVector substElementTypes; + SmallVector substConformances; - ImplClass *asImpl() { - return static_cast(this); + PackConformanceExpander(InFlightSubstitution &IFS, + ArrayRef origConformances) + : IFS(IFS), origConformances(origConformances) {} + +private: + /// Substitute a scalar element of the original pack. + void substScalar(Type origElementType, + ProtocolConformanceRef origConformance) { + auto substElementType = origElementType.subst(IFS); + auto substConformance = origConformance.subst(origElementType, IFS); + + substElementTypes.push_back(substElementType); + substConformances.push_back(substConformance); } - /// We're replacing a pack expansion type with a pack -- flatten the pack - /// using the pack expansion's pattern. - void addExpandedExpansion(Type origPatternType, PackType *expandedCountType, - unsigned i) { + /// Substitute and expand an expansion element of the original pack. + void substExpansion(PackExpansionType *origExpansionType, + ProtocolConformanceRef origConformance) { + IFS.expandPackExpansionType(origExpansionType, + [&](Type substComponentType) { + auto origPatternType = origExpansionType->getPatternType(); - // Get all pack parameters referenced from the pattern. - SmallVector rootParameterPacks; - origPatternType->getTypeParameterPacks(rootParameterPacks); + // Just substitute the conformance. We don't directly represent + // pack expansion conformances here; it's sort of implicit in the + // corresponding pack element type. + auto substConformance = origConformance.subst(origPatternType, IFS); - // Each pack parameter referenced from the pattern must be replaced - // with a pack type, and all pack types must have the same shape as - // the expanded count pack type. - llvm::SmallDenseMap expandedPacks; - for (auto origParamType : rootParameterPacks) { - auto substParamType = origParamType.subst(IFS); - - if (auto expandedParamType = substParamType->template getAs()) { - assert(arePackShapesEqual(expandedParamType, expandedCountType) && - "TODO: Return an invalid conformance if this fails"); - - auto inserted = expandedPacks.insert( - std::make_pair(origParamType->getCanonicalType(), - expandedParamType)).second; - assert(inserted && - "getTypeParameterPacks() should not return duplicates"); - } else { - assert(false && - "TODO: Return an invalid conformance if this fails"); - } - } - - // For each element of the expanded count, compute the substituted - // pattern type. - for (unsigned j = 0, ee = expandedCountType->getNumElements(); j < ee; ++j) { - auto projectedSubs = [&](SubstitutableType *type) -> Type { - // Nested sequence archetypes get passed in here, but we must - // handle them via the standard nested type path. - if (auto *archetypeType = dyn_cast(type)) { - if (!archetypeType->isRoot()) - return Type(); - } - - // Compute the substituted type using our parent substitutions. - auto substType = Type(type).subst(IFS); - - // If the substituted type is a pack, project the jth element. - if (isRootParameterPack(type)) { - // FIXME: What if you have something like G... where G<> is - // variadic? - assert(substType->template is() && - "TODO: Return an invalid conformance if this fails"); - auto *packType = substType->template castTo(); - assert(arePackShapesEqual(packType, expandedCountType) && - "TODO: Return an invalid conformance if this fails"); - - return packType->getElementType(j); - } - - return IFS.substType(type); - }; - - auto projectedConformances = [&](CanType origType, Type substType, - ProtocolDecl *proto) -> ProtocolConformanceRef { - auto substConformance = - IFS.lookupConformance(origType, substType, proto); - - // If the substituted conformance is a pack, project the jth element. - if (isRootedInParameterPack(origType)) { - return substConformance.getPack()->getPatternConformances()[j]; - } - - return substConformance; - }; - - auto origCountElement = expandedCountType->getElementType(j); - auto substCountElement = origCountElement.subst( - projectedSubs, projectedConformances, IFS.getOptions()); - - asImpl()->add(origCountElement, substCountElement, i); - } - } - - /// A pack expansion remains unexpanded, so we substitute the pattern and - /// form a new pack expansion. - void addUnexpandedExpansion(Type origPatternType, Type substCountType, - unsigned i) { - auto substPatternType = origPatternType.subst(IFS); - auto substExpansion = PackExpansionType::get(substPatternType, substCountType); - - asImpl()->add(origPatternType, substExpansion, i); - } - - /// Scalar elements of the original pack are substituted and added to the - /// flattened pack. - void addScalar(Type origElement, unsigned i) { - auto substElement = origElement.subst(IFS); - - asImpl()->add(origElement, substElement, i); - } - - /// Potentially expand an element of the original pack. - void maybeExpandExpansion(PackExpansionType *origExpansion, unsigned i) { - auto origPatternType = origExpansion->getPatternType(); - auto origCountType = origExpansion->getCountType(); - - auto substCountType = origCountType.subst(IFS); - - // If the substituted count type is a pack, we're expanding the - // original element. - if (auto *expandedCountType = substCountType->template getAs()) { - addExpandedExpansion(origPatternType, expandedCountType, i); - return; - } - - addUnexpandedExpansion(origPatternType, substCountType, i); + substElementTypes.push_back(substComponentType); + substConformances.push_back(substConformance); + }); } public: void expand(PackType *origPackType) { - for (unsigned i = 0, e = origPackType->getNumElements(); i < e; ++i) { - auto origElement = origPackType->getElementType(i); + assert(origPackType->getNumElements() == origConformances.size()); - // Check if the original element is potentially being expanded. - if (auto *origExpansion = origElement->getAs()) { - maybeExpandExpansion(origExpansion, i); - continue; + for (auto i : range(origPackType->getNumElements())) { + auto origElementType = origPackType->getElementType(i); + if (auto *origExpansion = origElementType->getAs()) { + substExpansion(origExpansion, origConformances[i]); + } else { + substScalar(origElementType, origConformances[i]); } - - addScalar(origElement, i); } } }; -class PackConformanceExpander : public PackExpander { -public: - SmallVector substElements; - SmallVector substConformances; - - ArrayRef origConformances; - - PackConformanceExpander(InFlightSubstitution &IFS, - ArrayRef origConformances) - : PackExpander(IFS), origConformances(origConformances) {} - - void add(Type origType, Type substType, unsigned i) { - substElements.push_back(substType); - - // FIXME: Pass down projection callbacks - substConformances.push_back(origConformances[i].subst( - origType, IFS)); - } -}; - -} +} // end anonymous namespace ProtocolConformanceRef PackConformance::subst(TypeSubstitutionFn subs, LookupConformanceFn conformances, @@ -382,7 +240,7 @@ PackConformance::subst(InFlightSubstitution &IFS) const { expander.expand(ConformingType); auto &ctx = Protocol->getASTContext(); - auto *substConformingType = PackType::get(ctx, expander.substElements); + auto *substConformingType = PackType::get(ctx, expander.substElementTypes); auto substConformance = PackConformance::get(substConformingType, Protocol, expander.substConformances); diff --git a/lib/AST/ParameterPack.cpp b/lib/AST/ParameterPack.cpp index d3957e718d4..fb2e93b8021 100644 --- a/lib/AST/ParameterPack.cpp +++ b/lib/AST/ParameterPack.cpp @@ -449,15 +449,15 @@ PackType *PackType::get(const ASTContext &C, auto arg = args[i]; if (params[i]->isParameterPack()) { - wrappedArgs.push_back(PackExpansionType::get( - arg, arg->getReducedShape())); + auto argPackElements = arg->castTo()->getElementTypes(); + wrappedArgs.append(argPackElements.begin(), argPackElements.end()); continue; } wrappedArgs.push_back(arg); } - return get(C, wrappedArgs)->flattenPackTypes(); + return get(C, wrappedArgs); } PackType *PackType::getSingletonPackExpansion(Type param) { diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index a43c332ebcf..90d9612b333 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -4654,6 +4654,84 @@ static Type substGenericFunctionType(GenericFunctionType *genericFnType, fnType->getResult(), fnType->getExtInfo()); } +void InFlightSubstitution::expandPackExpansionShape(Type origShape, + llvm::function_ref handleComponent) { + + // Substitute the shape using the baseline substitutions, not the + // current elementwise projections. + auto substShape = origShape.subst(BaselineSubstType, + BaselineLookupConformance, + Options); + + auto substPackShape = substShape->getAs(); + if (!substPackShape) { + ActivePackExpansions.push_back({/*is subst expansion*/true, 0}); + handleComponent(substShape); + ActivePackExpansions.pop_back(); + return; + } + + ActivePackExpansions.push_back({false, 0}); + for (auto substElt : substPackShape->getElementTypes()) { + auto substExpansion = substElt->getAs(); + auto substExpansionShape = + (substExpansion ? substExpansion->getCountType() : Type()); + + ActivePackExpansions.back().isSubstExpansion = + (substExpansion != nullptr); + handleComponent(substExpansionShape); + ActivePackExpansions.back().expansionIndex++; + } + ActivePackExpansions.pop_back(); +} + +Type InFlightSubstitution::substType(SubstitutableType *origType) { + auto substType = BaselineSubstType(origType); + if (!substType || ActivePackExpansions.empty()) + return substType; + + auto substPackType = substType->getAs(); + if (!substPackType) + return substType; + + auto &activeExpansion = ActivePackExpansions.back(); + auto index = activeExpansion.expansionIndex; + assert(index < substPackType->getNumElements() && + "replacement for pack parameter did not have the right " + "size for expansion"); + auto substEltType = substPackType->getElementType(index); + if (activeExpansion.isSubstExpansion) { + assert(substEltType->is() && + "substituted shape mismatch: expected an expansion component"); + substEltType = substEltType->castTo()->getPatternType(); + } else { + assert(!substEltType->is() && + "substituted shape mismatch: expected a scalar component"); + } + return substEltType; +} + +ProtocolConformanceRef +InFlightSubstitution::lookupConformance(CanType dependentType, + Type conformingReplacementType, + ProtocolDecl *conformedProtocol) { + auto substConfRef = BaselineLookupConformance(dependentType, + conformingReplacementType, + conformedProtocol); + if (!substConfRef || + ActivePackExpansions.empty() || + !substConfRef.isPack()) + return substConfRef; + + auto substPackConf = substConfRef.getPack(); + auto substPackPatterns = substPackConf->getPatternConformances(); + auto index = ActivePackExpansions.back().expansionIndex; + assert(index < substPackPatterns.size() && + "replacement for pack parameter did not have the right " + "size for expansion"); + return substPackPatterns[index]; +} + bool InFlightSubstitution::isInvariant(Type derivedType) const { return !derivedType->hasArchetype() && !derivedType->hasTypeParameter() @@ -4691,12 +4769,9 @@ static Type substType(Type derivedType, InFlightSubstitution &IFS) { } if (auto packExpansionTy = dyn_cast(type)) { - auto patternTy = substType(packExpansionTy->getPatternType(), IFS); - auto countTy = substType(packExpansionTy->getCountType(), IFS); - if (auto *archetypeTy = countTy->getAs()) - countTy = archetypeTy->getReducedShape(); - - return Type(PackExpansionType::get(patternTy, countTy)->expand()); + auto eltTys = IFS.expandPackExpansionType(packExpansionTy); + if (eltTys.size() == 1) return eltTys[0]; + return Type(PackType::get(packExpansionTy->getASTContext(), eltTys)); } if (auto silFnTy = dyn_cast(type)) { @@ -5194,6 +5269,23 @@ Type Type::transform(llvm::function_ref fn) const { }); } +static PackType *getTransformedPack(Type substType) { + if (auto pack = substType->getAs()) { + return pack; + } + + // The pack matchers like to make expansions out of packs, and + // these types then propagate out into transforms. Make sure we + // flatten them exactly if they were the underlying pack. + // FIXME: stop doing this and make PackExpansionType::get assert + // that we never construct these types + if (auto expansion = substType->getAs()) { + return expansion->getPatternType()->getAs(); + } + + return nullptr; +} + Type Type::transformRec( llvm::function_ref(TypeBase *)> fn) const { return transformWithPosition(TypePosition::Invariant, @@ -5631,13 +5723,19 @@ case TypeKind::Id: anyChanged = true; } - elements.push_back(transformedEltTy); + // If the transformed type is a pack, immediately expand it. + if (auto eltPack = getTransformedPack(transformedEltTy)) { + auto eltElements = eltPack->getElementTypes(); + elements.append(eltElements.begin(), eltElements.end()); + } else { + elements.push_back(transformedEltTy); + } } if (!anyChanged) return *this; - return PackType::get(Ptr->getASTContext(), elements)->flattenPackTypes(); + return PackType::get(Ptr->getASTContext(), elements); } case TypeKind::SILPack: { @@ -5689,6 +5787,8 @@ case TypeKind::Id: case TypeKind::PackExpansion: { auto expand = cast(base); + // Substitution completely replaces this. + Type transformedPat = expand->getPatternType().transformWithPosition(pos, fn); if (!transformedPat) @@ -5703,7 +5803,14 @@ case TypeKind::Id: transformedCount.getPointer() == expand->getCountType().getPointer()) return *this; - return PackExpansionType::get(transformedPat, transformedCount)->expand(); + // // If we transform the count to a pack type, expand the pattern. + // // This is necessary because of how we piece together types in + // // the constraint system. + // if (auto countPack = transformedCount->getAs()) { + // return PackExpansionType::expand(transformedPat, countPack); + // } + + return PackExpansionType::get(transformedPat, transformedCount); } case TypeKind::Tuple: { @@ -5734,13 +5841,35 @@ case TypeKind::Id: } // Add the new tuple element, with the transformed type. - elements.push_back(elt.getWithType(transformedEltTy)); + // Expand packs immediately. + if (auto eltPack = getTransformedPack(transformedEltTy)) { + bool first = true; + for (auto eltElement : eltPack->getElementTypes()) { + if (first) { + elements.push_back(elt.getWithType(eltElement)); + first = false; + } else { + elements.push_back(TupleTypeElt(eltElement)); + } + } + } else { + elements.push_back(elt.getWithType(transformedEltTy)); + } } if (!anyChanged) return *this; - return TupleType::get(elements, Ptr->getASTContext())->flattenPackTypes(); + // If the transform would yield a singleton tuple, and we didn't + // start with one, flatten to produce the element type. + if (elements.size() == 1 && + !elements[0].getType()->is() && + !(tuple->getNumElements() == 1 && + !tuple->getElementType(0)->is())) { + return elements[0].getType(); + } + + return TupleType::get(elements, Ptr->getASTContext()); } @@ -5794,7 +5923,21 @@ case TypeKind::Id: flags = flags.withInOut(true); } - substParams.emplace_back(substType, label, flags, internalLabel); + if (auto substPack = getTransformedPack(substType)) { + bool first = true; + for (auto substEltType : substPack->getElementTypes()) { + if (first) { + substParams.emplace_back(substEltType, label, flags, + internalLabel); + first = false; + } else { + substParams.emplace_back(substEltType, Identifier(), flags, + Identifier()); + } + } + } else { + substParams.emplace_back(substType, label, flags, internalLabel); + } } // Transform result type. @@ -5836,8 +5979,7 @@ case TypeKind::Id: return GenericFunctionType::get(genericSig, substParams, resultTy); return GenericFunctionType::get(genericSig, substParams, resultTy, function->getExtInfo() - .withGlobalActor(globalActorType)) - ->flattenPackTypes(); + .withGlobalActor(globalActorType)); } if (isUnchanged) return *this; @@ -5846,8 +5988,7 @@ case TypeKind::Id: return FunctionType::get(substParams, resultTy); return FunctionType::get(substParams, resultTy, function->getExtInfo() - .withGlobalActor(globalActorType)) - ->flattenPackTypes(); + .withGlobalActor(globalActorType)); } case TypeKind::ArraySlice: { diff --git a/lib/SIL/IR/SILTypeSubstitution.cpp b/lib/SIL/IR/SILTypeSubstitution.cpp index 51e8d242f61..b74c754cf6f 100644 --- a/lib/SIL/IR/SILTypeSubstitution.cpp +++ b/lib/SIL/IR/SILTypeSubstitution.cpp @@ -42,27 +42,6 @@ class SILTypeSubstituter : // context signature. CanGenericSignature Sig; - struct PackExpansion { - /// The shape class of pack parameters that are expanded by this - /// expansion. Set during construction and not changed. - CanType OrigShapeClass; - - /// The count type of the pack expansion in the current lane of - /// expansion, if any. Pack elements in this lane should be - /// expansions with this shape. - CanType SubstPackExpansionCount; - - /// The index of the current lane of expansion. Basic - /// substitution of pack parameters with the same shape as - /// OrigShapeClass should yield a pack, and lanewise - /// substitution should produce this element of that pack. - unsigned Index; - - PackExpansion(CanType origShapeClass) - : OrigShapeClass(origShapeClass), Index(0) {} - }; - SmallVector ActivePackExpansions; - TypeExpansionContext typeExpansionContext; public: @@ -356,63 +335,21 @@ public: } CanType visitPackExpansionType(CanPackExpansionType origType) { - CanType patternType = visit(origType.getPatternType()); - CanType countType = substASTType(origType.getCountType()); - - return CanType(PackExpansionType::get(patternType, countType)); + llvm_unreachable("shouldn't substitute an independent lowered pack " + "expansion type"); } void substPackExpansion(CanPackExpansionType origType, llvm::function_ref addExpandedType) { - CanType origCountType = origType.getCountType(); - CanType origPatternType = origType.getPatternType(); - - // Substitute the count type (as an AST type). - CanType substCountType = substASTType(origCountType); - - // If that produces a pack type, expand the pattern element-wise. - if (auto substCountPackType = dyn_cast(substCountType)) { - // Set up for element-wise expansion. - ActivePackExpansions.emplace_back(origCountType); - - for (CanType substCountEltType : substCountPackType.getElementTypes()) { - auto expansionType = dyn_cast(substCountEltType); - ActivePackExpansions.back().SubstPackExpansionCount = - (expansionType ? expansionType.getCountType() : CanType()); - - // Expand the pattern type in the element-wise context. - CanType expandedType = visit(origPatternType); - - // Turn that into a pack expansion if appropriate for the - // count element. - if (expansionType) { - expandedType = - CanPackExpansionType::get(expandedType, - expansionType.getCountType()); - } - - addExpandedType(expandedType); - - // Move to the next element. - ActivePackExpansions.back().Index++; + IFS.expandPackExpansionShape(origType.getCountType(), + [&](Type substExpansionShape) { + CanType substComponentType = visit(origType.getPatternType()); + if (substExpansionShape) { + substComponentType = CanPackExpansionType::get(substComponentType, + substExpansionShape->getCanonicalType()); } - - // Leave the element-wise context. - ActivePackExpansions.pop_back(); - return; - } - - // Otherwise, transform the pattern type abstractly and just add a - // type expansion. - CanType substPatternType = visit(origPatternType); - - CanType expandedType; - if (substCountType == origCountType && substPatternType == origPatternType) - expandedType = origType; - else - expandedType = - CanPackExpansionType::get(substPatternType, substCountType); - addExpandedType(expandedType); + addExpandedType(substComponentType); + }); } /// Tuples need to have their component types substituted by these @@ -511,73 +448,12 @@ public: substType); } - struct SubstRespectingExpansions { - SILTypeSubstituter *_this; - SubstRespectingExpansions(SILTypeSubstituter *_this) : _this(_this) {} - - Type operator()(SubstitutableType *origType) const { - auto substType = _this->IFS.substType(origType); - if (!substType) return substType; - auto substPackType = dyn_cast(substType->getCanonicalType()); - if (!substPackType) return substType; - auto activeExpansion = _this->getActivePackExpansion(CanType(origType)); - if (!activeExpansion) return substType; - auto substEltType = - substPackType.getElementType(activeExpansion->Index); - auto substExpansion = dyn_cast(substEltType); - assert((bool) substExpansion == - (bool) activeExpansion->SubstPackExpansionCount); - if (substExpansion) { - assert(_this->hasSameShape(substExpansion.getCountType(), - activeExpansion->SubstPackExpansionCount)); - return substExpansion.getPatternType(); - } - return substEltType; - } - }; - - struct SubstConformanceRespectingExpansions { - SILTypeSubstituter *_this; - SubstConformanceRespectingExpansions(SILTypeSubstituter *_this) - : _this(_this) {} - - ProtocolConformanceRef operator()(CanType dependentType, - Type conformingReplacementType, - ProtocolDecl *conformingProtocol) const { - auto conformance = - _this->IFS.lookupConformance(dependentType, - conformingReplacementType, - conformingProtocol); - if (!conformance || !conformance.isPack()) return conformance; - auto activeExpansion = _this->getActivePackExpansion(dependentType); - if (!activeExpansion) return conformance; - auto pack = conformance.getPack(); - auto substEltConf = - pack->getPatternConformances()[activeExpansion->Index]; - // There isn't currently a ProtocolConformanceExpansion that - // we would need to look through here. - return substEltConf; - }; - }; - CanType substASTType(CanType origType) { - if (ActivePackExpansions.empty()) - return origType.subst(IFS)->getCanonicalType(); - - return origType.subst(SubstRespectingExpansions(this), - SubstConformanceRespectingExpansions(this), - IFS.getOptions())->getCanonicalType(); + return origType.subst(IFS)->getCanonicalType(); } SubstitutionMap substSubstitutions(SubstitutionMap subs) { - SubstitutionMap newSubs; - - if (ActivePackExpansions.empty()) - newSubs = subs.subst(IFS); - else - newSubs = subs.subst(SubstRespectingExpansions(this), - SubstConformanceRespectingExpansions(this), - IFS.getOptions()); + SubstitutionMap newSubs = subs.subst(IFS); // If we need to look through opaque types in this context, re-substitute // according to the expansion context. @@ -585,28 +461,6 @@ public: return newSubs; } - - PackExpansion *getActivePackExpansion(CanType dependentType) { - // We push new expansions onto the end of this vector, and we - // want to honor the innermost expansion, so we have to traverse - // in it reverse. - for (auto &entry : reverse(ActivePackExpansions)) { - if (hasSameShape(dependentType, entry.OrigShapeClass)) - return &entry; - } - return nullptr; - } - - bool hasSameShape(CanType lhs, CanType rhs) { - if (lhs->isTypeParameter() && rhs->isTypeParameter()) { - assert(Sig); - return Sig->haveSameShape(lhs, rhs); - } - - auto lhsArchetype = cast(lhs); - auto rhsArchetype = cast(rhs); - return lhsArchetype->getReducedShape() == rhsArchetype->getReducedShape(); - } }; } // end anonymous namespace diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 19b8cd4d83d..ef5d9edd473 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -7125,8 +7125,8 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType, auto *expansion = dyn_cast(expr); auto *elementEnv = expansion->getGenericEnvironment(); - auto toElementType = elementEnv->mapPackTypeIntoElementContext( - toExpansionType->getPatternType()->mapTypeOutOfContext()); + auto toElementType = elementEnv->mapContextualPackTypeIntoElementContext( + toExpansionType->getPatternType()); auto *pattern = coerceToType(expansion->getPatternExpr(), toElementType, locator); diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index df47209c617..894a7a28f38 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -3713,16 +3713,56 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator, } } -Type ConstraintSystem::simplifyTypeImpl(Type type, - llvm::function_ref getFixedTypeFn) const { - return type.transform([&](Type type) -> Type { - if (auto tvt = dyn_cast(type.getPointer())) - return getFixedTypeFn(tvt); +namespace { + +struct TypeSimplifier { + const ConstraintSystem &CS; + llvm::function_ref GetFixedTypeFn; + + struct ActivePackExpansion { + bool isPackExpansion = false; + unsigned index = 0; + }; + SmallVector ActivePackExpansions; + + TypeSimplifier(const ConstraintSystem &CS, + llvm::function_ref getFixedTypeFn) + : CS(CS), GetFixedTypeFn(getFixedTypeFn) {} + + Type operator()(Type type) { + if (auto tvt = dyn_cast(type.getPointer())) { + auto fixedTy = GetFixedTypeFn(tvt); + + // TODO: the following logic should be applied when rewriting + // PackElementType. + if (ActivePackExpansions.empty()) { + return fixedTy; + } + + if (auto fixedPack = fixedTy->getAs()) { + auto &activeExpansion = ActivePackExpansions.back(); + if (activeExpansion.index >= fixedPack->getNumElements()) { + return tvt; + } + + auto fixedElt = fixedPack->getElementType(activeExpansion.index); + auto fixedExpansion = fixedElt->getAs(); + if (activeExpansion.isPackExpansion && fixedExpansion) { + return fixedExpansion->getPatternType(); + } else if (!activeExpansion.isPackExpansion && !fixedExpansion) { + return fixedElt; + } else { + return tvt; + } + } + + return fixedTy; + } if (auto tuple = dyn_cast(type.getPointer())) { if (tuple->getNumElements() == 1) { auto element = tuple->getElement(0); - auto elementType = simplifyTypeImpl(element.getType(), getFixedTypeFn); + auto elementType = element.getType().transform(*this); // Flatten single-element tuples containing type variables that cannot // bind to packs. @@ -3733,14 +3773,47 @@ Type ConstraintSystem::simplifyTypeImpl(Type type, } } + if (auto expansion = dyn_cast(type.getPointer())) { + // Transform the count type, ignoring any active pack expansions. + auto countType = expansion->getCountType().transform( + TypeSimplifier(CS, GetFixedTypeFn)); + + if (auto countPack = countType->getAs()) { + SmallVector elts; + ActivePackExpansions.push_back({false, 0}); + for (auto countElt : countPack->getElementTypes()) { + auto countExpansion = countElt->getAs(); + ActivePackExpansions.back().isPackExpansion = + (countExpansion != nullptr); + + auto elt = expansion->getPatternType().transform(*this); + if (countExpansion) + elt = PackExpansionType::get(elt, countExpansion->getCountType()); + elts.push_back(elt); + + ActivePackExpansions.back().index++; + } + ActivePackExpansions.pop_back(); + + if (elts.size() == 1) + return elts[0]; + return PackType::get(CS.getASTContext(), elts); + } else { + ActivePackExpansions.push_back({true, 0}); + auto patternType = expansion->getPatternType().transform(*this); + ActivePackExpansions.pop_back(); + return PackExpansionType::get(patternType, countType); + } + } + // If this is a dependent member type for which we end up simplifying // the base to a non-type-variable, perform lookup. if (auto depMemTy = dyn_cast(type.getPointer())) { // Simplify the base. - Type newBase = simplifyTypeImpl(depMemTy->getBase(), getFixedTypeFn); + Type newBase = depMemTy->getBase().transform(*this); if (newBase->isPlaceholder()) { - return PlaceholderType::get(getASTContext(), depMemTy); + return PlaceholderType::get(CS.getASTContext(), depMemTy); } // If nothing changed, we're done. @@ -3760,7 +3833,7 @@ Type ConstraintSystem::simplifyTypeImpl(Type type, if (lookupBaseType->mayHaveMembers() || lookupBaseType->is()) { auto *proto = assocType->getProtocol(); - auto conformance = DC->getParentModule()->lookupConformance( + auto conformance = CS.DC->getParentModule()->lookupConformance( lookupBaseType, proto); if (!conformance) { // FIXME: This regresses diagnostics if removed, but really the @@ -3774,9 +3847,9 @@ Type ConstraintSystem::simplifyTypeImpl(Type type, // so the concrete dependent member type is considered a "hole" in // order to continue solving. auto memberTy = DependentMemberType::get(lookupBaseType, assocType); - if (shouldAttemptFixes() && - getPhase() == ConstraintSystemPhase::Solving) { - return PlaceholderType::get(getASTContext(), memberTy); + if (CS.shouldAttemptFixes() && + CS.getPhase() == ConstraintSystemPhase::Solving) { + return PlaceholderType::get(CS.getASTContext(), memberTy); } return memberTy; @@ -3792,7 +3865,14 @@ Type ConstraintSystem::simplifyTypeImpl(Type type, } return type; - }); + } +}; + +} // end anonymous namespace + +Type ConstraintSystem::simplifyTypeImpl(Type type, + llvm::function_ref getFixedTypeFn) const { + return type.transform(TypeSimplifier(*this, getFixedTypeFn)); } Type ConstraintSystem::simplifyType(Type type) const { diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index 08315473e14..44cead7af37 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -918,11 +918,22 @@ static Type applyGenericArguments(Type type, TypeResolution resolution, assert(found != matcher.pairs.end()); auto arg = found->rhs; - if (auto *expansionType = arg->getAs()) - arg = expansionType->getPatternType(); - if (arg->isParameterPack()) - arg = PackType::getSingletonPackExpansion(arg); + // PackMatcher will always produce a PackExpansionType as the + // arg for a pack parameter, if necessary by wrapping a PackType + // in one. (It's a weird representation.) Look for that pattern + // and unwrap the pack. Otherwise, we must have matched with a + // single component which happened to be an expansion; wrap that + // in a PackType. In either case, we always want arg to end up + // a PackType. + if (auto *expansionType = arg->getAs()) { + auto pattern = expansionType->getPatternType(); + if (auto pack = pattern->getAs()) { + arg = pack; + } else { + arg = PackType::get(ctx, {expansionType}); + } + } args.push_back(arg); }