RequirementMachine: Move protocol linear order from ProtocolGraph to RewriteContext

This commit is contained in:
Slava Pestov
2021-10-21 18:58:18 -04:00
parent 941438d6c8
commit 0571b65cb8
16 changed files with 207 additions and 190 deletions

View File

@@ -12,7 +12,6 @@
#include "swift/AST/Decl.h"
#include "swift/AST/Types.h"
#include "ProtocolGraph.h"
#include "RequirementMachine.h"
#include "RewriteSystem.h"
#include "RewriteContext.h"
@@ -66,6 +65,77 @@ RewriteContext::RewriteContext(ASTContext &ctx)
Debug = parseDebugFlags(debugFlags);
}
const llvm::TinyPtrVector<const ProtocolDecl *> &
RewriteContext::getInheritedProtocols(const ProtocolDecl *proto) {
auto found = AllInherited.find(proto);
if (found != AllInherited.end())
return found->second;
AllInherited.insert(std::make_pair(proto, TinyPtrVector<const ProtocolDecl *>()));
llvm::SmallDenseSet<const ProtocolDecl *, 4> visited;
llvm::TinyPtrVector<const ProtocolDecl *> protos;
for (auto *inheritedProto : proto->getInheritedProtocols()) {
if (!visited.insert(inheritedProto).second)
continue;
protos.push_back(inheritedProto);
const auto &allInherited = getInheritedProtocols(inheritedProto);
for (auto *otherProto : allInherited) {
if (!visited.insert(otherProto).second)
continue;
protos.push_back(otherProto);
}
}
auto &result = AllInherited[proto];
std::swap(protos, result);
return result;
}
unsigned RewriteContext::getProtocolSupport(
const ProtocolDecl *proto) {
return getInheritedProtocols(proto).size() + 1;
}
unsigned RewriteContext::getProtocolSupport(
ArrayRef<const ProtocolDecl *> protos) {
auto found = Support.find(protos);
if (found != Support.end())
return found->second;
unsigned result;
if (protos.size() == 1) {
result = getProtocolSupport(protos[0]);
} else {
llvm::DenseSet<const ProtocolDecl *> visited;
for (const auto *proto : protos) {
visited.insert(proto);
for (const auto *inheritedProto : getInheritedProtocols(proto))
visited.insert(inheritedProto);
}
result = visited.size();
}
Support[protos] = result;
return result;
}
int RewriteContext::compareProtocols(const ProtocolDecl *lhs,
const ProtocolDecl *rhs) {
unsigned lhsSupport = getProtocolSupport(lhs);
unsigned rhsSupport = getProtocolSupport(rhs);
if (lhsSupport != rhsSupport)
return rhsSupport - lhsSupport;
return TypeDecl::compare(lhs, rhs);
}
Term RewriteContext::getTermForType(CanType paramType,
const ProtocolDecl *proto) {
return Term::get(getMutableTermForType(paramType, proto), *this);
@@ -167,64 +237,53 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,
/// Note that the protocol graph is not part of the caching key; each
/// protocol graph is a subgraph of the global inheritance graph, so
/// the specific choice of subgraph does not change the result.
AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(
Symbol symbol, const ProtocolGraph &protos) {
AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(Symbol symbol) {
auto found = AssocTypes.find(symbol);
if (found != AssocTypes.end())
return found->second;
assert(symbol.getKind() == Symbol::Kind::AssociatedType);
auto *proto = symbol.getProtocols()[0];
auto name = symbol.getName();
AssociatedTypeDecl *assocType = nullptr;
// Special case: handle unknown protocols, since they can appear in the
// invalid types that getCanonicalTypeInContext() must handle via
// concrete substitution; see the definition of getCanonicalTypeInContext()
// below for details.
if (!protos.isKnownProtocol(proto)) {
assert(symbol.getProtocols().size() == 1 &&
"Unknown associated type symbol must have a single protocol");
assocType = proto->getAssociatedType(name)->getAssociatedTypeAnchor();
} else {
// An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
// P0...Pn and an identifier 'A'.
//
// We map it back to a AssociatedTypeDecl as follows:
//
// - For each protocol Pn, look for associated types A in Pn itself,
// and all protocols that Pn refines.
//
// - For each candidate associated type An in protocol Qn where
// Pn refines Qn, get the associated type anchor An' defined in
// protocol Qn', where Qn refines Qn'.
//
// - Out of all the candidiate pairs (Qn', An'), pick the one where
// the protocol Qn' is the lowest element according to the linear
// order defined by TypeDecl::compare().
//
// The associated type An' is then the canonical associated type
// representative of the associated type symbol [P0&...&Pn:A].
//
for (auto *proto : symbol.getProtocols()) {
const auto &info = protos.getProtocolInfo(proto);
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
// An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
// P0...Pn and an identifier 'A'.
//
// We map it back to a AssociatedTypeDecl as follows:
//
// - For each protocol Pn, look for associated types A in Pn itself,
// and all protocols that Pn refines.
//
// - For each candidate associated type An in protocol Qn where
// Pn refines Qn, get the associated type anchor An' defined in
// protocol Qn', where Qn refines Qn'.
//
// - Out of all the candidiate pairs (Qn', An'), pick the one where
// the protocol Qn' is the lowest element according to the linear
// order defined by TypeDecl::compare().
//
// The associated type An' is then the canonical associated type
// representative of the associated type symbol [P0&...&Pn:A].
//
for (auto *proto : symbol.getProtocols()) {
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
if (otherAssocType->getName() == name &&
(assocType == nullptr ||
TypeDecl::compare(otherAssocType->getProtocol(),
assocType->getProtocol()) < 0)) {
assocType = otherAssocType;
}
};
for (auto *otherAssocType : info.AssociatedTypes) {
checkOtherAssocType(otherAssocType);
if (otherAssocType->getName() == name &&
(assocType == nullptr ||
TypeDecl::compare(otherAssocType->getProtocol(),
assocType->getProtocol()) < 0)) {
assocType = otherAssocType;
}
};
for (auto *otherAssocType : info.InheritedAssociatedTypes) {
for (auto *otherAssocType : proto->getAssociatedTypeMembers()) {
checkOtherAssocType(otherAssocType);
}
for (auto *inheritedProto : getInheritedProtocols(proto)) {
for (auto *otherAssocType : inheritedProto->getAssociatedTypeMembers()) {
checkOtherAssocType(otherAssocType);
}
}
@@ -244,7 +303,6 @@ AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(
template<typename Iter>
Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
TypeArrayView<GenericTypeParamType> genericParams,
const ProtocolGraph &protos,
const RewriteContext &ctx) {
Type result = root;
@@ -319,7 +377,7 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
// We should have a resolved type at this point.
auto *assocType =
const_cast<RewriteContext &>(ctx)
.getAssociatedTypeForSymbol(symbol, protos);
.getAssociatedTypeForSymbol(symbol);
result = DependentMemberType::get(result, assocType);
}
@@ -327,28 +385,25 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
}
Type RewriteContext::getTypeForTerm(Term term,
TypeArrayView<GenericTypeParamType> genericParams,
const ProtocolGraph &protos) const {
TypeArrayView<GenericTypeParamType> genericParams) const {
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
genericParams, protos, *this);
genericParams, *this);
}
Type RewriteContext::getTypeForTerm(const MutableTerm &term,
TypeArrayView<GenericTypeParamType> genericParams,
const ProtocolGraph &protos) const {
TypeArrayView<GenericTypeParamType> genericParams) const {
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
genericParams, protos, *this);
genericParams, *this);
}
Type RewriteContext::getRelativeTypeForTerm(
const MutableTerm &term, const MutableTerm &prefix,
const ProtocolGraph &protos) const {
const MutableTerm &term, const MutableTerm &prefix) const {
assert(std::equal(prefix.begin(), prefix.end(), term.begin()));
auto genericParam = CanGenericTypeParamType::get(0, 0, Context);
return getTypeForSymbolRange(
term.begin() + prefix.size(), term.end(), genericParam,
{ }, protos, *this);
{ }, *this);
}
/// Concrete type terms are written in terms of generic parameter types that
@@ -426,8 +481,7 @@ RewriteContext::getRelativeTermForType(CanType typeWitness,
Type RewriteContext::getTypeFromSubstitutionSchema(
Type schema, ArrayRef<Term> substitutions,
TypeArrayView<GenericTypeParamType> genericParams,
const MutableTerm &prefix,
const ProtocolGraph &protos) const {
const MutableTerm &prefix) const {
assert(!schema->isTypeParameter() && "Must have a concrete type here");
if (!schema->hasTypeParameter())
@@ -442,13 +496,13 @@ Type RewriteContext::getTypeFromSubstitutionSchema(
if (prefix.empty()) {
// Skip creation of a new MutableTerm in the case where the
// prefix is empty.
return getTypeForTerm(substitution, genericParams, protos);
return getTypeForTerm(substitution, genericParams);
} else {
// Otherwise build a new term by appending the substitution
// to the prefix.
MutableTerm result(prefix);
result.append(substitution);
return getTypeForTerm(result, genericParams, protos);
return getTypeForTerm(result, genericParams);
}
}