mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
The PropertyMap wants to use a try to map to PropertyBags, and it needs longest-suffix rather than shortest-prefix matching. Implement both by making Trie into a template class with two parameters; the ValueType, and the MatchKind. Note that while the MatchKind encodes the longest vs shortest match part, matching on the prefix vs suffix of a term is up to the caller, since the find() and insert() methods of Trie take a pair of iterators, so simply passing in begin()/end() vs rbegin()/rend() selects the direction.
283 lines
7.6 KiB
C++
283 lines
7.6 KiB
C++
//===--- RewriteSystem.cpp - Generics with term rewriting -----------------===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2021 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "swift/AST/Decl.h"
|
|
#include "swift/AST/Types.h"
|
|
#include "llvm/ADT/FoldingSet.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <algorithm>
|
|
#include <vector>
|
|
#include "ProtocolGraph.h"
|
|
#include "RewriteContext.h"
|
|
#include "RewriteSystem.h"
|
|
|
|
using namespace swift;
|
|
using namespace rewriting;
|
|
|
|
RewriteSystem::RewriteSystem(RewriteContext &ctx)
|
|
: Context(ctx) {
|
|
DebugSimplify = false;
|
|
DebugAdd = false;
|
|
DebugMerge = false;
|
|
DebugCompletion = false;
|
|
}
|
|
|
|
RewriteSystem::~RewriteSystem() {
|
|
Trie.updateHistograms(Context.RuleTrieHistogram,
|
|
Context.RuleTrieRootHistogram);
|
|
}
|
|
|
|
void Rule::dump(llvm::raw_ostream &out) const {
|
|
out << LHS << " => " << RHS;
|
|
if (deleted)
|
|
out << " [deleted]";
|
|
}
|
|
|
|
void RewriteSystem::initialize(
|
|
std::vector<std::pair<MutableTerm, MutableTerm>> &&rules,
|
|
ProtocolGraph &&graph) {
|
|
Protos = graph;
|
|
|
|
for (const auto &rule : rules)
|
|
addRule(rule.first, rule.second);
|
|
}
|
|
|
|
Symbol RewriteSystem::simplifySubstitutionsInSuperclassOrConcreteSymbol(
|
|
Symbol symbol) const {
|
|
return symbol.transformConcreteSubstitutions(
|
|
[&](Term term) -> Term {
|
|
MutableTerm mutTerm(term);
|
|
if (!simplify(mutTerm))
|
|
return term;
|
|
|
|
return Term::get(mutTerm, Context);
|
|
}, Context);
|
|
}
|
|
|
|
bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
|
|
assert(!lhs.empty());
|
|
assert(!rhs.empty());
|
|
|
|
// First, simplify terms appearing inside concrete substitutions before
|
|
// doing anything else.
|
|
if (lhs.back().isSuperclassOrConcreteType())
|
|
lhs.back() = simplifySubstitutionsInSuperclassOrConcreteSymbol(lhs.back());
|
|
else if (rhs.back().isSuperclassOrConcreteType())
|
|
rhs.back() = simplifySubstitutionsInSuperclassOrConcreteSymbol(rhs.back());
|
|
|
|
// Now simplify both sides as much as possible with the rules we have so far.
|
|
//
|
|
// This avoids unnecessary work in the completion algorithm.
|
|
simplify(lhs);
|
|
simplify(rhs);
|
|
|
|
// If the left hand side and right hand side are already equivalent, we're
|
|
// done.
|
|
int result = lhs.compare(rhs, Protos);
|
|
if (result == 0)
|
|
return false;
|
|
|
|
// Orient the two terms so that the left hand side is greater than the
|
|
// right hand side.
|
|
if (result < 0)
|
|
std::swap(lhs, rhs);
|
|
|
|
assert(lhs.compare(rhs, Protos) > 0);
|
|
|
|
if (DebugAdd) {
|
|
llvm::dbgs() << "# Adding rule " << lhs << " => " << rhs << "\n";
|
|
}
|
|
|
|
unsigned i = Rules.size();
|
|
Rules.emplace_back(Term::get(lhs, Context), Term::get(rhs, Context));
|
|
auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), i);
|
|
if (oldRuleID) {
|
|
llvm::errs() << "Duplicate rewrite rule!\n";
|
|
const auto &oldRule = Rules[*oldRuleID];
|
|
llvm::errs() << "Old rule #" << *oldRuleID << ": ";
|
|
oldRule.dump(llvm::errs());
|
|
llvm::errs() << "\nTrying to replay what happened when I simplified this term:\n";
|
|
DebugSimplify = true;
|
|
MutableTerm term = lhs;
|
|
simplify(lhs);
|
|
|
|
abort();
|
|
}
|
|
|
|
// Check if we have a rule of the form
|
|
//
|
|
// X.[P1:T] => X.[P2:T]
|
|
//
|
|
// If so, record this rule for later. We'll try to merge the associated
|
|
// types in RewriteSystem::processMergedAssociatedTypes().
|
|
if (lhs.size() == rhs.size() &&
|
|
std::equal(lhs.begin(), lhs.end() - 1, rhs.begin()) &&
|
|
lhs.back().getKind() == Symbol::Kind::AssociatedType &&
|
|
rhs.back().getKind() == Symbol::Kind::AssociatedType &&
|
|
lhs.back().getName() == rhs.back().getName()) {
|
|
MergedAssociatedTypes.emplace_back(lhs, rhs);
|
|
}
|
|
|
|
// Since we added a new rule, we have to check for overlaps between the
|
|
// new rule and all existing rules.
|
|
for (unsigned j : indices(Rules)) {
|
|
// A rule does not overlap with itself.
|
|
if (i == j)
|
|
continue;
|
|
|
|
// We don't have to check for overlap with deleted rules.
|
|
if (Rules[j].isDeleted())
|
|
continue;
|
|
|
|
// The overlap check is not commutative so we have to check both
|
|
// directions.
|
|
Worklist.emplace_back(i, j);
|
|
Worklist.emplace_back(j, i);
|
|
|
|
if (DebugCompletion) {
|
|
llvm::dbgs() << "$ Queued up (" << i << ", " << j << ") and ";
|
|
llvm::dbgs() << "(" << j << ", " << i << ")\n";
|
|
}
|
|
}
|
|
|
|
// Tell the caller that we added a new rule.
|
|
return true;
|
|
}
|
|
|
|
/// Reduce a term by applying all rewrite rules until fixed point.
|
|
bool RewriteSystem::simplify(MutableTerm &term) const {
|
|
bool changed = false;
|
|
|
|
if (DebugSimplify) {
|
|
llvm::dbgs() << "= Term " << term << "\n";
|
|
}
|
|
|
|
while (true) {
|
|
bool tryAgain = false;
|
|
|
|
auto from = term.begin();
|
|
auto end = term.end();
|
|
while (from < end) {
|
|
auto ruleID = Trie.find(from, end);
|
|
if (ruleID) {
|
|
const auto &rule = Rules[*ruleID];
|
|
if (!rule.isDeleted()) {
|
|
if (DebugSimplify) {
|
|
llvm::dbgs() << "== Rule #" << *ruleID << ": " << rule << "\n";
|
|
}
|
|
|
|
auto to = from + rule.getLHS().size();
|
|
assert(std::equal(from, to, rule.getLHS().begin()));
|
|
|
|
term.rewriteSubTerm(from, to, rule.getRHS());
|
|
|
|
if (DebugSimplify) {
|
|
llvm::dbgs() << "=== Result " << term << "\n";
|
|
}
|
|
|
|
changed = true;
|
|
tryAgain = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
++from;
|
|
}
|
|
|
|
if (!tryAgain)
|
|
break;
|
|
}
|
|
|
|
return changed;
|
|
}
|
|
|
|
void RewriteSystem::simplifyRightHandSides() {
|
|
for (auto &rule : Rules) {
|
|
if (rule.isDeleted())
|
|
continue;
|
|
|
|
MutableTerm rhs(rule.getRHS());
|
|
if (!simplify(rhs))
|
|
continue;
|
|
|
|
rule = Rule(rule.getLHS(), Term::get(rhs, Context));
|
|
}
|
|
|
|
#ifndef NDEBUG
|
|
|
|
#define ASSERT_RULE(expr) \
|
|
if (!(expr)) { \
|
|
llvm::errs() << "&&& Malformed rewrite rule: " << rule << "\n"; \
|
|
llvm::errs() << "&&& " << #expr << "\n\n"; \
|
|
dump(llvm::errs()); \
|
|
assert(expr); \
|
|
}
|
|
|
|
for (const auto &rule : Rules) {
|
|
if (rule.isDeleted())
|
|
continue;
|
|
|
|
const auto &lhs = rule.getLHS();
|
|
const auto &rhs = rule.getRHS();
|
|
|
|
for (unsigned index : indices(lhs)) {
|
|
auto symbol = lhs[index];
|
|
|
|
if (index != lhs.size() - 1) {
|
|
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Layout);
|
|
ASSERT_RULE(!symbol.isSuperclassOrConcreteType());
|
|
}
|
|
|
|
if (index != 0) {
|
|
ASSERT_RULE(symbol.getKind() != Symbol::Kind::GenericParam);
|
|
}
|
|
|
|
if (index != 0 && index != lhs.size() - 1) {
|
|
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Protocol);
|
|
}
|
|
}
|
|
|
|
for (unsigned index : indices(rhs)) {
|
|
auto symbol = rhs[index];
|
|
|
|
// FIXME: This is only true if the input requirements were valid.
|
|
// On invalid code, we'll need to skip this assertion (and instead
|
|
// assert that we diagnosed an error!)
|
|
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Name);
|
|
|
|
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Layout);
|
|
ASSERT_RULE(!symbol.isSuperclassOrConcreteType());
|
|
|
|
if (index != 0) {
|
|
ASSERT_RULE(symbol.getKind() != Symbol::Kind::GenericParam);
|
|
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Protocol);
|
|
}
|
|
}
|
|
|
|
auto lhsDomain = lhs.getRootProtocols();
|
|
auto rhsDomain = rhs.getRootProtocols();
|
|
|
|
ASSERT_RULE(lhsDomain == rhsDomain);
|
|
}
|
|
|
|
#undef ASSERT_RULE
|
|
#endif
|
|
}
|
|
|
|
void RewriteSystem::dump(llvm::raw_ostream &out) const {
|
|
out << "Rewrite system: {\n";
|
|
for (const auto &rule : Rules) {
|
|
out << "- " << rule << "\n";
|
|
}
|
|
out << "}\n";
|
|
}
|