Solver: Keep track of a solution's score as we're computing it.

No functionality change here; just staging for some future optimizations.


Swift SVN r11028
This commit is contained in:
Doug Gregor
2013-12-09 17:12:07 +00:00
parent ac7ee4ba59
commit 79f8175e0b
8 changed files with 209 additions and 83 deletions

View File

@@ -3353,54 +3353,3 @@ Solution::convertToArrayBound(Expr *expr, ConstraintLocator *locator) const {
return result; return result;
} }
int Solution::getFixedScore() const {
if (fixedScore)
return *fixedScore;
int score = 0;
// Consider overload choices.
for (auto overload : overloadChoices) {
auto choice = overload.second.choice;
if (choice.getKind() != OverloadChoiceKind::Decl)
continue;
// -3 penalty for each user-defined conversion.
if (choice.getDecl()->getAttrs().isConversion())
score -= 3;
}
// Consider type bindings.
auto &tc = getConstraintSystem().getTypeChecker();
for (auto binding : typeBindings) {
// Look for type variables corresponding directly to an expression.
auto typeVar = binding.first;
auto locator = typeVar->getImpl().getLocator();
if (!locator || !locator->getAnchor() || !locator->getPath().empty())
continue;
// Check whether there is a literal protocol corresponding to the
// anchor expression.
auto literalProtocol
= tc.getLiteralProtocol(locator->getAnchor());
if (!literalProtocol)
continue;
// Retrieve the default type for this literal protocol, if there is one.
auto defaultType = tc.getDefaultType(literalProtocol,
getConstraintSystem().DC);
if (!defaultType)
continue;
// +1 if the bound type matches the default type for this literal protocol.
// Literal types are always nominal, so we simply check the nominal
// declaration. This covers e.g., Slice vs. Slice<T>.
if (defaultType->getAnyNominal() == binding.second->getAnyNominal())
++score;
}
// Save the fixed score.
fixedScore = score;
return score;
}

View File

@@ -27,6 +27,20 @@ using namespace constraints;
#define DEBUG_TYPE "Constraint solver overall" #define DEBUG_TYPE "Constraint solver overall"
STATISTIC(NumDiscardedSolutions, "# of solutions discarded"); STATISTIC(NumDiscardedSolutions, "# of solutions discarded");
void ConstraintSystem::increaseScore(ScoreKind kind) {
unsigned index = static_cast<unsigned>(kind);
++CurrentScore.Data[index];
}
llvm::raw_ostream &constraints::operator<<(llvm::raw_ostream &out,
const Score &score) {
for (unsigned i = 0; i != NumScoreKinds; ++i) {
if (i) out << ' ';
out << score.Data[i];
}
return out;
}
/// \brief Remove the initializers from any tuple types within the /// \brief Remove the initializers from any tuple types within the
/// given type. /// given type.
static Type stripInitializers(TypeChecker &tc, Type origType) { static Type stripInitializers(TypeChecker &tc, Type origType) {
@@ -446,6 +460,13 @@ Comparison TypeChecker::compareDeclarations(DeclContext *dc,
return decl1Better? Comparison::Better : Comparison::Worse; return decl1Better? Comparison::Better : Comparison::Worse;
} }
/// Simplify a score into a single integer.
/// FIXME: Temporary hack.
static int simplifyScore(const Score &score) {
return (int)score.Data[SK_UserConversion] * -3
+ (int)score.Data[SK_NonDefaultLiteral] * -1;
}
SolutionCompareResult ConstraintSystem::compareSolutions( SolutionCompareResult ConstraintSystem::compareSolutions(
ConstraintSystem &cs, ConstraintSystem &cs,
ArrayRef<Solution> solutions, ArrayRef<Solution> solutions,
@@ -458,8 +479,8 @@ SolutionCompareResult ConstraintSystem::compareSolutions(
// Solution comparison uses a scoring system to determine whether one // Solution comparison uses a scoring system to determine whether one
// solution is better than the other. Retrieve the fixed scores for each of // solution is better than the other. Retrieve the fixed scores for each of
// the solutions, which we'll modify with relative scoring. // the solutions, which we'll modify with relative scoring.
int score1 = solutions[idx1].getFixedScore(); int score1 = simplifyScore(solutions[idx1].getFixedScore());
int score2 = solutions[idx2].getFixedScore(); int score2 = simplifyScore(solutions[idx2].getFixedScore());
// Compare overload sets. // Compare overload sets.
for (auto &overload : diff.overloads) { for (auto &overload : diff.overloads) {

View File

@@ -526,6 +526,9 @@ tryUserConversion(ConstraintSystem &cs, Type type, ConstraintKind kind,
cs.addConstraint(kind, outputTV, otherType, resultLocator); cs.addConstraint(kind, outputTV, otherType, resultLocator);
} }
// We're adding a user-defined conversion.
cs.increaseScore(SK_UserConversion);
return ConstraintSystem::SolutionKind::Solved; return ConstraintSystem::SolutionKind::Solved;
} }

View File

@@ -73,7 +73,7 @@ static Optional<Type> checkTypeOfBinding(ConstraintSystem &cs,
Solution ConstraintSystem::finalize( Solution ConstraintSystem::finalize(
FreeTypeVariableBinding allowFreeTypeVariables) { FreeTypeVariableBinding allowFreeTypeVariables) {
// Create the solution. // Create the solution.
Solution solution(*this); Solution solution(*this, CurrentScore);
// For any of the type variables that has no associated fixed type, assign a // For any of the type variables that has no associated fixed type, assign a
// fresh generic type parameters. // fresh generic type parameters.
@@ -123,6 +123,9 @@ Solution ConstraintSystem::finalize(
} }
void ConstraintSystem::applySolution(const Solution &solution) { void ConstraintSystem::applySolution(const Solution &solution) {
// Update the score.
CurrentScore += solution.getFixedScore();
// Assign fixed types to the type variables solved by this solution. // Assign fixed types to the type variables solved by this solution.
llvm::SmallPtrSet<TypeVariableType *, 4> llvm::SmallPtrSet<TypeVariableType *, 4>
knownTypeVariables(TypeVariables.begin(), TypeVariables.end()); knownTypeVariables(TypeVariables.begin(), TypeVariables.end());
@@ -134,7 +137,7 @@ void ConstraintSystem::applySolution(const Solution &solution) {
// If we don't already have a fixed type for this type variable, // If we don't already have a fixed type for this type variable,
// assign the fixed type from the solution. // assign the fixed type from the solution.
if (!getFixedType(binding.first) && !binding.second->hasTypeVariable()) if (!getFixedType(binding.first) && !binding.second->hasTypeVariable())
assignFixedType(binding.first, binding.second); assignFixedType(binding.first, binding.second, /*updateScore=*/false);
} }
// Register overload choices. // Register overload choices.
@@ -583,6 +586,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
numConstraintRestrictions = cs.solverState->constraintRestrictions.size(); numConstraintRestrictions = cs.solverState->constraintRestrictions.size();
oldGeneratedConstraints = cs.solverState->generatedConstraints; oldGeneratedConstraints = cs.solverState->generatedConstraints;
cs.solverState->generatedConstraints = &generatedConstraints; cs.solverState->generatedConstraints = &generatedConstraints;
PreviousScore = cs.CurrentScore;
++cs.solverState->NumStatesExplored; ++cs.solverState->NumStatesExplored;
@@ -618,6 +622,9 @@ ConstraintSystem::SolverScope::~SolverScope() {
// Reset the prior generated-constraints pointer. // Reset the prior generated-constraints pointer.
cs.solverState->generatedConstraints = oldGeneratedConstraints; cs.solverState->generatedConstraints = oldGeneratedConstraints;
// Reset the previous score.
cs.CurrentScore = PreviousScore;
// Clear out other "failed" state. // Clear out other "failed" state.
cs.failedConstraint = nullptr; cs.failedConstraint = nullptr;
} }
@@ -932,7 +939,8 @@ bool ConstraintSystem::solve(SmallVectorImpl<Solution> &solutions,
auto solution = finalize(allowFreeTypeVariables); auto solution = finalize(allowFreeTypeVariables);
if (TC.getLangOpts().DebugConstraintSolver) { if (TC.getLangOpts().DebugConstraintSolver) {
auto &log = getASTContext().TypeCheckerDebug->getStream(); auto &log = getASTContext().TypeCheckerDebug->getStream();
log.indent(solverState->depth * 2) << "(found solution)\n"; log.indent(solverState->depth * 2)
<< "(found solution " << CurrentScore << ")\n";
} }
solutions.push_back(std::move(solution)); solutions.push_back(std::move(solution));
@@ -1069,6 +1077,11 @@ bool ConstraintSystem::solve(SmallVectorImpl<Solution> &solutions,
// Move the type variables back, clear out constraints; we're // Move the type variables back, clear out constraints; we're
// ready for the next component. // ready for the next component.
TypeVariables = std::move(allTypeVariables); TypeVariables = std::move(allTypeVariables);
// For each of the partial solutions, substract off the current score.
// It doesn't contribute.
for (auto &solution : partialSolutions[component])
solution.getFixedScore() -= CurrentScore;
} }
// Move the constraints back. The system is back in a normal state. // Move the constraints back. The system is back in a normal state.
@@ -1103,7 +1116,8 @@ bool ConstraintSystem::solve(SmallVectorImpl<Solution> &solutions,
auto solution = finalize(allowFreeTypeVariables); auto solution = finalize(allowFreeTypeVariables);
if (TC.getLangOpts().DebugConstraintSolver) { if (TC.getLangOpts().DebugConstraintSolver) {
auto &log = getASTContext().TypeCheckerDebug->getStream(); auto &log = getASTContext().TypeCheckerDebug->getStream();
log.indent(solverState->depth * 2) << "(composed solution)\n"; log.indent(solverState->depth * 2)
<< "(composed solution " << CurrentScore << ")\n";
} }
// Save this solution. // Save this solution.

View File

@@ -69,9 +69,51 @@ void ConstraintSystem::mergeEquivalenceClasses(TypeVariableType *typeVar1,
} }
} }
void ConstraintSystem::assignFixedType(TypeVariableType *typeVar, Type type) { void ConstraintSystem::assignFixedType(TypeVariableType *typeVar, Type type,
bool updateScore) {
typeVar->getImpl().assignFixedType(type, getSavedBindings()); typeVar->getImpl().assignFixedType(type, getSavedBindings());
if (updateScore && !type->is<TypeVariableType>()) {
// If this type variable represents a literal, check whether we picked the
// default literal type. First, find the corresponding protocol.
ProtocolDecl *literalProtocol = nullptr;
if (CG) {
// If we have the constraint graph, we can check all type variables in
// the equivalence class. This is the More Correct path.
// FIXME: Eliminate the less-correct path.
auto typeVarRep = getRepresentative(typeVar);
for (auto tv : (*CG)[typeVarRep].getEquivalenceClass()) {
auto locator = tv->getImpl().getLocator();
if (!locator || !locator->getPath().empty())
continue;
auto anchor = locator->getAnchor();
if (!anchor)
continue;
literalProtocol = TC.getLiteralProtocol(anchor);
if (literalProtocol)
break;
}
} else {
// FIXME: This is the less-correct path.
auto locator = typeVar->getImpl().getLocator();
if (locator && locator->getPath().empty() && locator->getAnchor()) {
literalProtocol = TC.getLiteralProtocol(locator->getAnchor());
}
}
// If the protocol has a default type, check it.
if (literalProtocol) {
if (auto defaultType = TC.getDefaultType(literalProtocol, DC)) {
// Check whether the nominal types match. This makes sure that we
// properly handle Slice vs. Slice<T>.
if (defaultType->getAnyNominal() != type->getAnyNominal())
increaseScore(SK_NonDefaultLiteral);
}
}
}
// Notify the constraint graph. // Notify the constraint graph.
if (CG) { if (CG) {
CG->bindTypeVariable(typeVar, type); CG->bindTypeVariable(typeVar, type);

View File

@@ -611,6 +611,94 @@ struct SelectedOverload {
Type openedType; Type openedType;
}; };
/// Describes an aspect of a solution that affects is overall score, i.e., a
/// user-defined conversions.
enum ScoreKind {
/// A user-defined conversion.
SK_UserConversion = 0,
/// A literal expression bound to a non-default literal type.
SK_NonDefaultLiteral = 1
};
/// The number of score kinds.
const unsigned NumScoreKinds = 2;
/// Describes the fixed score of a solution to the constraint system.
struct Score {
unsigned Data[NumScoreKinds] = { 0, 0 };
friend Score &operator+=(Score &x, const Score &y) {
for (unsigned i = 0; i != NumScoreKinds; ++i) {
x.Data[i] += y.Data[i];
}
return x;
}
friend Score operator+(const Score &x, const Score &y) {
Score result;
for (unsigned i = 0; i != NumScoreKinds; ++i) {
result.Data[i] = x.Data[i] + y.Data[i];
}
return result;
}
friend Score operator-(const Score &x, const Score &y) {
Score result;
for (unsigned i = 0; i != NumScoreKinds; ++i) {
result.Data[i] = x.Data[i] - y.Data[i];
}
return result;
}
friend Score &operator-=(Score &x, const Score &y) {
for (unsigned i = 0; i != NumScoreKinds; ++i) {
x.Data[i] -= y.Data[i];
}
return x;
}
friend bool operator==(const Score &x, const Score &y) {
for (unsigned i = 0; i != NumScoreKinds; ++i) {
if (x.Data[i] != y.Data[i])
return false;
}
return true;
}
friend bool operator!=(const Score &x, const Score &y) {
return !(x == y);
}
friend bool operator<(const Score &x, const Score &y) {
for (unsigned i = 0; i != NumScoreKinds; ++i) {
if (x.Data[i] < y.Data[i])
return true;
if (x.Data[i] > y.Data[i])
return false;
}
return true;
}
friend bool operator<=(const Score &x, const Score &y) {
return !(y < x);
}
friend bool operator>(const Score &x, const Score &y) {
return y < x;
}
friend bool operator>=(const Score &x, const Score &y) {
return !(x < y);
}
};
/// Display a score.
llvm::raw_ostream &operator<<(llvm::raw_ostream &out, const Score &score);
/// \brief A complete solution to a constraint system. /// \brief A complete solution to a constraint system.
/// ///
/// A solution to a constraint system consists of type variable bindings to /// A solution to a constraint system consists of type variable bindings to
@@ -622,31 +710,19 @@ class Solution {
ConstraintSystem *constraintSystem; ConstraintSystem *constraintSystem;
/// \brief The fixed score for this solution. /// \brief The fixed score for this solution.
mutable Optional<int> fixedScore; Score FixedScore;
public: public:
/// \brief Create a solution for the given constraint system. /// \brief Create a solution for the given constraint system.
Solution(ConstraintSystem &cs) : constraintSystem(&cs) {} Solution(ConstraintSystem &cs, const Score &score)
: constraintSystem(&cs), FixedScore(score) {}
// Solution is a non-copyable type for performance reasons. // Solution is a non-copyable type for performance reasons.
Solution(const Solution &other) = delete; Solution(const Solution &other) = delete;
Solution &operator=(const Solution &other) = delete; Solution &operator=(const Solution &other) = delete;
Solution(Solution &&other) Solution(Solution &&other) = default;
: constraintSystem(other.constraintSystem), Solution &operator=(Solution &&other) = default;
typeBindings(std::move(other.typeBindings)),
overloadChoices(std::move(other.overloadChoices)),
constraintRestrictions(std::move(other.constraintRestrictions))
{
}
Solution &operator=(Solution &&other) {
constraintSystem = other.constraintSystem;
typeBindings = std::move(other.typeBindings);
overloadChoices = std::move(other.overloadChoices);
constraintRestrictions = std::move(other.constraintRestrictions);
return *this;
}
/// \brief Retrieve the constraint system that this solution solves. /// \brief Retrieve the constraint system that this solution solves.
ConstraintSystem &getConstraintSystem() const { return *constraintSystem; } ConstraintSystem &getConstraintSystem() const { return *constraintSystem; }
@@ -726,10 +802,12 @@ public:
Type openedType, Type openedType,
SmallVectorImpl<Substitution> &substitutions) const; SmallVectorImpl<Substitution> &substitutions) const;
/// \brief Retrieve the fixed score of this solution, which considers /// \brief Retrieve the fixed score of this solution
/// the number of user-defined conversions. const Score &getFixedScore() const { return FixedScore; }
int getFixedScore() const;
/// \brief Retrieve the fixed score of this solution
Score &getFixedScore() { return FixedScore; }
/// \brief Retrieve the fixed type for the given type variable. /// \brief Retrieve the fixed type for the given type variable.
Type getFixedType(TypeVariableType *typeVar) const; Type getFixedType(TypeVariableType *typeVar) const;
@@ -911,6 +989,10 @@ private:
/// \brief The overload sets that have been resolved along the current path. /// \brief The overload sets that have been resolved along the current path.
ResolvedOverloadSetListItem *resolvedOverloadSets = nullptr; ResolvedOverloadSetListItem *resolvedOverloadSets = nullptr;
/// The current fixed score for this constraint system and the (partial)
/// solution it represents.
Score CurrentScore;
SmallVector<TypeVariableType *, 16> TypeVariables; SmallVector<TypeVariableType *, 16> TypeVariables;
ConstraintList Constraints; ConstraintList Constraints;
@@ -1012,6 +1094,9 @@ public:
/// \brief The length of \c constraintRestrictions. /// \brief The length of \c constraintRestrictions.
unsigned numConstraintRestrictions; unsigned numConstraintRestrictions;
/// The previous score.
Score PreviousScore;
/// Constraint graph scope associated with this solver scope. /// Constraint graph scope associated with this solver scope.
/// ///
/// FIXME: This is optional so we can easily enabled/disable the /// FIXME: This is optional so we can easily enabled/disable the
@@ -1314,7 +1399,12 @@ public:
bool wantRValue); bool wantRValue);
/// \brief Assign a fixed type to the given type variable. /// \brief Assign a fixed type to the given type variable.
void assignFixedType(TypeVariableType *typeVar, Type type); ///
/// \param typeVar The type variable to bind.
/// \param type The fixed type to which the type variable will be bound.
/// \param updateScore Whether to update the score based on this binding.
void assignFixedType(TypeVariableType *typeVar, Type type,
bool updateScore = true);
private: private:
/// Introduce the constraints associated with the given type variable /// Introduce the constraints associated with the given type variable
@@ -1720,6 +1810,10 @@ private:
unsigned idx2); unsigned idx2);
public: public:
/// Increase the score of the given kind for the current (partial) solution
/// along the.
void increaseScore(ScoreKind kind);
/// \brief Given a set of viable solutions, find the best /// \brief Given a set of viable solutions, find the best
/// solution. /// solution.
/// ///

View File

@@ -1504,7 +1504,8 @@ void Solution::dump(SourceManager *sm) const {
} }
void Solution::dump(SourceManager *sm, raw_ostream &out) const { void Solution::dump(SourceManager *sm, raw_ostream &out) const {
out << "Fixed score: " << getFixedScore() << "\n\n"; out << "Fixed score: " << FixedScore << "\n";
out << "Type variables:\n"; out << "Type variables:\n";
for (auto binding : typeBindings) { for (auto binding : typeBindings) {
out.indent(2); out.indent(2);
@@ -1552,6 +1553,7 @@ void ConstraintSystem::dump() {
} }
void ConstraintSystem::dump(raw_ostream &out) { void ConstraintSystem::dump(raw_ostream &out) {
out << "Score: " << CurrentScore << "\n";
out << "Type Variables:\n"; out << "Type Variables:\n";
for (auto tv : TypeVariables) { for (auto tv : TypeVariables) {
out.indent(2); out.indent(2);

View File

@@ -111,6 +111,7 @@ ProtocolDecl *TypeChecker::getLiteralProtocol(Expr *expr) {
if (isa<InterpolatedStringLiteralExpr>(expr)) if (isa<InterpolatedStringLiteralExpr>(expr))
return getProtocol(expr->getLoc(), return getProtocol(expr->getLoc(),
KnownProtocolKind::StringInterpolationConvertible); KnownProtocolKind::StringInterpolationConvertible);
if (auto E = dyn_cast<MagicIdentifierLiteralExpr>(expr)) { if (auto E = dyn_cast<MagicIdentifierLiteralExpr>(expr)) {
switch (E->getKind()) { switch (E->getKind()) {
case MagicIdentifierLiteralExpr::File: case MagicIdentifierLiteralExpr::File:
@@ -124,7 +125,7 @@ ProtocolDecl *TypeChecker::getLiteralProtocol(Expr *expr) {
} }
} }
llvm_unreachable("Unhandled literal kind"); return nullptr;
} }
Module *TypeChecker::getStdlibModule(const DeclContext *dc) { Module *TypeChecker::getStdlibModule(const DeclContext *dc) {