mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
RequirementMachine: Move protocol linear order from ProtocolGraph to RewriteContext
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
|
||||
#include "swift/AST/Decl.h"
|
||||
#include "swift/AST/Types.h"
|
||||
#include "ProtocolGraph.h"
|
||||
#include "RequirementMachine.h"
|
||||
#include "RewriteSystem.h"
|
||||
#include "RewriteContext.h"
|
||||
@@ -66,6 +65,77 @@ RewriteContext::RewriteContext(ASTContext &ctx)
|
||||
Debug = parseDebugFlags(debugFlags);
|
||||
}
|
||||
|
||||
const llvm::TinyPtrVector<const ProtocolDecl *> &
|
||||
RewriteContext::getInheritedProtocols(const ProtocolDecl *proto) {
|
||||
auto found = AllInherited.find(proto);
|
||||
if (found != AllInherited.end())
|
||||
return found->second;
|
||||
|
||||
AllInherited.insert(std::make_pair(proto, TinyPtrVector<const ProtocolDecl *>()));
|
||||
|
||||
llvm::SmallDenseSet<const ProtocolDecl *, 4> visited;
|
||||
llvm::TinyPtrVector<const ProtocolDecl *> protos;
|
||||
|
||||
for (auto *inheritedProto : proto->getInheritedProtocols()) {
|
||||
if (!visited.insert(inheritedProto).second)
|
||||
continue;
|
||||
|
||||
protos.push_back(inheritedProto);
|
||||
const auto &allInherited = getInheritedProtocols(inheritedProto);
|
||||
|
||||
for (auto *otherProto : allInherited) {
|
||||
if (!visited.insert(otherProto).second)
|
||||
continue;
|
||||
|
||||
protos.push_back(otherProto);
|
||||
}
|
||||
}
|
||||
|
||||
auto &result = AllInherited[proto];
|
||||
std::swap(protos, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned RewriteContext::getProtocolSupport(
|
||||
const ProtocolDecl *proto) {
|
||||
return getInheritedProtocols(proto).size() + 1;
|
||||
}
|
||||
|
||||
unsigned RewriteContext::getProtocolSupport(
|
||||
ArrayRef<const ProtocolDecl *> protos) {
|
||||
auto found = Support.find(protos);
|
||||
if (found != Support.end())
|
||||
return found->second;
|
||||
|
||||
unsigned result;
|
||||
if (protos.size() == 1) {
|
||||
result = getProtocolSupport(protos[0]);
|
||||
} else {
|
||||
llvm::DenseSet<const ProtocolDecl *> visited;
|
||||
for (const auto *proto : protos) {
|
||||
visited.insert(proto);
|
||||
for (const auto *inheritedProto : getInheritedProtocols(proto))
|
||||
visited.insert(inheritedProto);
|
||||
}
|
||||
|
||||
result = visited.size();
|
||||
}
|
||||
|
||||
Support[protos] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
int RewriteContext::compareProtocols(const ProtocolDecl *lhs,
|
||||
const ProtocolDecl *rhs) {
|
||||
unsigned lhsSupport = getProtocolSupport(lhs);
|
||||
unsigned rhsSupport = getProtocolSupport(rhs);
|
||||
|
||||
if (lhsSupport != rhsSupport)
|
||||
return rhsSupport - lhsSupport;
|
||||
|
||||
return TypeDecl::compare(lhs, rhs);
|
||||
}
|
||||
|
||||
Term RewriteContext::getTermForType(CanType paramType,
|
||||
const ProtocolDecl *proto) {
|
||||
return Term::get(getMutableTermForType(paramType, proto), *this);
|
||||
@@ -167,64 +237,53 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,
|
||||
/// Note that the protocol graph is not part of the caching key; each
|
||||
/// protocol graph is a subgraph of the global inheritance graph, so
|
||||
/// the specific choice of subgraph does not change the result.
|
||||
AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(
|
||||
Symbol symbol, const ProtocolGraph &protos) {
|
||||
AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(Symbol symbol) {
|
||||
auto found = AssocTypes.find(symbol);
|
||||
if (found != AssocTypes.end())
|
||||
return found->second;
|
||||
|
||||
assert(symbol.getKind() == Symbol::Kind::AssociatedType);
|
||||
auto *proto = symbol.getProtocols()[0];
|
||||
auto name = symbol.getName();
|
||||
|
||||
AssociatedTypeDecl *assocType = nullptr;
|
||||
|
||||
// Special case: handle unknown protocols, since they can appear in the
|
||||
// invalid types that getCanonicalTypeInContext() must handle via
|
||||
// concrete substitution; see the definition of getCanonicalTypeInContext()
|
||||
// below for details.
|
||||
if (!protos.isKnownProtocol(proto)) {
|
||||
assert(symbol.getProtocols().size() == 1 &&
|
||||
"Unknown associated type symbol must have a single protocol");
|
||||
assocType = proto->getAssociatedType(name)->getAssociatedTypeAnchor();
|
||||
} else {
|
||||
// An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
|
||||
// P0...Pn and an identifier 'A'.
|
||||
//
|
||||
// We map it back to a AssociatedTypeDecl as follows:
|
||||
//
|
||||
// - For each protocol Pn, look for associated types A in Pn itself,
|
||||
// and all protocols that Pn refines.
|
||||
//
|
||||
// - For each candidate associated type An in protocol Qn where
|
||||
// Pn refines Qn, get the associated type anchor An' defined in
|
||||
// protocol Qn', where Qn refines Qn'.
|
||||
//
|
||||
// - Out of all the candidiate pairs (Qn', An'), pick the one where
|
||||
// the protocol Qn' is the lowest element according to the linear
|
||||
// order defined by TypeDecl::compare().
|
||||
//
|
||||
// The associated type An' is then the canonical associated type
|
||||
// representative of the associated type symbol [P0&...&Pn:A].
|
||||
//
|
||||
for (auto *proto : symbol.getProtocols()) {
|
||||
const auto &info = protos.getProtocolInfo(proto);
|
||||
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
|
||||
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
|
||||
// An associated type symbol [P1&P1&...&Pn:A] has one or more protocols
|
||||
// P0...Pn and an identifier 'A'.
|
||||
//
|
||||
// We map it back to a AssociatedTypeDecl as follows:
|
||||
//
|
||||
// - For each protocol Pn, look for associated types A in Pn itself,
|
||||
// and all protocols that Pn refines.
|
||||
//
|
||||
// - For each candidate associated type An in protocol Qn where
|
||||
// Pn refines Qn, get the associated type anchor An' defined in
|
||||
// protocol Qn', where Qn refines Qn'.
|
||||
//
|
||||
// - Out of all the candidiate pairs (Qn', An'), pick the one where
|
||||
// the protocol Qn' is the lowest element according to the linear
|
||||
// order defined by TypeDecl::compare().
|
||||
//
|
||||
// The associated type An' is then the canonical associated type
|
||||
// representative of the associated type symbol [P0&...&Pn:A].
|
||||
//
|
||||
for (auto *proto : symbol.getProtocols()) {
|
||||
auto checkOtherAssocType = [&](AssociatedTypeDecl *otherAssocType) {
|
||||
otherAssocType = otherAssocType->getAssociatedTypeAnchor();
|
||||
|
||||
if (otherAssocType->getName() == name &&
|
||||
(assocType == nullptr ||
|
||||
TypeDecl::compare(otherAssocType->getProtocol(),
|
||||
assocType->getProtocol()) < 0)) {
|
||||
assocType = otherAssocType;
|
||||
}
|
||||
};
|
||||
|
||||
for (auto *otherAssocType : info.AssociatedTypes) {
|
||||
checkOtherAssocType(otherAssocType);
|
||||
if (otherAssocType->getName() == name &&
|
||||
(assocType == nullptr ||
|
||||
TypeDecl::compare(otherAssocType->getProtocol(),
|
||||
assocType->getProtocol()) < 0)) {
|
||||
assocType = otherAssocType;
|
||||
}
|
||||
};
|
||||
|
||||
for (auto *otherAssocType : info.InheritedAssociatedTypes) {
|
||||
for (auto *otherAssocType : proto->getAssociatedTypeMembers()) {
|
||||
checkOtherAssocType(otherAssocType);
|
||||
}
|
||||
|
||||
for (auto *inheritedProto : getInheritedProtocols(proto)) {
|
||||
for (auto *otherAssocType : inheritedProto->getAssociatedTypeMembers()) {
|
||||
checkOtherAssocType(otherAssocType);
|
||||
}
|
||||
}
|
||||
@@ -244,7 +303,6 @@ AssociatedTypeDecl *RewriteContext::getAssociatedTypeForSymbol(
|
||||
template<typename Iter>
|
||||
Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
|
||||
TypeArrayView<GenericTypeParamType> genericParams,
|
||||
const ProtocolGraph &protos,
|
||||
const RewriteContext &ctx) {
|
||||
Type result = root;
|
||||
|
||||
@@ -319,7 +377,7 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
|
||||
// We should have a resolved type at this point.
|
||||
auto *assocType =
|
||||
const_cast<RewriteContext &>(ctx)
|
||||
.getAssociatedTypeForSymbol(symbol, protos);
|
||||
.getAssociatedTypeForSymbol(symbol);
|
||||
result = DependentMemberType::get(result, assocType);
|
||||
}
|
||||
|
||||
@@ -327,28 +385,25 @@ Type getTypeForSymbolRange(Iter begin, Iter end, Type root,
|
||||
}
|
||||
|
||||
Type RewriteContext::getTypeForTerm(Term term,
|
||||
TypeArrayView<GenericTypeParamType> genericParams,
|
||||
const ProtocolGraph &protos) const {
|
||||
TypeArrayView<GenericTypeParamType> genericParams) const {
|
||||
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
|
||||
genericParams, protos, *this);
|
||||
genericParams, *this);
|
||||
}
|
||||
|
||||
Type RewriteContext::getTypeForTerm(const MutableTerm &term,
|
||||
TypeArrayView<GenericTypeParamType> genericParams,
|
||||
const ProtocolGraph &protos) const {
|
||||
TypeArrayView<GenericTypeParamType> genericParams) const {
|
||||
return getTypeForSymbolRange(term.begin(), term.end(), Type(),
|
||||
genericParams, protos, *this);
|
||||
genericParams, *this);
|
||||
}
|
||||
|
||||
Type RewriteContext::getRelativeTypeForTerm(
|
||||
const MutableTerm &term, const MutableTerm &prefix,
|
||||
const ProtocolGraph &protos) const {
|
||||
const MutableTerm &term, const MutableTerm &prefix) const {
|
||||
assert(std::equal(prefix.begin(), prefix.end(), term.begin()));
|
||||
|
||||
auto genericParam = CanGenericTypeParamType::get(0, 0, Context);
|
||||
return getTypeForSymbolRange(
|
||||
term.begin() + prefix.size(), term.end(), genericParam,
|
||||
{ }, protos, *this);
|
||||
{ }, *this);
|
||||
}
|
||||
|
||||
/// Concrete type terms are written in terms of generic parameter types that
|
||||
@@ -426,8 +481,7 @@ RewriteContext::getRelativeTermForType(CanType typeWitness,
|
||||
Type RewriteContext::getTypeFromSubstitutionSchema(
|
||||
Type schema, ArrayRef<Term> substitutions,
|
||||
TypeArrayView<GenericTypeParamType> genericParams,
|
||||
const MutableTerm &prefix,
|
||||
const ProtocolGraph &protos) const {
|
||||
const MutableTerm &prefix) const {
|
||||
assert(!schema->isTypeParameter() && "Must have a concrete type here");
|
||||
|
||||
if (!schema->hasTypeParameter())
|
||||
@@ -442,13 +496,13 @@ Type RewriteContext::getTypeFromSubstitutionSchema(
|
||||
if (prefix.empty()) {
|
||||
// Skip creation of a new MutableTerm in the case where the
|
||||
// prefix is empty.
|
||||
return getTypeForTerm(substitution, genericParams, protos);
|
||||
return getTypeForTerm(substitution, genericParams);
|
||||
} else {
|
||||
// Otherwise build a new term by appending the substitution
|
||||
// to the prefix.
|
||||
MutableTerm result(prefix);
|
||||
result.append(substitution);
|
||||
return getTypeForTerm(result, genericParams, protos);
|
||||
return getTypeForTerm(result, genericParams);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user