Sema: Fix associated type solution ranking

We folded away viable solutions with identical type witnesses;
the first one "wins". However, solutions also store the value
witnesses from which those type witnesses were derived, and
this determines their ranking.

Suppose we have three solutions S_1, S_2, S_3 ranked as follows:

    S_1 < S_2 < S_3

If S_1 and S_3 have identical type witnesses, then one of two
things would happen:

Scenario A:
- we find S_1, and record it.
- we find S_2, and record it.
- we find S_3; it's identical to S_1, so we drop it.

Scenario B:
- we find S_3, and record it.
- we find S_2, and record it.
- we find S_1; it's identical to S_3, so we drop it.

Now, we the best solution Scenario A is S_1, and the best
solution in Scenario B is S_3.

To fix this and ensure we always end up with S_1, remove this
folding of solutions, except for invalid solutions where it
doesn't matter.

To avoid recording too many viable solutions, instead prune the
solution list every time we add a new solution. This maintains the
invariant that no solution is clearly worse than the others; when
we get to the end, we just check if we have exactly one solution,
in which case we know it's the best one.

Fixes rdar://problem/122586685.
This commit is contained in:
Slava Pestov
2024-02-09 11:09:21 -05:00
parent 5ef198d692
commit 83cb420ee4
2 changed files with 90 additions and 103 deletions

View File

@@ -1061,16 +1061,6 @@ private:
bool isBetterSolution(const InferredTypeWitnessesSolution &first,
const InferredTypeWitnessesSolution &second);
/// Find the best solution.
///
/// \param solutions All of the solutions to consider. On success,
/// this will contain only the best solution.
///
/// \returns \c false if there was a single best solution,
/// \c true if no single best solution exists.
bool findBestSolution(
SmallVectorImpl<InferredTypeWitnessesSolution> &solutions);
/// Emit a diagnostic for the case where there are no solutions at all
/// to consider.
///
@@ -3015,8 +3005,11 @@ void AssociatedTypeInference::findSolutionsRec(
++NumSolutionStates;
// Validate and complete the solution.
// Fold the dependent member types within this type.
// Fold any concrete dependent member types that remain among our
// tentative type witnesses.
//
// FIXME: inferAbstractTypeWitnesses() also does this in a different way;
// combine the two.
for (auto assocType : proto->getAssociatedTypeMembers()) {
if (conformance->hasTypeWitness(assocType))
continue;
@@ -3039,33 +3032,6 @@ void AssociatedTypeInference::findSolutionsRec(
known->first = replaced;
}
// Check whether our current solution matches the given solution.
auto matchesSolution =
[&](const InferredTypeWitnessesSolution &solution) {
for (const auto &existingTypeWitness : solution.TypeWitnesses) {
auto typeWitness = typeWitnesses.begin(existingTypeWitness.first);
if (!typeWitness->first->isEqual(existingTypeWitness.second.first))
return false;
}
return true;
};
// If we've seen this solution already, bail out; there's no point in
// checking further.
if (llvm::any_of(solutions, matchesSolution)) {
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
<< "+ Duplicate valid solution found\n";);
++NumDuplicateSolutionStates;
return;
}
if (llvm::any_of(nonViableSolutions, matchesSolution)) {
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
<< "+ Duplicate invalid solution found\n";);
++NumDuplicateSolutionStates;
return;
}
/// Check the current set of type witnesses.
bool invalid = checkCurrentTypeWitnesses(valueWitnesses);
@@ -3077,9 +3043,8 @@ void AssociatedTypeInference::findSolutionsRec(
<< "+ Valid solution found\n";);
}
auto &solutionList = invalid ? nonViableSolutions : solutions;
solutionList.push_back(InferredTypeWitnessesSolution());
auto &solution = solutionList.back();
// Build the solution.
InferredTypeWitnessesSolution solution;
// Copy the type witnesses.
for (auto assocType : unresolvedAssocTypes) {
@@ -3092,14 +3057,58 @@ void AssociatedTypeInference::findSolutionsRec(
solution.NumValueWitnessesInProtocolExtensions
= numValueWitnessesInProtocolExtensions;
// If this solution was clearly better than the previous best solution,
// swap them.
if (solutionList.back().NumValueWitnessesInProtocolExtensions
< solutionList.front().NumValueWitnessesInProtocolExtensions) {
std::swap(solutionList.front(), solutionList.back());
// We fold away non-viable solutions that have the same type witnesses.
if (invalid) {
auto matchesSolution = [&](const InferredTypeWitnessesSolution &other) {
for (const auto &otherTypeWitness : other.TypeWitnesses) {
auto typeWitness = solution.TypeWitnesses.find(otherTypeWitness.first);
if (!typeWitness->second.first->isEqual(otherTypeWitness.second.first))
return false;
}
return true;
};
if (llvm::any_of(nonViableSolutions, matchesSolution)) {
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
<< "+ Duplicate invalid solution found\n";);
++NumDuplicateSolutionStates;
return;
}
nonViableSolutions.push_back(std::move(solution));
return;
}
// We're done recording the solution.
// For valid solutions, we want to find the best solution if one exists.
// We maintain the invariant that no viable solution is clearly worse than
// any other viable solution. If multiple viable solutions remain after
// we're considered the entire search space, we have an ambiguous situation.
// If this solution is clearly worse than some existing solution, give up.
if (llvm::any_of(solutions, [&](const InferredTypeWitnessesSolution &other) {
return isBetterSolution(other, solution);
})) {
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
<< "+ Solution is worse than some existing solution\n";);
++NumDuplicateSolutionStates;
return;
}
// If any existing solutions are clearly worse than this solution,
// remove them.
llvm::erase_if(solutions, [&](const InferredTypeWitnessesSolution &other) {
if (isBetterSolution(solution, other)) {
LLVM_DEBUG(llvm::dbgs() << std::string(valueWitnesses.size(), '+')
<< "+ Solution is better than some existing solution\n";);
++NumDuplicateSolutionStates;
return true;
}
return false;
});
solutions.push_back(std::move(solution));
return;
}
@@ -3414,6 +3423,23 @@ bool AssociatedTypeInference::isBetterSolution(
const InferredTypeWitnessesSolution &first,
const InferredTypeWitnessesSolution &second) {
assert(first.ValueWitnesses.size() == second.ValueWitnesses.size());
if (first.NumValueWitnessesInProtocolExtensions <
second.NumValueWitnessesInProtocolExtensions)
return true;
if (first.NumValueWitnessesInProtocolExtensions >
second.NumValueWitnessesInProtocolExtensions)
return false;
// Dear reader: this is not a lexicographic order on tuple of value witnesses;
// rather, (x_1, ..., x_n) < (y_1, ..., y_n) if and only if:
//
// - there exists at least one index i such that x_i < y_i.
// - there does not exist any i such that y_i < x_i.
//
// that is, the order relation is independent of the order in which value
// witnesses were pushed onto the stack.
bool firstBetter = false;
bool secondBetter = false;
for (unsigned i = 0, n = first.ValueWitnesses.size(); i != n; ++i) {
@@ -3446,58 +3472,6 @@ bool AssociatedTypeInference::isBetterSolution(
return firstBetter;
}
bool AssociatedTypeInference::findBestSolution(
SmallVectorImpl<InferredTypeWitnessesSolution> &solutions) {
if (solutions.empty()) return true;
if (solutions.size() == 1) return false;
// The solution at the front has the smallest number of value witnesses found
// in protocol extensions, by construction.
unsigned bestNumValueWitnessesInProtocolExtensions
= solutions.front().NumValueWitnessesInProtocolExtensions;
// Erase any solutions with more value witnesses in protocol
// extensions than the best.
solutions.erase(
std::remove_if(solutions.begin(), solutions.end(),
[&](const InferredTypeWitnessesSolution &solution) {
return solution.NumValueWitnessesInProtocolExtensions >
bestNumValueWitnessesInProtocolExtensions;
}),
solutions.end());
// If we're down to one solution, success!
if (solutions.size() == 1) return false;
// Find a solution that's at least as good as the solutions that follow it.
unsigned bestIdx = 0;
for (unsigned i = 1, n = solutions.size(); i != n; ++i) {
if (isBetterSolution(solutions[i], solutions[bestIdx]))
bestIdx = i;
}
// Make sure that solution is better than any of the other solutions.
bool ambiguous = false;
for (unsigned i = 1, n = solutions.size(); i != n; ++i) {
if (i != bestIdx && !isBetterSolution(solutions[bestIdx], solutions[i])) {
ambiguous = true;
break;
}
}
// If the result was ambiguous, fail.
if (ambiguous) {
assert(solutions.size() != 1 && "should have succeeded somewhere above?");
return true;
}
// Keep the best solution, erasing all others.
if (bestIdx != 0)
solutions[0] = std::move(solutions[bestIdx]);
solutions.erase(solutions.begin() + 1, solutions.end());
return false;
}
namespace {
/// A failed type witness binding.
struct FailedTypeWitness {
@@ -3897,9 +3871,8 @@ auto AssociatedTypeInference::solve()
}
}
// Find the best solution.
if (!findBestSolution(solutions)) {
assert(solutions.size() == 1 && "Not a unique best solution?");
// Happy case: we found exactly one viable solution.
if (solutions.size() == 1) {
// Form the resulting solution.
auto &typeWitnesses = solutions.front().TypeWitnesses;
for (auto assocType : unresolvedAssocTypes) {

View File

@@ -0,0 +1,14 @@
// RUN: %target-typecheck-verify-swift -enable-experimental-associated-type-inference
// RUN: %target-typecheck-verify-swift -disable-experimental-associated-type-inference
public struct S: P {}
public protocol P: Collection {}
extension P {
public func index(after i: Int) -> Int { fatalError() }
public var startIndex: Int { fatalError() }
public var endIndex: Int { fatalError() }
public subscript(index: Int) -> String { fatalError() }
public func makeIterator() -> AnyIterator<String> { fatalError() }
}