//===--- RewriteSystem.cpp - Generics with term rewriting -----------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2021 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// #include "swift/AST/Decl.h" #include "swift/AST/Types.h" #include "swift/AST/TypeWalker.h" #include "llvm/ADT/FoldingSet.h" #include "llvm/Support/raw_ostream.h" #include #include #include "RewriteContext.h" #include "RewriteSystem.h" using namespace swift; using namespace rewriting; /// If this is a rule of the form T.[p] => T where [p] is a property symbol, /// returns the symbol. Otherwise, returns None. /// /// Note that this is meant to be used with a simplified rewrite system, /// where the right hand sides of rules are canonical, since this also means /// that T is canonical. Optional Rule::isPropertyRule() const { auto property = LHS.back(); if (!property.isProperty()) return None; if (LHS.size() - 1 != RHS.size()) return None; if (!std::equal(RHS.begin(), RHS.end(), LHS.begin())) return None; return property; } /// If this is a rule of the form T.[P] => T where [P] is a protocol symbol, /// return the protocol P, otherwise return nullptr. const ProtocolDecl *Rule::isProtocolConformanceRule() const { if (auto property = isPropertyRule()) { if (property->getKind() == Symbol::Kind::Protocol) return property->getProtocol(); } return nullptr; } /// If this is a rule of the form T.[concrete: C : P] => T where /// [concrete: C : P] is a concrete conformance symbol, return the protocol P, /// otherwise return nullptr. const ProtocolDecl *Rule::isAnyConformanceRule() const { if (auto property = isPropertyRule()) { switch (property->getKind()) { case Symbol::Kind::ConcreteConformance: case Symbol::Kind::Protocol: return property->getProtocol(); case Symbol::Kind::Layout: case Symbol::Kind::Superclass: case Symbol::Kind::ConcreteType: return nullptr; case Symbol::Kind::Name: case Symbol::Kind::AssociatedType: case Symbol::Kind::GenericParam: break; } llvm_unreachable("Bad symbol kind"); } return nullptr; } /// If this is a rule of the form [P].[P] => [P] where [P] is a protocol /// symbol, return true, otherwise return false. bool Rule::isIdentityConformanceRule() const { return (LHS.size() == 2 && RHS.size() == 1 && LHS[0] == RHS[0] && LHS[0] == LHS[1] && LHS[0].getKind() == Symbol::Kind::Protocol); } /// If this is a rule of the form [P].[Q] => [P] where [P] and [Q] are /// protocol symbols, return true, otherwise return false. bool Rule::isProtocolRefinementRule() const { if (LHS.size() == 2 && RHS.size() == 1 && LHS[0] == RHS[0] && LHS[0].getKind() == Symbol::Kind::Protocol && (LHS[1].getKind() == Symbol::Kind::Protocol || LHS[1].getKind() == Symbol::Kind::ConcreteConformance) && LHS[0] != LHS[1]) { // A protocol refinement rule must be from a directly-stated // inheritance clause entry. It can only become redundant if it is // written in terms of other protocol refinement rules; otherwise, it // must appear in the protocol's requirement signature. // // See RewriteSystem::isValidRefinementPath() for an explanation. auto *proto = LHS[0].getProtocol(); auto *otherProto = LHS[1].getProtocol(); auto inherited = proto->getInheritedProtocols(); return (std::find(inherited.begin(), inherited.end(), otherProto) != inherited.end()); } return false; } /// A protocol typealias rule takes one of the following two forms, /// where T is a name symbol: /// /// 1) [P].T => X /// 2) [P].T.[concrete: C] => [P].T /// /// The first case is where the protocol's underlying type is another /// type parameter. The second case is where the protocol's underlying /// type is a concrete type. /// /// In the first case, X must be fully resolved, that is, it must not /// contain any name symbols. /// /// If this rule is a protocol typealias rule, returns its name. Otherwise /// returns None. Optional Rule::isProtocolTypeAliasRule() const { if (LHS.size() != 2 && LHS.size() != 3) return None; if (LHS[0].getKind() != Symbol::Kind::Protocol || LHS[1].getKind() != Symbol::Kind::Name) return None; if (LHS.size() == 2) { // This is the case where the underlying type is a type parameter. // // We shouldn't have unresolved symbols on the right hand side; // they should have been simplified away. if (RHS.containsUnresolvedSymbols()) return None; } else { // This is the case where the underlying type is concrete. assert(LHS.size() == 3); auto prop = isPropertyRule(); if (!prop || prop->getKind() != Symbol::Kind::ConcreteType) return None; } return LHS[1].getName(); } /// Returns the length of the left hand side. unsigned Rule::getDepth() const { auto result = LHS.size(); if (LHS.back().hasSubstitutions()) { for (auto substitution : LHS.back().getSubstitutions()) { result = std::max(result, substitution.size()); } } return result; } /// Returns the nesting depth of the concrete symbol at the end of the /// left hand side, or 0 if there isn't one. unsigned Rule::getNesting() const { if (LHS.back().hasSubstitutions()) { auto type = LHS.back().getConcreteType(); struct Walker : TypeWalker { unsigned Nesting = 0; unsigned MaxNesting = 0; Action walkToTypePre(Type ty) override { ++Nesting; MaxNesting = std::max(Nesting, MaxNesting); return Action::Continue; } Action walkToTypePost(Type ty) override { --Nesting; return Action::Continue; } }; Walker walker; type.walk(walker); return walker.MaxNesting; } return 0; } /// Linear order on rules; compares LHS followed by RHS. Optional Rule::compare(const Rule &other, RewriteContext &ctx) const { Optional compare = LHS.compare(other.LHS, ctx); if (!compare.hasValue() || *compare != 0) return compare; return RHS.compare(other.RHS, ctx); } void Rule::dump(llvm::raw_ostream &out) const { out << LHS << " => " << RHS; if (Permanent) out << " [permanent]"; if (Explicit) out << " [explicit]"; if (LHSSimplified) out << " [lhs↓]"; if (RHSSimplified) out << " [rhs↓]"; if (SubstitutionSimplified) out << " [subst↓]"; if (Redundant) out << " [redundant]"; if (Conflicting) out << " [conflicting]"; } RewriteSystem::RewriteSystem(RewriteContext &ctx) : Context(ctx), Debug(ctx.getDebugOptions()) { Initialized = 0; Complete = 0; Minimized = 0; RecordLoops = 0; } RewriteSystem::~RewriteSystem() { Trie.updateHistograms(Context.RuleTrieHistogram, Context.RuleTrieRootHistogram); } void RewriteSystem::initialize( bool recordLoops, ArrayRef protos, std::vector> &&permanentRules, std::vector> &&requirementRules) { assert(!Initialized); Initialized = 1; RecordLoops = recordLoops; Protos = protos; for (const auto &rule : permanentRules) addPermanentRule(rule.first, rule.second); for (const auto &rule : requirementRules) addExplicitRule(rule.first, rule.second); } /// Reduce a term by applying all rewrite rules until fixed point. /// /// If \p path is non-null, records the series of rewrite steps taken. bool RewriteSystem::simplify(MutableTerm &term, RewritePath *path) const { bool changed = false; MutableTerm original; RewritePath subpath; bool debug = false; if (Debug.contains(DebugFlags::Simplify)) { original = term; debug = true; } while (true) { bool tryAgain = false; auto from = term.begin(); auto end = term.end(); while (from < end) { auto ruleID = Trie.find(from, end); if (ruleID) { const auto &rule = getRule(*ruleID); auto to = from + rule.getLHS().size(); assert(std::equal(from, to, rule.getLHS().begin())); unsigned startOffset = (unsigned)(from - term.begin()); unsigned endOffset = term.size() - rule.getLHS().size() - startOffset; term.rewriteSubTerm(from, to, rule.getRHS()); if (path || debug) { subpath.add(RewriteStep::forRewriteRule(startOffset, endOffset, *ruleID, /*inverse=*/false)); } changed = true; tryAgain = true; break; } ++from; } if (!tryAgain) break; } if (debug) { if (changed) { llvm::dbgs() << "= Simplified " << original << " to " << term << " via "; subpath.dump(llvm::dbgs(), original, *this); llvm::dbgs() << "\n"; } else { llvm::dbgs() << "= Irreducible term: " << term << "\n"; } } if (path != nullptr) { assert(changed != subpath.empty()); path->append(subpath); } return changed; } /// Adds a rewrite rule, returning true if the new rule was non-trivial. /// /// If both sides simplify to the same term, the rule is trivial and discarded, /// and this method returns false. /// /// If \p path is non-null, the new rule is derived from existing rules in the /// rewrite system; the path records a series of rewrite steps which transform /// \p lhs to \p rhs. bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs, const RewritePath *path) { // FIXME: // assert(!Complete || path != nullptr && // "Rules added by completion must have a path"); assert(!lhs.empty()); assert(!rhs.empty()); if (Debug.contains(DebugFlags::Add)) { llvm::dbgs() << "# Adding rule " << lhs << " == " << rhs << "\n\n"; } // Now simplify both sides as much as possible with the rules we have so far. // // This avoids unnecessary work in the completion algorithm. RewritePath lhsPath; RewritePath rhsPath; simplify(lhs, &lhsPath); simplify(rhs, &rhsPath); RewritePath loop; if (path) { // Produce a path from the simplified lhs to the simplified rhs. // (1) First, apply lhsPath in reverse to produce the original lhs. lhsPath.invert(); loop.append(lhsPath); // (2) Now, apply the path from the original lhs to the original rhs // given to us by the completion procedure. loop.append(*path); // (3) Finally, apply rhsPath to produce the simplified rhs, which // is the same as the simplified lhs. loop.append(rhsPath); } // If the left hand side and right hand side are already equivalent, we're // done. Optional result = lhs.compare(rhs, Context); if (*result == 0) { // If this rule is a consequence of existing rules, add a homotopy // generator. if (path) { // We already have a loop, since the simplified lhs is identical to the // simplified rhs. recordRewriteLoop(lhs, loop); if (Debug.contains(DebugFlags::Add)) { llvm::dbgs() << "## Recorded trivial loop at " << lhs << ": "; loop.dump(llvm::dbgs(), lhs, *this); llvm::dbgs() << "\n\n"; } } return false; } // Orient the two terms so that the left hand side is greater than the // right hand side. if (*result < 0) { std::swap(lhs, rhs); loop.invert(); } assert(*lhs.compare(rhs, Context) > 0); if (Debug.contains(DebugFlags::Add)) { llvm::dbgs() << "## Simplified and oriented rule " << lhs << " => " << rhs << "\n\n"; } unsigned newRuleID = Rules.size(); Rules.emplace_back(Term::get(lhs, Context), Term::get(rhs, Context)); if (path) { // We have a rewrite path from the simplified lhs to the simplified rhs; // add a rewrite step applying the new rule in reverse to close the loop. loop.add(RewriteStep::forRewriteRule(/*startOffset=*/0, /*endOffset=*/0, newRuleID, /*inverse=*/true)); recordRewriteLoop(lhs, loop); if (Debug.contains(DebugFlags::Add)) { llvm::dbgs() << "## Recorded non-trivial loop at " << lhs << ": "; loop.dump(llvm::dbgs(), lhs, *this); llvm::dbgs() << "\n\n"; } } auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), newRuleID); if (oldRuleID) { llvm::errs() << "Duplicate rewrite rule!\n"; const auto &oldRule = getRule(*oldRuleID); llvm::errs() << "Old rule #" << *oldRuleID << ": "; oldRule.dump(llvm::errs()); llvm::errs() << "\nTrying to replay what happened when I simplified this term:\n"; Debug |= DebugFlags::Simplify; MutableTerm term = lhs; simplify(lhs); dump(llvm::errs()); abort(); } // Tell the caller that we added a new rule. return true; } /// Add a new rule, marking it permanent. bool RewriteSystem::addPermanentRule(MutableTerm lhs, MutableTerm rhs) { bool added = addRule(std::move(lhs), std::move(rhs)); if (added) Rules.back().markPermanent(); return added; } /// Add a new rule, marking it explicit. bool RewriteSystem::addExplicitRule(MutableTerm lhs, MutableTerm rhs) { bool added = addRule(std::move(lhs), std::move(rhs)); if (added) Rules.back().markExplicit(); return added; } /// Delete any rules whose left hand sides can be reduced by other rules. /// /// Must be run after the completion procedure, since the deletion of /// rules is only valid to perform if the rewrite system is confluent. void RewriteSystem::simplifyLeftHandSides() { assert(Complete); for (unsigned ruleID = 0, e = Rules.size(); ruleID < e; ++ruleID) { auto &rule = getRule(ruleID); if (rule.isLHSSimplified()) continue; // First, see if the left hand side of this rule can be reduced using // some other rule. auto lhs = rule.getLHS(); auto begin = lhs.begin(); auto end = lhs.end(); while (begin < end) { if (auto otherRuleID = Trie.find(begin++, end)) { // A rule does not obsolete itself. if (*otherRuleID == ruleID) continue; // Ignore other deleted rules. const auto &otherRule = getRule(*otherRuleID); if (otherRule.isLHSSimplified()) continue; if (Debug.contains(DebugFlags::Completion)) { const auto &otherRule = getRule(*otherRuleID); llvm::dbgs() << "$ Deleting rule " << rule << " because " << "its left hand side contains " << otherRule << "\n"; } rule.markLHSSimplified(); break; } } } } /// Reduce the right hand sides of all remaining rules as much as /// possible. /// /// Must be run after the completion procedure, since the deletion of /// rules is only valid to perform if the rewrite system is confluent. void RewriteSystem::simplifyRightHandSides() { assert(Complete); for (unsigned ruleID = 0, e = Rules.size(); ruleID < e; ++ruleID) { auto &rule = getRule(ruleID); if (rule.isRHSSimplified()) continue; // Now, try to reduce the right hand side. RewritePath rhsPath; MutableTerm rhs(rule.getRHS()); if (!simplify(rhs, &rhsPath)) continue; auto lhs = rule.getLHS(); // We're adding a new rule, so the old rule won't apply anymore. rule.markRHSSimplified(); unsigned newRuleID = Rules.size(); // Add a new rule with the simplified right hand side. Rules.emplace_back(lhs, Term::get(rhs, Context)); auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), newRuleID); assert(oldRuleID == ruleID); (void) oldRuleID; // Produce a loop at the original lhs. RewritePath loop; // (1) First, apply the original rule to produce the original rhs. loop.add(RewriteStep::forRewriteRule(/*startOffset=*/0, /*endOffset=*/0, ruleID, /*inverse=*/false)); // (2) Next, apply rhsPath to produce the simplified rhs. loop.append(rhsPath); // (3) Finally, apply the new rule in reverse to produce the original lhs. loop.add(RewriteStep::forRewriteRule(/*startOffset=*/0, /*endOffset=*/0, newRuleID, /*inverse=*/true)); if (Debug.contains(DebugFlags::Completion)) { llvm::dbgs() << "$ Right hand side simplification recorded a loop at "; llvm::dbgs() << lhs << ": "; loop.dump(llvm::dbgs(), MutableTerm(lhs), *this); llvm::dbgs() << "\n"; } recordRewriteLoop(MutableTerm(lhs), loop); } } /// When minimizing a generic signature, we only care about loops where the /// basepoint is a generic parameter symbol. /// /// When minimizing protocol requirement signatures, we only care about loops /// where the basepoint is a protocol symbol or associated type symbol whose /// protocol is part of the connected component. /// /// All other loops can be discarded since they do not encode redundancies /// that are relevant to us. bool RewriteSystem::isInMinimizationDomain(const ProtocolDecl *proto) const { assert(Protos.empty() || proto != nullptr); if (proto == nullptr && Protos.empty()) return true; if (std::find(Protos.begin(), Protos.end(), proto) != Protos.end()) return true; return false; } void RewriteSystem::recordRewriteLoop(MutableTerm basepoint, RewritePath path) { RewriteLoop loop(basepoint, path); loop.verify(*this); if (!RecordLoops) return; // Ignore the rewrite rule if it is not part of our minimization domain. if (!isInMinimizationDomain(basepoint.getRootProtocol())) return; Loops.push_back(loop); } void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const { #ifndef NDEBUG #define ASSERT_RULE(expr) \ if (!(expr)) { \ llvm::errs() << "&&& Malformed rewrite rule: " << rule << "\n"; \ llvm::errs() << "&&& " << #expr << "\n\n"; \ dump(llvm::errs()); \ assert(expr); \ } for (const auto &rule : Rules) { const auto &lhs = rule.getLHS(); const auto &rhs = rule.getRHS(); for (unsigned index : indices(lhs)) { auto symbol = lhs[index]; if (index != lhs.size() - 1) { ASSERT_RULE(symbol.getKind() != Symbol::Kind::Layout); ASSERT_RULE(!symbol.hasSubstitutions()); } if (index != 0) { ASSERT_RULE(symbol.getKind() != Symbol::Kind::GenericParam); } // Completion can produce rules like [P:T].[Q].[R] => [P:T].[Q] // which are immediately simplified away. if (!rule.isLHSSimplified() && index != 0 && index != lhs.size() - 1) { ASSERT_RULE(symbol.getKind() != Symbol::Kind::Protocol); } } for (unsigned index : indices(rhs)) { auto symbol = rhs[index]; // RHS-simplified rules might have unresolved name symbols on the // right hand side. Also, completion can introduce rules of the // form T.X.[concrete: C] => T.X, where T is some resolved term, // and X is a name symbol for a protocol typealias. if (!rule.isLHSSimplified() && !rule.isRHSSimplified() && !(rule.isPropertyRule() && index == rhs.size() - 1)) { // This is only true if the input requirements were valid. if (policy == DisallowInvalidRequirements) { ASSERT_RULE(symbol.getKind() != Symbol::Kind::Name); } else { // FIXME: Assert that we diagnosed an error } } ASSERT_RULE(symbol.getKind() != Symbol::Kind::Layout); ASSERT_RULE(!symbol.hasSubstitutions()); if (index != 0) { ASSERT_RULE(symbol.getKind() != Symbol::Kind::GenericParam); } // Completion can produce rules like [P:T].[Q].[R] => [P:T].[Q] // which are immediately simplified away. if (!rule.isRHSSimplified() && index != 0) { ASSERT_RULE(symbol.getKind() != Symbol::Kind::Protocol); } } auto lhsDomain = lhs.getRootProtocol(); auto rhsDomain = rhs.getRootProtocol(); ASSERT_RULE(lhsDomain == rhsDomain); } #undef ASSERT_RULE #endif } void RewriteSystem::dump(llvm::raw_ostream &out) const { out << "Rewrite system: {\n"; for (const auto &rule : Rules) { out << "- " << rule << "\n"; } out << "}\n"; if (!Relations.empty()) { out << "Relations: {\n"; for (const auto &relation : Relations) { out << "- " << relation.first << " =>> " << relation.second << "\n"; } out << "}\n"; } if (!Differences.empty()) { out << "Type differences: {\n"; for (const auto &difference : Differences) { difference.dump(out); out << "\n"; } out << "}\n"; } if (!Loops.empty()) { out << "Rewrite loops: {\n"; for (unsigned loopID : indices(Loops)) { const auto &loop = Loops[loopID]; if (loop.isDeleted()) continue; out << "- (#" << loopID << ") "; loop.dump(out, *this); out << "\n"; } } out << "}\n"; }