Support vanishing tuples in the the general abstraction pattern routines.

This commit is contained in:
John McCall
2023-04-03 23:14:47 -04:00
parent cbf8519084
commit 30817b11dd
3 changed files with 145 additions and 28 deletions

View File

@@ -283,7 +283,7 @@ LayoutConstraint AbstractionPattern::getLayoutConstraint() const {
}
}
bool AbstractionPattern::matchesTuple(CanTupleType substType) const {
bool AbstractionPattern::matchesTuple(CanType substType) const {
switch (getKind()) {
case Kind::Invalid:
llvm_unreachable("querying invalid abstraction pattern!");
@@ -311,11 +311,19 @@ bool AbstractionPattern::matchesTuple(CanTupleType substType) const {
return false;
LLVM_FALLTHROUGH;
case Kind::Tuple: {
if (getVanishingTupleElementPatternType()) {
// TODO: recurse into elements.
return true;
}
auto substTupleType = dyn_cast<TupleType>(substType);
if (!substTupleType) return false;
size_t nextSubstIndex = 0;
auto nextComponentIsAcceptable = [&](bool isPackExpansion) -> bool {
if (nextSubstIndex == substType->getNumElements())
if (nextSubstIndex == substTupleType->getNumElements())
return false;
auto substComponentType = substType.getElementType(nextSubstIndex++);
auto substComponentType = substTupleType.getElementType(nextSubstIndex++);
return (isPackExpansion == isa<PackExpansionType>(substComponentType));
};
for (auto elt : getTupleElementTypes()) {
@@ -333,7 +341,7 @@ bool AbstractionPattern::matchesTuple(CanTupleType substType) const {
return false;
}
}
return nextSubstIndex == substType->getNumElements();
return nextSubstIndex == substTupleType->getNumElements();
}
}
llvm_unreachable("bad kind");
@@ -469,7 +477,63 @@ bool AbstractionPattern::doesTupleContainPackExpansionType() const {
llvm_unreachable("bad kind");
}
void AbstractionPattern::forEachTupleElement(CanTupleType substType,
Optional<AbstractionPattern>
AbstractionPattern::getVanishingTupleElementPatternType() const {
if (!isTuple()) return None;
if (!GenericSubs) return None;
// Substitution causes tuples to vanish when substituting the elements
// produces a singleton tuple and it didn't start that way.
auto numOrigElts = getNumTupleElements();
// Track whether we've found a single element.
Optional<AbstractionPattern> singletonEltType;
bool hadOrigExpansion = false;
for (auto index : range(numOrigElts)) {
auto eltType = getTupleElementType(index);
// If this pattern isn't a pack expansion, we've got a new candidate
// singleton. If this is the second such candidate, of course, it's
// not a singleton.
if (!eltType.isPackExpansion()) {
if (singletonEltType) return None;
singletonEltType = eltType;
// Otherwise, check what the expansion shape expands to.
} else {
hadOrigExpansion = true;
auto expansionType = cast<PackExpansionType>(eltType.getType());
auto substShape = cast<PackType>(
expansionType.getCountType().subst(GenericSubs)->getCanonicalType());
auto expansionCount = substShape->getNumElements();
// If it expands to multiple elements or to a single expansion, we
// won't have a singleton tuple. If it expands to a single scalar
// element, this is a singleton candidate.
if (expansionCount > 1) {
return None;
} else if (expansionCount == 1) {
auto substExpansion =
dyn_cast<PackExpansionType>(substShape.getElementType(0));
if (substExpansion)
return None;
if (singletonEltType)
return None;
singletonEltType = eltType.getPackExpansionPatternType();
}
}
}
// If we found a singleton scalar element, and we didn't start with
// a singleton element, that's the index we want to return.
if (singletonEltType && !(numOrigElts == 1 && !hadOrigExpansion))
return singletonEltType;
return None;
}
void AbstractionPattern::forEachTupleElement(CanType substType,
llvm::function_ref<void(TupleElementGenerator &)> handleElement) const {
TupleElementGenerator elt(*this, substType);
for (; !elt.isFinished(); elt.advance()) {
@@ -480,35 +544,46 @@ void AbstractionPattern::forEachTupleElement(CanTupleType substType,
TupleElementGenerator::TupleElementGenerator(
AbstractionPattern origTupleType,
CanTupleType substTupleType)
: origTupleType(origTupleType), substTupleType(substTupleType) {
CanType substType)
: origTupleType(origTupleType), substType(substType) {
assert(origTupleType.isTuple());
assert(origTupleType.matchesTuple(substTupleType));
assert(origTupleType.matchesTuple(substType));
origTupleVanishes =
origTupleType.getVanishingTupleElementPatternType().hasValue();
origTupleTypeIsOpaque = origTupleType.isOpaqueTuple();
numOrigElts = origTupleType.getNumTupleElements();
if (!isFinished()) loadElement();
}
void AbstractionPattern::forEachExpandedTupleElement(CanTupleType substType,
void AbstractionPattern::forEachExpandedTupleElement(CanType substType,
llvm::function_ref<void(AbstractionPattern origEltType,
CanType substEltType,
const TupleTypeElt &elt)>
handleElement) const {
assert(matchesTuple(substType));
auto substEltTypes = substType.getElementTypes();
// Handle opaque patterns by just iterating the substituted components.
if (!isTuple()) {
auto substTupleType = cast<TupleType>(substType);
auto substEltTypes = substTupleType.getElementTypes();
for (auto i : indices(substEltTypes)) {
handleElement(getTupleElementType(i), substEltTypes[i],
substType->getElement(i));
substTupleType->getElement(i));
}
return;
}
// For vanishing tuples, just call the callback once.
if (auto origEltType = getVanishingTupleElementPatternType()) {
handleElement(*origEltType, substType, TupleTypeElt(substType));
return;
}
auto substTupleType = cast<TupleType>(substType);
auto substEltTypes = substTupleType.getElementTypes();
// For non-opaque patterns, we have to iterate the original components
// in order to match things up properly, but we'll still end up calling
// once per substituted element.
@@ -517,7 +592,7 @@ void AbstractionPattern::forEachExpandedTupleElement(CanTupleType substType,
auto origEltType = getTupleElementType(origEltIndex);
if (!origEltType.isPackExpansion()) {
handleElement(origEltType, substEltTypes[substEltIndex],
substType->getElement(substEltIndex));
substTupleType->getElement(substEltIndex));
substEltIndex++;
} else {
auto origPatternType = origEltType.getPackExpansionPatternType();
@@ -532,7 +607,8 @@ void AbstractionPattern::forEachExpandedTupleElement(CanTupleType substType,
// be misleading in one way or another.
handleElement(isa<PackExpansionType>(substEltType)
? origEltType : origPatternType,
substEltType, substType->getElement(substEltIndex));
substEltType,
substTupleType->getElement(substEltIndex));
substEltIndex++;
}
}