Files
swift-mirror/lib/AST/RequirementMachine/RewriteSystem.cpp
Slava Pestov 65f27d3ab7 RequirementMachine: Fix subtle corner case where we could insert duplicate rewrite rules
RewriteSystem::addRule() did two things before adding the rewrite rule:

- Simplify both sides. If both sides are now equal, discard the rule.

- If the last symbol on the left hand side was a superclass or concrete type
  symbol, simplify the substitution terms in that symbol.

The problem is that the second step can produce a term which can be further
simplified, and in particular, one that is exactly equal to the left hand
side of some other rule.

To fix this, swap the order of the two steps. The only wrinkle is now we
have to check for a concrete type symbol at the end of _both_ the left hand
side and right hand side, since we don't orient the rule until we simplify
both sides.

I don't have a reduced test case for this one, but it was revealed by
compiler_crashers_2_fixed/0109-sr4737.swift after I introduced the trie.
2021-08-05 21:42:50 -04:00

253 lines
6.7 KiB
C++

//===--- RewriteSystem.cpp - Generics with term rewriting -----------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2021 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
#include "swift/AST/Decl.h"
#include "swift/AST/Types.h"
#include "llvm/ADT/FoldingSet.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <vector>
#include "ProtocolGraph.h"
#include "RewriteContext.h"
#include "RewriteSystem.h"
using namespace swift;
using namespace rewriting;
RewriteSystem::RewriteSystem(RewriteContext &ctx)
: Context(ctx) {
DebugSimplify = false;
DebugAdd = false;
DebugMerge = false;
DebugCompletion = false;
}
void Rule::dump(llvm::raw_ostream &out) const {
out << LHS << " => " << RHS;
if (deleted)
out << " [deleted]";
}
void RewriteSystem::initialize(
std::vector<std::pair<MutableTerm, MutableTerm>> &&rules,
ProtocolGraph &&graph) {
Protos = graph;
for (const auto &rule : rules)
addRule(rule.first, rule.second);
}
Symbol RewriteSystem::simplifySubstitutionsInSuperclassOrConcreteSymbol(
Symbol symbol) const {
return symbol.transformConcreteSubstitutions(
[&](Term term) -> Term {
MutableTerm mutTerm(term);
if (!simplify(mutTerm))
return term;
return Term::get(mutTerm, Context);
}, Context);
}
bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
assert(!lhs.empty());
assert(!rhs.empty());
// First, simplify terms appearing inside concrete substitutions before
// doing anything else.
if (lhs.back().isSuperclassOrConcreteType())
lhs.back() = simplifySubstitutionsInSuperclassOrConcreteSymbol(lhs.back());
else if (rhs.back().isSuperclassOrConcreteType())
rhs.back() = simplifySubstitutionsInSuperclassOrConcreteSymbol(rhs.back());
// Now simplify both sides as much as possible with the rules we have so far.
//
// This avoids unnecessary work in the completion algorithm.
simplify(lhs);
simplify(rhs);
// If the left hand side and right hand side are already equivalent, we're
// done.
int result = lhs.compare(rhs, Protos);
if (result == 0)
return false;
// Orient the two terms so that the left hand side is greater than the
// right hand side.
if (result < 0)
std::swap(lhs, rhs);
assert(lhs.compare(rhs, Protos) > 0);
if (DebugAdd) {
llvm::dbgs() << "# Adding rule " << lhs << " => " << rhs << "\n";
}
unsigned i = Rules.size();
Rules.emplace_back(Term::get(lhs, Context), Term::get(rhs, Context));
// Check if we have a rule of the form
//
// X.[P1:T] => X.[P2:T]
//
// If so, record this rule for later. We'll try to merge the associated
// types in RewriteSystem::processMergedAssociatedTypes().
if (lhs.size() == rhs.size() &&
std::equal(lhs.begin(), lhs.end() - 1, rhs.begin()) &&
lhs.back().getKind() == Symbol::Kind::AssociatedType &&
rhs.back().getKind() == Symbol::Kind::AssociatedType &&
lhs.back().getName() == rhs.back().getName()) {
MergedAssociatedTypes.emplace_back(lhs, rhs);
}
// Since we added a new rule, we have to check for overlaps between the
// new rule and all existing rules.
for (unsigned j : indices(Rules)) {
// A rule does not overlap with itself.
if (i == j)
continue;
// We don't have to check for overlap with deleted rules.
if (Rules[j].isDeleted())
continue;
// The overlap check is not commutative so we have to check both
// directions.
Worklist.emplace_back(i, j);
Worklist.emplace_back(j, i);
if (DebugCompletion) {
llvm::dbgs() << "$ Queued up (" << i << ", " << j << ") and ";
llvm::dbgs() << "(" << j << ", " << i << ")\n";
}
}
// Tell the caller that we added a new rule.
return true;
}
/// Reduce a term by applying all rewrite rules until fixed point.
bool RewriteSystem::simplify(MutableTerm &term) const {
bool changed = false;
if (DebugSimplify) {
llvm::dbgs() << "= Term " << term << "\n";
}
while (true) {
bool tryAgain = false;
for (const auto &rule : Rules) {
if (rule.isDeleted())
continue;
if (DebugSimplify) {
llvm::dbgs() << "== Rule " << rule << "\n";
}
if (rule.apply(term)) {
if (DebugSimplify) {
llvm::dbgs() << "=== Result " << term << "\n";
}
changed = true;
tryAgain = true;
}
}
if (!tryAgain)
break;
}
return changed;
}
void RewriteSystem::simplifyRightHandSides() {
for (auto &rule : Rules) {
if (rule.isDeleted())
continue;
MutableTerm rhs(rule.getRHS());
if (!simplify(rhs))
continue;
rule = Rule(rule.getLHS(), Term::get(rhs, Context));
}
#ifndef NDEBUG
#define ASSERT_RULE(expr) \
if (!(expr)) { \
llvm::errs() << "&&& Malformed rewrite rule: " << rule << "\n"; \
llvm::errs() << "&&& " << #expr << "\n\n"; \
dump(llvm::errs()); \
assert(expr); \
}
for (const auto &rule : Rules) {
if (rule.isDeleted())
continue;
const auto &lhs = rule.getLHS();
const auto &rhs = rule.getRHS();
for (unsigned index : indices(lhs)) {
auto symbol = lhs[index];
if (index != lhs.size() - 1) {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Layout);
ASSERT_RULE(!symbol.isSuperclassOrConcreteType());
}
if (index != 0) {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::GenericParam);
}
if (index != 0 && index != lhs.size() - 1) {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Protocol);
}
}
for (unsigned index : indices(rhs)) {
auto symbol = rhs[index];
// FIXME: This is only true if the input requirements were valid.
// On invalid code, we'll need to skip this assertion (and instead
// assert that we diagnosed an error!)
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Name);
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Layout);
ASSERT_RULE(!symbol.isSuperclassOrConcreteType());
if (index != 0) {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::GenericParam);
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Protocol);
}
}
auto lhsDomain = lhs.getRootProtocols();
auto rhsDomain = rhs.getRootProtocols();
ASSERT_RULE(lhsDomain == rhsDomain);
}
#undef ASSERT_RULE
#endif
}
void RewriteSystem::dump(llvm::raw_ostream &out) const {
out << "Rewrite system: {\n";
for (const auto &rule : Rules) {
out << "- " << rule << "\n";
}
out << "}\n";
}