[CSOptimizer] Keep track of mismatches while evaluating candidates

This commit is contained in:
Pavel Yaskevich
2023-02-10 11:33:37 -08:00
parent c2f7451c7b
commit a094c3ebb0

View File

@@ -16,6 +16,7 @@
#include "swift/Sema/ConstraintGraph.h"
#include "swift/Sema/ConstraintSystem.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TinyPtrVector.h"
@@ -231,20 +232,45 @@ static void determineBestChoicesInContext(
if (paramType->is<FunctionType>())
continue;
double argScore = 0.0;
for (auto const &candidate : candidateArgumentTypes[i]) {
auto candidateType = candidate.first;
if (candidateArgumentTypes[i].empty())
continue;
// The idea here is to match the parameter type against
// all of the argument candidate types and pick the best
// match (i.e. exact equality one).
//
// If none of the candidates match exactly and they are
// all bound concrete types, we consider this is mismatch
// at this parameter position and remove the overload choice
// from consideration.
double bestCandidateScore = 0;
llvm::BitVector mismatches(candidateArgumentTypes[i].size());
for (unsigned candidateIdx : indices(candidateArgumentTypes[i])) {
// If one of the candidates matched exactly there is no reason
// to continue checking.
if (bestCandidateScore == 1)
break;
Type candidateType;
bool isLiteralDefault;
std::tie(candidateType, isLiteralDefault) =
candidateArgumentTypes[i][candidateIdx];
// `inout` parameter accepts only l-value argument.
if (paramFlags.isInOut() && !candidateType->is<LValueType>())
if (paramFlags.isInOut() && !candidateType->is<LValueType>()) {
mismatches.set(candidateIdx);
continue;
}
// The specifier only matters for `inout` check.
candidateType = candidateType->getWithoutSpecifierType();
// Exact match on one of the candidate bindings.
if (candidateType->isEqual(paramType)) {
argScore = std::max(
argScore, /*fromLiteral=*/candidate.second ? 0.3 : 1.0);
bestCandidateScore =
std::max(bestCandidateScore, isLiteralDefault ? 0.3 : 1.0);
continue;
}
@@ -253,12 +279,23 @@ static void determineBestChoicesInContext(
// and expect a particular `Scalar` if it's known.
if (isSIMDType(candidateType) &&
isArithmeticOperatorOnSIMDProtocol(decl)) {
argScore = std::max(argScore, 1.0);
bestCandidateScore = 1.0;
continue;
}
// Only established arguments could be considered mismatches,
// literal default types should be regarded as holes if they
// didn't match.
if (!isLiteralDefault && !candidateType->hasTypeVariable())
mismatches.set(candidateIdx);
}
score += argScore;
// If none of the candidates for this parameter matched, let's
// drop this overload from any further consideration.
if (mismatches.all())
return;
score += bestCandidateScore;
}
// Average the score to avoid disfavoring disjunctions with fewer