Files
swift-mirror/lib/AST/RequirementMachine/ProtocolGraph.cpp
Slava Pestov e36b95d620 RequirementMachine: Split up list of associated types in ProtocolGraph
Store the protocol's direct associated types separately from the inherited
associated types, since in a couple of places we only need the direct
associated types.

Also, factor out a new ProtocolGraph::compute() method that does all the
steps in the right order.
2021-07-14 00:11:37 -04:00

213 lines
6.9 KiB
C++

//===--- ProtocolGraph.cpp - Collect information about protocols ----------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2021 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
#include "ProtocolGraph.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Requirement.h"
using namespace swift;
using namespace rewriting;
/// Adds information about all protocols transitvely referenced from
/// \p reqs.
void ProtocolGraph::visitRequirements(ArrayRef<Requirement> reqs) {
for (auto req : reqs) {
if (req.getKind() == RequirementKind::Conformance) {
addProtocol(req.getProtocolDecl());
}
}
}
/// Return true if we know about this protocol.
bool ProtocolGraph::isKnownProtocol(const ProtocolDecl *proto) const {
return Info.count(proto) > 0;
}
/// Look up information about a known protocol.
const ProtocolInfo &ProtocolGraph::getProtocolInfo(
const ProtocolDecl *proto) const {
auto found = Info.find(proto);
assert(found != Info.end());
return found->second;
}
/// Record information about a protocol if we have no seen it yet.
void ProtocolGraph::addProtocol(const ProtocolDecl *proto) {
if (Info.count(proto) > 0)
return;
Info[proto] = {proto->getInheritedProtocols(),
proto->getAssociatedTypeMembers(),
proto->getRequirementSignature()};
Protocols.push_back(proto);
}
/// Record information about all protocols transtively referenced
/// from protocol requirement signatures.
void ProtocolGraph::computeTransitiveClosure() {
unsigned i = 0;
while (i < Protocols.size()) {
auto *proto = Protocols[i++];
visitRequirements(getProtocolInfo(proto).Requirements);
}
}
/// See ProtocolGraph::compareProtocols() for the definition of this linear
/// order.
void ProtocolGraph::computeLinearOrder() {
for (const auto *proto : Protocols) {
(void) computeProtocolDepth(proto);
}
std::sort(
Protocols.begin(), Protocols.end(),
[&](const ProtocolDecl *lhs,
const ProtocolDecl *rhs) -> bool {
const auto &lhsInfo = getProtocolInfo(lhs);
const auto &rhsInfo = getProtocolInfo(rhs);
// protocol Base {} // depth 1
// protocol Derived : Base {} // depth 2
//
// Derived < Base in the linear order.
if (lhsInfo.Depth != rhsInfo.Depth)
return lhsInfo.Depth > rhsInfo.Depth;
return TypeDecl::compare(lhs, rhs) < 0;
});
for (unsigned i : indices(Protocols)) {
Info[Protocols[i]].Index = i;
}
if (Debug) {
for (const auto *proto : Protocols) {
const auto &info = getProtocolInfo(proto);
llvm::dbgs() << "@ Protocol " << proto->getName()
<< " Depth=" << info.Depth
<< " Index=" << info.Index << "\n";
}
}
}
/// Update each ProtocolInfo's AssociatedTypes vector to add all associated
/// types from all transitively inherited protocols.
void ProtocolGraph::computeInheritedAssociatedTypes() {
// Visit protocols in reverse order, so that if P inherits from Q and
// Q inherits from R, we first visit R, then Q, then P, ensuring that
// R's associated types are added to P's list, etc.
for (const auto *proto : llvm::reverse(Protocols)) {
auto &info = Info[proto];
for (const auto *inherited : info.AllInherited) {
for (auto *inheritedType : getProtocolInfo(inherited).AssociatedTypes) {
info.InheritedAssociatedTypes.push_back(inheritedType);
}
}
}
}
// Update each protocol's AllInherited vector to add all transitively
// inherited protocols.
void ProtocolGraph::computeInheritedProtocols() {
// Visit protocols in reverse order, so that if P inherits from Q and
// Q inherits from R, we first visit R, then Q, then P, ensuring that
// R's inherited protocols are added to P's list, etc.
for (const auto *proto : llvm::reverse(Protocols)) {
auto &info = Info[proto];
// We might inherit the same protocol multiple times due to diamond
// inheritance, so make sure we only add each protocol once.
llvm::SmallDenseSet<const ProtocolDecl *, 4> visited;
visited.insert(proto);
for (const auto *inherited : info.Inherited) {
// Add directly-inherited protocols.
if (!visited.insert(inherited).second)
continue;
info.AllInherited.push_back(inherited);
// Add indirectly-inherited protocols.
for (auto *inheritedType : getProtocolInfo(inherited).AllInherited) {
if (!visited.insert(inheritedType).second)
continue;
info.AllInherited.push_back(inheritedType);
}
}
}
}
/// Recursively compute the 'depth' of a protocol, which is inductively defined
/// as one greater than the depth of all inherited protocols, with a protocol
/// that does not inherit any other protocol having a depth of one.
unsigned ProtocolGraph::computeProtocolDepth(const ProtocolDecl *proto) {
auto &info = Info[proto];
if (info.Mark) {
// Already computed, or we have a cycle. Cycles are diagnosed
// elsewhere in the type checker, so we don't have to do
// anything here.
return info.Depth;
}
info.Mark = true;
unsigned depth = 0;
for (auto *inherited : info.Inherited) {
unsigned inheritedDepth = computeProtocolDepth(inherited);
depth = std::max(inheritedDepth, depth);
}
depth++;
info.Depth = depth;
return depth;
}
/// Compute everything in the right order.
void ProtocolGraph::compute() {
computeTransitiveClosure();
computeLinearOrder();
computeInheritedProtocols();
computeInheritedAssociatedTypes();
}
/// Defines a linear order with the property that if a protocol P inherits
/// from another protocol Q, then P < Q. (The converse cannot be true, since
/// this is a linear order.)
///
/// We first compare the 'depth' of a protocol, which is defined in
/// ProtocolGraph::computeProtocolDepth() above.
///
/// If two protocols have the same depth, the tie is broken by the standard
/// TypeDecl::compare().
int ProtocolGraph::compareProtocols(const ProtocolDecl *lhs,
const ProtocolDecl *rhs) const {
const auto &infoLHS = getProtocolInfo(lhs);
const auto &infoRHS = getProtocolInfo(rhs);
return infoLHS.Index - infoRHS.Index;
}
/// Returns if \p thisProto transitively inherits from \p otherProto.
///
/// The result is false if the two protocols are equal.
bool ProtocolGraph::inheritsFrom(const ProtocolDecl *thisProto,
const ProtocolDecl *otherProto) const {
const auto &info = getProtocolInfo(thisProto);
return std::find(info.AllInherited.begin(),
info.AllInherited.end(),
otherProto) != info.AllInherited.end();
}