Perform component-wise substitution of pack expansions immediately.

Substitution of a pack expansion type may now produce a pack type.
We immediately expand that pack when transforming a tuple, a function
parameter, or a pack.

I had to duplicate the component-wise transformation logic in the
simplifyType transform, which I'm not pleased about, but a little
code duplication seemed a lot better than trying to unify the code
in two very different places.

I think we're very close to being able to assert that pack expansion
shapes are either pack archetypes or pack parameters; unfortunately,
the pack matchers intentionally produce expansions of packs, and I
didn't want to add that to an already-large patch.
This commit is contained in:
John McCall
2023-03-25 18:46:29 -04:00
parent d16fed6e9a
commit c041d1061a
9 changed files with 398 additions and 386 deletions

View File

@@ -167,207 +167,65 @@ ProtocolConformanceRef PackConformance::subst(SubstitutionMap subMap,
return subst(IFS);
}
// TODO: Move this elsewhere since it's generally useful
static bool arePackShapesEqual(PackType *lhs, PackType *rhs) {
if (lhs->getNumElements() != rhs->getNumElements())
return false;
for (unsigned i = 0, e = lhs->getNumElements(); i < e; ++i) {
auto lhsElt = lhs->getElementType(i);
auto rhsElt = rhs->getElementType(i);
if (lhsElt->is<PackExpansionType>() != rhsElt->is<PackExpansionType>())
return false;
}
return true;
}
static bool isRootParameterPack(Type t) {
if (auto *paramTy = t->getAs<GenericTypeParamType>()) {
return paramTy->isParameterPack();
} else if (auto *archetypeTy = t->getAs<PackArchetypeType>()) {
return archetypeTy->isRoot();
}
return false;
}
static bool isRootedInParameterPack(Type t) {
if (auto *archetypeTy = t->getAs<PackArchetypeType>()) {
return true;
}
return t->getRootGenericParam()->isParameterPack();
}
namespace {
template<typename ImplClass>
class PackExpander {
protected:
struct PackConformanceExpander {
InFlightSubstitution &IFS;
ArrayRef<ProtocolConformanceRef> origConformances;
PackExpander(InFlightSubstitution &IFS) : IFS(IFS) {}
public:
// Results built up by the expansion.
SmallVector<Type, 4> substElementTypes;
SmallVector<ProtocolConformanceRef, 4> substConformances;
ImplClass *asImpl() {
return static_cast<ImplClass *>(this);
PackConformanceExpander(InFlightSubstitution &IFS,
ArrayRef<ProtocolConformanceRef> origConformances)
: IFS(IFS), origConformances(origConformances) {}
private:
/// Substitute a scalar element of the original pack.
void substScalar(Type origElementType,
ProtocolConformanceRef origConformance) {
auto substElementType = origElementType.subst(IFS);
auto substConformance = origConformance.subst(origElementType, IFS);
substElementTypes.push_back(substElementType);
substConformances.push_back(substConformance);
}
/// We're replacing a pack expansion type with a pack -- flatten the pack
/// using the pack expansion's pattern.
void addExpandedExpansion(Type origPatternType, PackType *expandedCountType,
unsigned i) {
/// Substitute and expand an expansion element of the original pack.
void substExpansion(PackExpansionType *origExpansionType,
ProtocolConformanceRef origConformance) {
IFS.expandPackExpansionType(origExpansionType,
[&](Type substComponentType) {
auto origPatternType = origExpansionType->getPatternType();
// Get all pack parameters referenced from the pattern.
SmallVector<Type, 2> rootParameterPacks;
origPatternType->getTypeParameterPacks(rootParameterPacks);
// Just substitute the conformance. We don't directly represent
// pack expansion conformances here; it's sort of implicit in the
// corresponding pack element type.
auto substConformance = origConformance.subst(origPatternType, IFS);
// Each pack parameter referenced from the pattern must be replaced
// with a pack type, and all pack types must have the same shape as
// the expanded count pack type.
llvm::SmallDenseMap<Type, PackType *, 2> expandedPacks;
for (auto origParamType : rootParameterPacks) {
auto substParamType = origParamType.subst(IFS);
if (auto expandedParamType = substParamType->template getAs<PackType>()) {
assert(arePackShapesEqual(expandedParamType, expandedCountType) &&
"TODO: Return an invalid conformance if this fails");
auto inserted = expandedPacks.insert(
std::make_pair(origParamType->getCanonicalType(),
expandedParamType)).second;
assert(inserted &&
"getTypeParameterPacks() should not return duplicates");
} else {
assert(false &&
"TODO: Return an invalid conformance if this fails");
}
}
// For each element of the expanded count, compute the substituted
// pattern type.
for (unsigned j = 0, ee = expandedCountType->getNumElements(); j < ee; ++j) {
auto projectedSubs = [&](SubstitutableType *type) -> Type {
// Nested sequence archetypes get passed in here, but we must
// handle them via the standard nested type path.
if (auto *archetypeType = dyn_cast<ArchetypeType>(type)) {
if (!archetypeType->isRoot())
return Type();
}
// Compute the substituted type using our parent substitutions.
auto substType = Type(type).subst(IFS);
// If the substituted type is a pack, project the jth element.
if (isRootParameterPack(type)) {
// FIXME: What if you have something like G<T...>... where G<> is
// variadic?
assert(substType->template is<PackType>() &&
"TODO: Return an invalid conformance if this fails");
auto *packType = substType->template castTo<PackType>();
assert(arePackShapesEqual(packType, expandedCountType) &&
"TODO: Return an invalid conformance if this fails");
return packType->getElementType(j);
}
return IFS.substType(type);
};
auto projectedConformances = [&](CanType origType, Type substType,
ProtocolDecl *proto) -> ProtocolConformanceRef {
auto substConformance =
IFS.lookupConformance(origType, substType, proto);
// If the substituted conformance is a pack, project the jth element.
if (isRootedInParameterPack(origType)) {
return substConformance.getPack()->getPatternConformances()[j];
}
return substConformance;
};
auto origCountElement = expandedCountType->getElementType(j);
auto substCountElement = origCountElement.subst(
projectedSubs, projectedConformances, IFS.getOptions());
asImpl()->add(origCountElement, substCountElement, i);
}
}
/// A pack expansion remains unexpanded, so we substitute the pattern and
/// form a new pack expansion.
void addUnexpandedExpansion(Type origPatternType, Type substCountType,
unsigned i) {
auto substPatternType = origPatternType.subst(IFS);
auto substExpansion = PackExpansionType::get(substPatternType, substCountType);
asImpl()->add(origPatternType, substExpansion, i);
}
/// Scalar elements of the original pack are substituted and added to the
/// flattened pack.
void addScalar(Type origElement, unsigned i) {
auto substElement = origElement.subst(IFS);
asImpl()->add(origElement, substElement, i);
}
/// Potentially expand an element of the original pack.
void maybeExpandExpansion(PackExpansionType *origExpansion, unsigned i) {
auto origPatternType = origExpansion->getPatternType();
auto origCountType = origExpansion->getCountType();
auto substCountType = origCountType.subst(IFS);
// If the substituted count type is a pack, we're expanding the
// original element.
if (auto *expandedCountType = substCountType->template getAs<PackType>()) {
addExpandedExpansion(origPatternType, expandedCountType, i);
return;
}
addUnexpandedExpansion(origPatternType, substCountType, i);
substElementTypes.push_back(substComponentType);
substConformances.push_back(substConformance);
});
}
public:
void expand(PackType *origPackType) {
for (unsigned i = 0, e = origPackType->getNumElements(); i < e; ++i) {
auto origElement = origPackType->getElementType(i);
assert(origPackType->getNumElements() == origConformances.size());
// Check if the original element is potentially being expanded.
if (auto *origExpansion = origElement->getAs<PackExpansionType>()) {
maybeExpandExpansion(origExpansion, i);
continue;
for (auto i : range(origPackType->getNumElements())) {
auto origElementType = origPackType->getElementType(i);
if (auto *origExpansion = origElementType->getAs<PackExpansionType>()) {
substExpansion(origExpansion, origConformances[i]);
} else {
substScalar(origElementType, origConformances[i]);
}
addScalar(origElement, i);
}
}
};
class PackConformanceExpander : public PackExpander<PackConformanceExpander> {
public:
SmallVector<Type, 4> substElements;
SmallVector<ProtocolConformanceRef, 4> substConformances;
ArrayRef<ProtocolConformanceRef> origConformances;
PackConformanceExpander(InFlightSubstitution &IFS,
ArrayRef<ProtocolConformanceRef> origConformances)
: PackExpander(IFS), origConformances(origConformances) {}
void add(Type origType, Type substType, unsigned i) {
substElements.push_back(substType);
// FIXME: Pass down projection callbacks
substConformances.push_back(origConformances[i].subst(
origType, IFS));
}
};
}
} // end anonymous namespace
ProtocolConformanceRef PackConformance::subst(TypeSubstitutionFn subs,
LookupConformanceFn conformances,
@@ -382,7 +240,7 @@ PackConformance::subst(InFlightSubstitution &IFS) const {
expander.expand(ConformingType);
auto &ctx = Protocol->getASTContext();
auto *substConformingType = PackType::get(ctx, expander.substElements);
auto *substConformingType = PackType::get(ctx, expander.substElementTypes);
auto substConformance = PackConformance::get(substConformingType, Protocol,
expander.substConformances);