Merge pull request #23088 from DougGregor/solver-disjunction-favoring

[Constraint solver] Generalize disjunction favoring
This commit is contained in:
Doug Gregor
2019-03-05 14:41:56 -08:00
committed by GitHub
7 changed files with 157 additions and 163 deletions

View File

@@ -595,107 +595,74 @@ namespace {
/// of the overload set and call arguments.
///
/// \param expr The application.
/// \param isFavored Determine whether the given overload is favored.
/// \param isFavored Determine whether the given overload is favored, passing
/// it the "effective" overload type when it's being called.
/// \param mustConsider If provided, a function to detect the presence of
/// overloads which inhibit any overload from being favored.
void favorCallOverloads(ApplyExpr *expr,
ConstraintSystem &CS,
llvm::function_ref<bool(ValueDecl *)> isFavored,
llvm::function_ref<bool(ValueDecl *, Type)> isFavored,
std::function<bool(ValueDecl *)>
mustConsider = nullptr) {
// Find the type variable associated with the function, if any.
auto tyvarType = CS.getType(expr->getFn())->getAs<TypeVariableType>();
if (!tyvarType)
if (!tyvarType || CS.getFixedType(tyvarType))
return;
// This type variable is only currently associated with the function
// being applied, and the only constraint attached to it should
// be the disjunction constraint for the overload group.
auto &CG = CS.getConstraintGraph();
llvm::SetVector<Constraint *> disjunctions;
CG.gatherConstraints(tyvarType, disjunctions,
ConstraintGraph::GatheringKind::EquivalenceClass,
[](Constraint *constraint) -> bool {
return constraint->getKind() ==
ConstraintKind::Disjunction;
});
if (disjunctions.empty())
auto disjunction = CS.getUnboundBindOverloadDisjunction(tyvarType);
if (!disjunction)
return;
// Look for the disjunction that binds the overload set.
for (auto *disjunction : disjunctions) {
auto oldConstraints = disjunction->getNestedConstraints();
auto csLoc = CS.getConstraintLocator(expr->getFn());
// Only replace the disjunctive overload constraint.
if (oldConstraints[0]->getKind() != ConstraintKind::BindOverload) {
// Find the favored constraints and mark them.
SmallVector<Constraint *, 4> newlyFavoredConstraints;
unsigned numFavoredConstraints = 0;
Constraint *firstFavored = nullptr;
for (auto constraint : disjunction->getNestedConstraints()) {
if (!constraint->getOverloadChoice().isDecl())
continue;
auto decl = constraint->getOverloadChoice().getDecl();
if (mustConsider && mustConsider(decl)) {
// Roll back any constraints we favored.
for (auto favored : newlyFavoredConstraints)
favored->setFavored(false);
return;
}
if (mustConsider) {
bool hasMustConsider = false;
for (auto oldConstraint : oldConstraints) {
auto overloadChoice = oldConstraint->getOverloadChoice();
if (overloadChoice.isDecl() &&
mustConsider(overloadChoice.getDecl()))
hasMustConsider = true;
}
if (hasMustConsider) {
continue;
}
Type overloadType =
CS.getEffectiveOverloadType(constraint->getOverloadChoice(),
/*allowMembers=*/true, CS.DC);
if (!overloadType)
continue;
if (!decl->getAttrs().isUnavailable(CS.getASTContext()) &&
isFavored(decl, overloadType)) {
// If we might need to roll back the favored constraints, keep
// track of those we are favoring.
if (mustConsider && !constraint->isFavored())
newlyFavoredConstraints.push_back(constraint);
constraint->setFavored();
++numFavoredConstraints;
if (!firstFavored)
firstFavored = constraint;
}
}
// Copy over the existing bindings, dividing the constraints up
// into "favored" and non-favored lists.
SmallVector<Constraint *, 4> favoredConstraints;
SmallVector<Constraint *, 4> fallbackConstraints;
for (auto oldConstraint : oldConstraints) {
if (!oldConstraint->getOverloadChoice().isDecl())
continue;
auto decl = oldConstraint->getOverloadChoice().getDecl();
if (!decl->getAttrs().isUnavailable(CS.getASTContext()) &&
isFavored(decl))
favoredConstraints.push_back(oldConstraint);
else
fallbackConstraints.push_back(oldConstraint);
}
// If we did not find any favored constraints, we're done.
if (favoredConstraints.empty()) break;
if (favoredConstraints.size() == 1) {
auto overloadChoice = favoredConstraints[0]->getOverloadChoice();
auto overloadType = overloadChoice.getDecl()->getInterfaceType();
auto resultType = overloadType->getAs<AnyFunctionType>()->getResult();
// If there was one favored constraint, set the favored type based on its
// result type.
if (numFavoredConstraints == 1) {
auto overloadChoice = firstFavored->getOverloadChoice();
auto overloadType =
CS.getEffectiveOverloadType(overloadChoice, /*allowMembers=*/true,
CS.DC);
auto resultType = overloadType->castTo<AnyFunctionType>()->getResult();
if (!resultType->hasTypeParameter())
CS.setFavoredType(expr, resultType.getPointer());
}
// Remove the original constraint from the inactive constraint
// list and add the new one.
CS.removeInactiveConstraint(disjunction);
// Create the disjunction of favored constraints.
auto favoredConstraintsDisjunction =
Constraint::createDisjunction(CS,
favoredConstraints,
csLoc);
favoredConstraintsDisjunction->setFavored();
llvm::SmallVector<Constraint *, 2> aggregateConstraints;
aggregateConstraints.push_back(favoredConstraintsDisjunction);
if (!fallbackConstraints.empty()) {
// Find the disjunction of fallback constraints. If any
// constraints were added here, create a new disjunction.
Constraint *fallbackConstraintsDisjunction =
Constraint::createDisjunction(CS, fallbackConstraints, csLoc);
aggregateConstraints.push_back(fallbackConstraintsDisjunction);
}
CS.addDisjunctionConstraint(aggregateConstraints, csLoc);
break;
}
}
@@ -738,18 +705,11 @@ namespace {
void favorMatchingUnaryOperators(ApplyExpr *expr,
ConstraintSystem &CS) {
// Determine whether the given declaration is favored.
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
auto valueTy = value->getInterfaceType();
auto fnTy = valueTy->getAs<AnyFunctionType>();
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
auto fnTy = type->getAs<AnyFunctionType>();
if (!fnTy)
return false;
// Figure out the parameter type.
if (value->getDeclContext()->isTypeContext()) {
fnTy = fnTy->getResult()->castTo<AnyFunctionType>();
}
Type paramTy = FunctionType::composeInput(CS.getASTContext(),
fnTy->getParams(), false);
auto resultTy = fnTy->getResult();
@@ -791,10 +751,8 @@ namespace {
}
// Determine whether the given declaration is favored.
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
auto valueTy = value->getInterfaceType();
if (!valueTy->is<AnyFunctionType>())
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
if (!type->is<AnyFunctionType>())
return false;
auto paramCount = getParamCount(value);
@@ -809,23 +767,11 @@ namespace {
if (auto favoredTy = CS.getFavoredType(expr->getArg())) {
// Determine whether the given declaration is favored.
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
auto valueTy = value->getInterfaceType();
auto fnTy = valueTy->getAs<AnyFunctionType>();
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
auto fnTy = type->getAs<AnyFunctionType>();
if (!fnTy)
return false;
// Figure out the parameter type, accounting for the implicit 'self' if
// necessary.
if (auto *FD = dyn_cast<AbstractFunctionDecl>(value)) {
if (FD->hasImplicitSelfDecl()) {
if (auto resFnTy = fnTy->getResult()->getAs<AnyFunctionType>()) {
fnTy = resFnTy;
}
}
}
auto paramTy =
AnyFunctionType::composeInput(CS.getASTContext(), fnTy->getParams(),
/*canonicalVararg*/ false);
@@ -884,10 +830,8 @@ namespace {
};
// Determine whether the given declaration is favored.
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
auto valueTy = value->getInterfaceType();
auto fnTy = valueTy->getAs<AnyFunctionType>();
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
auto fnTy = type->getAs<AnyFunctionType>();
if (!fnTy)
return false;
@@ -913,11 +857,6 @@ namespace {
}
}
// Figure out the parameter type.
if (value->getDeclContext()->isTypeContext()) {
fnTy = fnTy->getResult()->castTo<AnyFunctionType>();
}
auto params = fnTy->getParams();
if (params.size() != 2)
return false;

View File

@@ -4538,7 +4538,6 @@ ConstraintSystem::simplifyEscapableFunctionOfConstraint(
return SolutionKind::Unsolved;
};
type2 = getFixedTypeRecursive(type2, flags, /*wantRValue=*/true);
if (auto fn2 = type2->getAs<FunctionType>()) {
// Solve forward by binding the other type variable to the escapable
@@ -5037,6 +5036,33 @@ retry_after_fail:
break;
}
// Collect the active overload choices.
SmallVector<OverloadChoice, 4> choices;
for (auto constraint : disjunction->getNestedConstraints()) {
if (constraint->isDisabled())
continue;
choices.push_back(constraint->getOverloadChoice());
}
// If we can favor one generic result over another, do so.
if (auto favoredChoice = tryOptimizeGenericDisjunction(choices)) {
unsigned favoredIndex = favoredChoice - choices.data();
for (auto constraint : disjunction->getNestedConstraints()) {
if (constraint->isDisabled())
continue;
if (favoredIndex == 0) {
if (solverState)
solverState->favorConstraint(constraint);
else
constraint->setFavored();
break;
} else {
--favoredIndex;
}
}
}
// If there was a constraint that we couldn't reason about, don't use the
// results of any common-type computations.

View File

@@ -424,6 +424,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
numCheckedConformances = cs.CheckedConformances.size();
numMissingMembers = cs.MissingMembers.size();
numDisabledConstraints = cs.solverState->getNumDisabledConstraints();
numFavoredConstraints = cs.solverState->getNumFavoredConstraints();
PreviousScore = cs.CurrentScore;
@@ -1909,21 +1910,6 @@ void ConstraintSystem::partitionForDesignatedTypes(
void ConstraintSystem::partitionDisjunction(
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
SmallVectorImpl<unsigned> &PartitionBeginning) {
// Maintain the original ordering, and make a single partition of
// disjunction choices.
auto originalOrdering = [&]() {
for (unsigned long i = 0, e = Choices.size(); i != e; ++i)
Ordering.push_back(i);
PartitionBeginning.push_back(0);
};
if (!TC.getLangOpts().SolverEnableOperatorDesignatedTypes ||
!isOperatorBindOverload(Choices[0])) {
originalOrdering();
return;
}
SmallSet<Constraint *, 16> taken;
// Local function used to iterate over the untaken choices from the
@@ -1937,33 +1923,45 @@ void ConstraintSystem::partitionDisjunction(
if (taken.count(constraint))
continue;
assert(constraint->getKind() == ConstraintKind::BindOverload);
assert(constraint->getOverloadChoice().isDecl());
if (fn(index, constraint))
taken.insert(constraint);
}
};
// First collect some things that we'll generally put near the end
// of the partitioning.
// First collect some things that we'll generally put near the beginning or
// end of the partitioning.
SmallVector<unsigned, 4> favored;
SmallVector<unsigned, 4> disabled;
SmallVector<unsigned, 4> unavailable;
// First collect disabled constraints.
// First collect disabled and favored constraints.
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
if (!constraint->isDisabled())
return false;
disabled.push_back(index);
return true;
if (constraint->isDisabled()) {
disabled.push_back(index);
return true;
}
if (constraint->isFavored()) {
favored.push_back(index);
return true;
}
return false;
});
// Then unavailable constraints if we're skipping them.
if (!shouldAttemptFixes()) {
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
if (constraint->getKind() != ConstraintKind::BindOverload)
return false;
if (!constraint->getOverloadChoice().isDecl())
return false;
auto *decl = constraint->getOverloadChoice().getDecl();
auto *funcDecl = cast<FuncDecl>(decl);
auto *funcDecl = dyn_cast<FuncDecl>(decl);
if (!funcDecl)
return false;
if (!funcDecl->getAttrs().isUnavailable(getASTContext()))
return false;
@@ -1983,7 +1981,10 @@ void ConstraintSystem::partitionDisjunction(
}
};
partitionForDesignatedTypes(Choices, forEachChoice, appendPartition);
if (TC.getLangOpts().SolverEnableOperatorDesignatedTypes &&
isOperatorBindOverload(Choices[0])) {
partitionForDesignatedTypes(Choices, forEachChoice, appendPartition);
}
SmallVector<unsigned, 4> everythingElse;
// Gather the remaining options.
@@ -1991,6 +1992,7 @@ void ConstraintSystem::partitionDisjunction(
everythingElse.push_back(index);
return true;
});
appendPartition(favored);
appendPartition(everythingElse);
// Now create the remaining partitions from what we previously collected.

View File

@@ -444,7 +444,7 @@ public:
}
/// Mark or retrieve whether this constraint should be favored in the system.
void setFavored() { IsFavored = true; }
void setFavored(bool favored = true) { IsFavored = favored; }
bool isFavored() const { return IsFavored; }
/// Whether the solver should remember which choice was taken for

View File

@@ -1373,17 +1373,16 @@ ConstraintSystem::getTypeOfMemberReference(
// Performance hack: if there are two generic overloads, and one is
// more specialized than the other, prefer the more-specialized one.
static void tryOptimizeGenericDisjunction(ConstraintSystem &cs,
ArrayRef<OverloadChoice> choices,
OverloadChoice *&favoredChoice) {
if (favoredChoice || choices.size() != 2)
return;
OverloadChoice *ConstraintSystem::tryOptimizeGenericDisjunction(
ArrayRef<OverloadChoice> choices) {
if (choices.size() != 2)
return nullptr;
const auto &choiceA = choices[0];
const auto &choiceB = choices[1];
if (!choiceA.isDecl() || !choiceB.isDecl())
return;
return nullptr;
auto isViable = [](ValueDecl *decl) -> bool {
assert(decl);
@@ -1410,22 +1409,17 @@ static void tryOptimizeGenericDisjunction(ConstraintSystem &cs,
auto *declB = choiceB.getDecl();
if (!isViable(declA) || !isViable(declB))
return;
auto &TC = cs.TC;
auto *DC = cs.DC;
return nullptr;
switch (TC.compareDeclarations(DC, declA, declB)) {
case Comparison::Better:
favoredChoice = const_cast<OverloadChoice *>(&choiceA);
break;
return const_cast<OverloadChoice *>(&choiceA);
case Comparison::Worse:
favoredChoice = const_cast<OverloadChoice *>(&choiceB);
break;
return const_cast<OverloadChoice *>(&choiceB);
case Comparison::Unordered:
break;
return nullptr;
}
}
@@ -1582,7 +1576,8 @@ void ConstraintSystem::addOverloadSet(Type boundType,
return;
}
tryOptimizeGenericDisjunction(*this, choices, favoredChoice);
if (!favoredChoice)
favoredChoice = tryOptimizeGenericDisjunction(choices);
SmallVector<OverloadChoice, 4> scratchChoices;
choices = partitionSIMDOperators(choices, scratchChoices);

View File

@@ -1317,6 +1317,15 @@ private:
disabledConstraints.erase(
disabledConstraints.begin() + scope->numDisabledConstraints,
disabledConstraints.end());
for (unsigned constraintIdx :
range(scope->numFavoredConstraints, favoredConstraints.size())) {
if (favoredConstraints[constraintIdx]->isFavored())
favoredConstraints[constraintIdx]->setFavored(false);
}
favoredConstraints.erase(
favoredConstraints.begin() + scope->numFavoredConstraints,
favoredConstraints.end());
}
/// Check whether constraint system is allowed to form solutions
@@ -1336,6 +1345,19 @@ private:
disabledConstraints.push_back(constraint);
}
unsigned getNumFavoredConstraints() const {
return favoredConstraints.size();
}
/// Favor the given constraint; this change will be rolled back
/// when we exit the current solver scope.
void favorConstraint(Constraint *constraint) {
if (!constraint->isFavored()) {
constraint->setFavored();
favoredConstraints.push_back(constraint);
}
}
private:
/// The list of constraints that have been retired along the
/// current path, this list is used in LIFO fashion when constraints
@@ -1358,6 +1380,7 @@ private:
std::tuple<SolverScope *, ConstraintList::iterator, unsigned>, 4> scopes;
SmallVector<Constraint *, 4> disabledConstraints;
SmallVector<Constraint *, 4> favoredConstraints;
};
class CacheExprTypes : public ASTWalker {
@@ -1518,6 +1541,8 @@ public:
unsigned numDisabledConstraints;
unsigned numFavoredConstraints;
/// The previous score.
Score PreviousScore;
@@ -3122,12 +3147,14 @@ private:
bool restoreOnFail,
llvm::function_ref<bool(Constraint *)> pred);
public:
// Given a type variable, attempt to find the disjunction of
// bind overloads associated with it. This may return null in cases where
// the disjunction has either not been created or binds the type variable
// in some manner other than by binding overloads.
Constraint *getUnboundBindOverloadDisjunction(TypeVariableType *tyvar);
private:
/// Given a type variable that might represent an overload set, retrieve
///
/// \returns the set of overload choices to which this type variable
@@ -3160,6 +3187,11 @@ private:
Constraint *selectApplyDisjunction();
/// Look at the set of overload choices to determine if there is a best
/// generic overload to favor.
OverloadChoice *tryOptimizeGenericDisjunction(
ArrayRef<OverloadChoice> choices);
/// Solve the system of constraints generated from provided expression.
///
/// \param expr The expression to generate constraints from.
@@ -3751,7 +3783,7 @@ private:
/// easy to work with disjunction and encapsulates
/// some other important information such as locator.
class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
// The disjunciton choices that this producer will iterate through.
// The disjunction choices that this producer will iterate through.
ArrayRef<Constraint *> Choices;
// The ordering of disjunction choices. We index into Choices

View File

@@ -3,5 +3,5 @@
// SR-139:
// Infinite recursion parsing bitwise operators
let x = UInt32(0x1FF)&0xFF << 24 | UInt32(0x1FF)&0xFF << 16 | UInt32(0x1FF)&0xFF << 8 | (UInt32(0x1FF)&0xFF); // expected-error {{reasonable time}}
let x = UInt32(0x1FF)&0xFF << 24 | UInt32(0x1FF)&0xFF << 16 | UInt32(0x1FF)&0xFF << 8 | (UInt32(0x1FF)&0xFF)