Files
swift-mirror/lib/Sema/CSOptimizer.cpp

848 lines
30 KiB
C++

//===--- CSOptimizer.cpp - Constraint Optimizer ---------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements disjunction and other constraint optimizations.
//
//===----------------------------------------------------------------------===//
#include "TypeChecker.h"
#include "swift/AST/ExistentialLayout.h"
#include "swift/AST/GenericSignature.h"
#include "swift/Basic/OptionSet.h"
#include "swift/Sema/ConstraintGraph.h"
#include "swift/Sema/ConstraintSystem.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TinyPtrVector.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/raw_ostream.h"
#include <cstddef>
#include <functional>
using namespace swift;
using namespace constraints;
namespace {
// TODO: both `isIntegerType` and `isFloatType` should be available on Type
// as `isStdlib{Integer, Float}Type`.
static bool isIntegerType(Type type) {
return type->isInt() || type->isInt8() || type->isInt16() ||
type->isInt32() || type->isInt64() || type->isUInt() ||
type->isUInt8() || type->isUInt16() || type->isUInt32() ||
type->isUInt64();
}
static bool isFloatType(Type type) {
return type->isFloat() || type->isDouble() || type->isFloat80();
}
static bool isSupportedOperator(Constraint *disjunction) {
if (!isOperatorDisjunction(disjunction))
return false;
auto choices = disjunction->getNestedConstraints();
auto *decl = getOverloadChoiceDecl(choices.front());
auto name = decl->getBaseIdentifier();
if (name.isArithmeticOperator() || name.isStandardComparisonOperator() ||
name.is("^")) {
return true;
}
// Operators like &<<, &>>, &+, .== etc.
if (llvm::any_of(choices, [](Constraint *choice) {
return isSIMDOperator(getOverloadChoiceDecl(choice));
})) {
return true;
}
return false;
}
static bool isSupportedSpecialConstructor(ConstructorDecl *ctor) {
if (auto *selfDecl = ctor->getImplicitSelfDecl()) {
auto selfTy = selfDecl->getInterfaceType();
/// Support `Int*`, `Float*` and `Double` initializers since their generic
/// overloads are not too complicated.
return selfTy && (isIntegerType(selfTy) || isFloatType(selfTy));
}
return false;
}
static bool isStandardComparisonOperator(ValueDecl *decl) {
return decl->isOperator() &&
decl->getBaseIdentifier().isStandardComparisonOperator();
}
static bool isArithmeticOperator(ValueDecl *decl) {
return decl->isOperator() && decl->getBaseIdentifier().isArithmeticOperator();
}
static bool isSupportedDisjunction(Constraint *disjunction) {
auto choices = disjunction->getNestedConstraints();
if (isSupportedOperator(disjunction))
return true;
if (auto *ctor = dyn_cast_or_null<ConstructorDecl>(
getOverloadChoiceDecl(choices.front()))) {
if (isSupportedSpecialConstructor(ctor))
return true;
}
// Non-operator disjunctions are supported only if they don't
// have any generic choices.
return llvm::all_of(choices, [&](Constraint *choice) {
if (choice->getKind() != ConstraintKind::BindOverload)
return false;
if (auto *decl = getOverloadChoiceDecl(choice))
return decl->getInterfaceType()->is<FunctionType>();
return false;
});
}
NullablePtr<Constraint> getApplicableFnConstraint(ConstraintGraph &CG,
Constraint *disjunction) {
auto *boundVar = disjunction->getNestedConstraints()[0]
->getFirstType()
->getAs<TypeVariableType>();
if (!boundVar)
return nullptr;
auto constraints = CG.gatherConstraints(
boundVar, ConstraintGraph::GatheringKind::EquivalenceClass,
[](Constraint *constraint) {
return constraint->getKind() == ConstraintKind::ApplicableFunction;
});
if (constraints.size() != 1)
return nullptr;
auto *applicableFn = constraints.front();
// Unapplied disjunction could appear as a argument to applicable function,
// we are not interested in that.
return applicableFn->getSecondType()->isEqual(boundVar) ? applicableFn
: nullptr;
}
void forEachDisjunctionChoice(
ConstraintSystem &cs, Constraint *disjunction,
llvm::function_ref<void(Constraint *, ValueDecl *decl, FunctionType *)>
callback) {
for (auto constraint : disjunction->getNestedConstraints()) {
if (constraint->isDisabled())
continue;
if (constraint->getKind() != ConstraintKind::BindOverload)
continue;
auto choice = constraint->getOverloadChoice();
auto *decl = choice.getDeclOrNull();
if (!decl)
continue;
// If disjunction choice is unavailable or disfavored we cannot
// do anything with it.
if (decl->getAttrs().hasAttribute<DisfavoredOverloadAttr>() ||
cs.isDeclUnavailable(decl, disjunction->getLocator()))
continue;
Type overloadType =
cs.getEffectiveOverloadType(disjunction->getLocator(), choice,
/*allowMembers=*/true, cs.DC);
if (!overloadType || !overloadType->is<FunctionType>())
continue;
callback(constraint, decl, overloadType->castTo<FunctionType>());
}
}
static bool isOverloadedDeclRef(Constraint *disjunction) {
assert(disjunction->getKind() == ConstraintKind::Disjunction);
return disjunction->getLocator()->directlyAt<OverloadedDeclRefExpr>();
}
} // end anonymous namespace
/// Given a set of disjunctions, attempt to determine
/// favored choices in the current context.
static Constraint *determineBestChoicesInContext(
ConstraintSystem &cs, SmallVectorImpl<Constraint *> &disjunctions,
llvm::DenseMap<Constraint *,
std::pair<double, llvm::TinyPtrVector<Constraint *>>>
&favorings) {
double bestOverallScore = 0.0;
// Tops scores across all of the disjunctions.
llvm::DenseMap<Constraint *, double> disjunctionScores;
llvm::DenseMap<Constraint *, llvm::TinyPtrVector<Constraint *>>
favoredChoicesPerDisjunction;
for (auto *disjunction : disjunctions) {
if (!isSupportedDisjunction(disjunction))
continue;
auto applicableFn =
getApplicableFnConstraint(cs.getConstraintGraph(), disjunction);
if (applicableFn.isNull())
continue;
auto argFuncType =
applicableFn.get()->getFirstType()->getAs<FunctionType>();
auto argumentList = cs.getArgumentList(applicableFn.get()->getLocator());
if (!argumentList)
return nullptr;
for (const auto &argument : *argumentList) {
if (auto *expr = argument.getExpr()) {
// Directly `<#...#>` or has one inside.
if (isa<CodeCompletionExpr>(expr) ||
cs.containsIDEInspectionTarget(expr))
return nullptr;
}
}
SmallVector<FunctionType::Param, 8> argsWithLabels;
{
argsWithLabels.append(argFuncType->getParams().begin(),
argFuncType->getParams().end());
FunctionType::relabelParams(argsWithLabels, argumentList);
}
SmallVector<SmallVector<std::pair<Type, /*fromLiteral=*/bool>, 2>, 2>
candidateArgumentTypes;
candidateArgumentTypes.resize(argFuncType->getNumParams());
llvm::TinyPtrVector<Type> resultTypes;
for (unsigned i = 0, n = argFuncType->getNumParams(); i != n; ++i) {
const auto &param = argFuncType->getParams()[i];
auto argType = cs.simplifyType(param.getPlainType());
SmallVector<std::pair<Type, bool>, 2> types;
if (auto *typeVar = argType->getAs<TypeVariableType>()) {
auto bindingSet = cs.getBindingsFor(typeVar, /*finalize=*/true);
for (const auto &binding : bindingSet.Bindings) {
types.push_back({binding.BindingType, /*fromLiteral=*/false});
}
for (const auto &literal : bindingSet.Literals) {
if (literal.second.hasDefaultType()) {
// Add primary default type
types.push_back(
{literal.second.getDefaultType(), /*fromLiteral=*/true});
}
}
} else {
types.push_back({argType, /*fromLiteral=*/false});
}
candidateArgumentTypes[i].append(types);
}
auto resultType = cs.simplifyType(argFuncType->getResult());
if (auto *typeVar = resultType->getAs<TypeVariableType>()) {
auto bindingSet = cs.getBindingsFor(typeVar, /*finalize=*/true);
for (const auto &binding : bindingSet.Bindings) {
resultTypes.push_back(binding.BindingType);
}
} else {
resultTypes.push_back(resultType);
}
// Match arguments to the given overload choice.
auto matchArguments = [&](OverloadChoice choice, FunctionType *overloadType)
-> std::optional<MatchCallArgumentResult> {
auto *decl = choice.getDeclOrNull();
assert(decl);
auto hasAppliedSelf =
decl->hasCurriedSelf() &&
doesMemberRefApplyCurriedSelf(choice.getBaseType(), decl);
ParameterListInfo paramListInfo(overloadType->getParams(), decl,
hasAppliedSelf);
MatchCallArgumentListener listener;
return matchCallArguments(argsWithLabels, overloadType->getParams(),
paramListInfo,
argumentList->getFirstTrailingClosureIndex(),
/*allow fixes*/ false, listener, std::nullopt);
};
// Determine whether the candidate type is a subclass of the superclass
// type.
std::function<bool(Type, Type)> isSubclassOf = [&](Type candidateType,
Type superclassType) {
// Conversion from a concrete type to its existential value.
if (superclassType->isExistentialType() && !superclassType->isAny()) {
auto layout = superclassType->getExistentialLayout();
if (auto layoutConstraint = layout.getLayoutConstraint()) {
if (layoutConstraint->isClass() &&
!(candidateType->isClassExistentialType() ||
candidateType->mayHaveSuperclass()))
return false;
}
if (layout.explicitSuperclass &&
!isSubclassOf(candidateType, layout.explicitSuperclass))
return false;
return llvm::all_of(layout.getProtocols(), [&](ProtocolDecl *P) {
if (auto superclass = P->getSuperclassDecl()) {
if (!isSubclassOf(candidateType,
superclass->getDeclaredInterfaceType()))
return false;
}
return bool(TypeChecker::containsProtocol(candidateType, P,
/*allowMissing=*/false));
});
}
auto *subclassDecl = candidateType->getClassOrBoundGenericClass();
auto *superclassDecl = superclassType->getClassOrBoundGenericClass();
if (!(subclassDecl && superclassDecl))
return false;
return superclassDecl->isSuperclassOf(subclassDecl);
};
enum class MatchFlag {
OnParam = 0x01,
Literal = 0x02,
ExactOnly = 0x04,
};
using MatchOptions = OptionSet<MatchFlag>;
// Perform a limited set of checks to determine whether the candidate
// could possibly match the parameter type:
//
// - Equality
// - Protocol conformance(s)
// - Optional injection
// - Superclass conversion
// - Array-to-pointer conversion
// - Value to existential conversion
// - Exact match on top-level types
std::function<double(GenericSignature, Type, Type, MatchOptions)>
scoreCandidateMatch = [&](GenericSignature genericSig,
Type candidateType, Type paramType,
MatchOptions options) -> double {
if (options.contains(MatchFlag::ExactOnly))
return candidateType->isEqual(paramType) ? 1 : 0;
// Exact match between candidate and parameter types.
if (candidateType->isEqual(paramType))
return options.contains(MatchFlag::Literal) ? 0.3 : 1;
if (options.contains(MatchFlag::Literal))
return 0;
// Check whether match would require optional injection.
{
SmallVector<Type, 2> candidateOptionals;
SmallVector<Type, 2> paramOptionals;
candidateType =
candidateType->lookThroughAllOptionalTypes(candidateOptionals);
paramType = paramType->lookThroughAllOptionalTypes(paramOptionals);
if (!candidateOptionals.empty() || !paramOptionals.empty()) {
if (paramOptionals.size() >= candidateOptionals.size()) {
return scoreCandidateMatch(genericSig, candidateType, paramType,
options);
}
// Optionality mismatch.
return 0;
}
}
// Candidate could be converted to a superclass.
if (isSubclassOf(candidateType, paramType))
return 1;
// Possible Array<T> -> Unsafe*Pointer conversion.
if (options.contains(MatchFlag::OnParam)) {
if (candidateType->isArrayType() &&
paramType->getAnyPointerElementType())
return 1;
}
// If both argument and parameter are tuples of the same arity,
// it's a match.
{
if (auto *candidateTuple = candidateType->getAs<TupleType>()) {
auto *paramTuple = paramType->getAs<TupleType>();
if (paramTuple &&
candidateTuple->getNumElements() == paramTuple->getNumElements())
return 1;
}
}
// Check protocol requirement(s) if this parameter is a
// generic parameter type.
if (genericSig && paramType->isTypeParameter()) {
auto protocolRequirements = genericSig->getRequiredProtocols(paramType);
// It's a generic parameter or dependent member which might
// be connected via ame-type constraints to other generic
// parameters or dependent member but we cannot check that here,
// so let's add a tiny score just to acknowledge that it could
// possibly match.
if (protocolRequirements.empty())
return 0.01;
if (llvm::all_of(protocolRequirements, [&](ProtocolDecl *protocol) {
return bool(cs.lookupConformance(candidateType, protocol));
})) {
if (auto *GP = paramType->getAs<GenericTypeParamType>()) {
auto *paramDecl = GP->getDecl();
if (paramDecl && paramDecl->isOpaqueType())
return 1.0;
}
return 0.7;
}
}
// Parameter is generic, let's check whether top-level
// types match i.e. Array<Element> as a parameter.
//
// This is slightly better than all of the conformances matching
// because the parameter is concrete and could split the graph.
if (paramType->hasTypeParameter()) {
auto *candidateDecl = candidateType->getAnyNominal();
auto *paramDecl = paramType->getAnyNominal();
if (candidateDecl && paramDecl && candidateDecl == paramDecl)
return 0.8;
}
return 0;
};
// The choice with the best score.
double bestScore = 0.0;
SmallVector<std::pair<Constraint *, double>, 2> favoredChoices;
bool isOverloadedDeclRefDisjunction = isOverloadedDeclRef(disjunction);
forEachDisjunctionChoice(
cs, disjunction,
[&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) {
GenericSignature genericSig;
{
if (auto *GF = dyn_cast<AbstractFunctionDecl>(decl)) {
genericSig = GF->getGenericSignature();
} else if (auto *SD = dyn_cast<SubscriptDecl>(decl)) {
genericSig = SD->getGenericSignature();
}
}
auto matchings =
matchArguments(choice->getOverloadChoice(), overloadType);
if (!matchings)
return;
bool favorExactMatchesOnly = false;
// Preserves old behavior where for unary calls to members
// the solver would not consider choices that didn't match on
// the number of parameters (regardless of defaults) and only
// exact matches were favored.
if (!isOverloadedDeclRefDisjunction && argumentList->size() == 1) {
// Old behavior completely disregarded the fact that some of
// the parameters could be defaulted.
if (overloadType->getNumParams() != 1)
return;
favorExactMatchesOnly = true;
}
double score = 0.0;
unsigned numDefaulted = 0;
for (unsigned paramIdx = 0, n = overloadType->getNumParams();
paramIdx != n; ++paramIdx) {
const auto &param = overloadType->getParams()[paramIdx];
auto argIndices = matchings->parameterBindings[paramIdx];
switch (argIndices.size()) {
case 0:
// Current parameter is defaulted, mark and continue.
++numDefaulted;
continue;
case 1:
// One-to-one match between argument and parameter.
break;
default:
// Cannot deal with multiple possible matchings at the moment.
return;
}
auto argIdx = argIndices.front();
// Looks like there is nothing know about the argument.
if (candidateArgumentTypes[argIdx].empty())
continue;
const auto paramFlags = param.getParameterFlags();
// If parameter is variadic we cannot compare because we don't know
// real arity.
if (paramFlags.isVariadic())
continue;
auto paramType = param.getPlainType();
// FIXME: Let's skip matching function types for now
// because they have special rules for e.g. Concurrency
// (around @Sendable) and @convention(c).
if (paramType->is<FunctionType>())
continue;
// The idea here is to match the parameter type against
// all of the argument candidate types and pick the best
// match (i.e. exact equality one).
//
// If none of the candidates match exactly and they are
// all bound concrete types, we consider this is mismatch
// at this parameter position and remove the overload choice
// from consideration.
double bestCandidateScore = 0;
llvm::BitVector mismatches(candidateArgumentTypes[argIdx].size());
for (unsigned candidateIdx :
indices(candidateArgumentTypes[argIdx])) {
// If one of the candidates matched exactly there is no reason
// to continue checking.
if (bestCandidateScore == 1)
break;
Type candidateType;
bool isLiteralDefault;
std::tie(candidateType, isLiteralDefault) =
candidateArgumentTypes[argIdx][candidateIdx];
// `inout` parameter accepts only l-value argument.
if (paramFlags.isInOut() && !candidateType->is<LValueType>()) {
mismatches.set(candidateIdx);
continue;
}
// The specifier only matters for `inout` check.
candidateType = candidateType->getWithoutSpecifierType();
MatchOptions options(MatchFlag::OnParam);
if (isLiteralDefault)
options |= MatchFlag::Literal;
if (favorExactMatchesOnly)
options |= MatchFlag::ExactOnly;
auto score = scoreCandidateMatch(genericSig, candidateType,
paramType, options);
if (score > 0) {
bestCandidateScore = std::max(bestCandidateScore, score);
continue;
}
// Only established arguments could be considered mismatches,
// literal default types should be regarded as holes if they
// didn't match.
if (!isLiteralDefault && !candidateType->hasTypeVariable())
mismatches.set(candidateIdx);
}
// If none of the candidates for this parameter matched, let's
// drop this overload from any further consideration.
if (mismatches.all())
return;
score += bestCandidateScore;
}
// An overload whether all of the parameters are defaulted
// that's called without arguments.
if (numDefaulted == overloadType->getNumParams())
return;
// Average the score to avoid disfavoring disjunctions with fewer
// parameters.
score /= (overloadType->getNumParams() - numDefaulted);
// If one of the result types matches exactly, that's a good
// indication that overload choice should be favored.
//
// If nothing is known about the arguments it's only safe to
// check result for operators (except to standard comparison
// ones that all have the same result type), regular
// functions/methods and especially initializers could end up
// with a lot of favored overloads because on the result type alone.
if (decl->isOperator() && !isStandardComparisonOperator(decl)) {
if (llvm::any_of(resultTypes, [&](const Type candidateResultTy) {
return scoreCandidateMatch(genericSig,
overloadType->getResult(),
candidateResultTy,
/*options=*/{}) > 0;
})) {
score += 1.0;
}
}
if (score > 0) {
// Nudge the score slightly to prefer concrete homogeneous
// arithmetic operators.
//
// This is an opportunistic optimization based on the operator
// use patterns where homogeneous operators are the most
// heavily used ones.
if (isArithmeticOperator(decl) &&
overloadType->getNumParams() == 2) {
auto resultTy = overloadType->getResult();
if (!resultTy->hasTypeParameter() &&
llvm::all_of(overloadType->getParams(),
[&resultTy](const auto &param) {
return param.getPlainType()->isEqual(resultTy);
}))
score += 0.001;
}
favoredChoices.push_back({choice, score});
bestScore = std::max(bestScore, score);
}
});
if (cs.isDebugMode()) {
PrintOptions PO;
PO.PrintTypesForDebugging = true;
llvm::errs().indent(cs.solverState->getCurrentIndent())
<< "<<< Disjunction "
<< disjunction->getNestedConstraints()[0]->getFirstType()->getString(
PO)
<< " with score " << bestScore << "\n";
}
// No matching overload choices to favor.
if (bestScore == 0.0)
continue;
bestOverallScore = std::max(bestOverallScore, bestScore);
disjunctionScores[disjunction] = bestScore;
for (const auto &choice : favoredChoices) {
if (choice.second == bestScore)
favoredChoicesPerDisjunction[disjunction].push_back(choice.first);
}
}
if (cs.isDebugMode() && bestOverallScore > 0) {
PrintOptions PO;
PO.PrintTypesForDebugging = true;
auto getLogger = [&](unsigned extraIndent = 0) -> llvm::raw_ostream & {
return llvm::errs().indent(cs.solverState->getCurrentIndent() +
extraIndent);
};
{
auto &log = getLogger();
log << "(Optimizing disjunctions: [";
interleave(
disjunctions,
[&](const auto *disjunction) {
log << disjunction->getNestedConstraints()[0]
->getFirstType()
->getString(PO);
},
[&]() { log << ", "; });
log << "]\n";
}
getLogger(/*extraIndent=*/4)
<< "Best overall score = " << bestOverallScore << '\n';
for (const auto &entry : disjunctionScores) {
getLogger(/*extraIndent=*/4)
<< "[Disjunction '"
<< entry.first->getNestedConstraints()[0]->getFirstType()->getString(
PO)
<< "' with score = " << entry.second << '\n';
for (const auto *choice : favoredChoicesPerDisjunction[entry.first]) {
auto &log = getLogger(/*extraIndent=*/6);
log << "- ";
choice->print(log, &cs.getASTContext().SourceMgr);
log << '\n';
}
getLogger(/*extraIdent=*/4) << "]\n";
}
getLogger() << ")\n";
}
if (bestOverallScore == 0)
return nullptr;
for (auto &entry : disjunctionScores) {
TinyPtrVector<Constraint *> favoredChoices;
for (auto *choice : favoredChoicesPerDisjunction[entry.first])
favoredChoices.push_back(choice);
favorings[entry.first] = std::make_pair(entry.second, favoredChoices);
}
Constraint *bestDisjunction = nullptr;
for (auto *disjunction : disjunctions) {
if (disjunctionScores[disjunction] != bestOverallScore)
continue;
if (!bestDisjunction)
bestDisjunction = disjunction;
else // Multiple disjunctions with the same score.
return nullptr;
}
return bestDisjunction;
}
// Attempt to find a disjunction of bind constraints where all options
// in the disjunction are binding the same type variable.
//
// Prefer disjunctions where the bound type variable is also the
// right-hand side of a conversion constraint, since having a concrete
// type that we're converting to can make it possible to split the
// constraint system into multiple ones.
static Constraint *
selectBestBindingDisjunction(ConstraintSystem &cs,
SmallVectorImpl<Constraint *> &disjunctions) {
if (disjunctions.empty())
return nullptr;
auto getAsTypeVar = [&cs](Type type) {
return cs.simplifyType(type)->getRValueType()->getAs<TypeVariableType>();
};
Constraint *firstBindDisjunction = nullptr;
for (auto *disjunction : disjunctions) {
auto choices = disjunction->getNestedConstraints();
assert(!choices.empty());
auto *choice = choices.front();
if (choice->getKind() != ConstraintKind::Bind)
continue;
// We can judge disjunction based on the single choice
// because all of choices (of bind overload set) should
// have the same left-hand side.
// Only do this for simple type variable bindings, not for
// bindings like: ($T1) -> $T2 bind String -> Int
auto *typeVar = getAsTypeVar(choice->getFirstType());
if (!typeVar)
continue;
if (!firstBindDisjunction)
firstBindDisjunction = disjunction;
auto constraints = cs.getConstraintGraph().gatherConstraints(
typeVar, ConstraintGraph::GatheringKind::EquivalenceClass,
[](Constraint *constraint) {
return constraint->getKind() == ConstraintKind::Conversion;
});
for (auto *constraint : constraints) {
if (typeVar == getAsTypeVar(constraint->getSecondType()))
return disjunction;
}
}
// If we had any binding disjunctions, return the first of
// those. These ensure that we attempt to bind types earlier than
// trying the elements of other disjunctions, which can often mean
// we fail faster.
return firstBindDisjunction;
}
std::optional<std::pair<Constraint *, llvm::TinyPtrVector<Constraint *>>>
ConstraintSystem::selectDisjunction() {
SmallVector<Constraint *, 4> disjunctions;
collectDisjunctions(disjunctions);
if (disjunctions.empty())
return std::nullopt;
if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions))
return std::make_pair(disjunction, llvm::TinyPtrVector<Constraint *>());
llvm::DenseMap<Constraint *,
std::pair</*bestScore=*/double, llvm::TinyPtrVector<Constraint *>>>
favorings;
if (auto *bestDisjunction =
determineBestChoicesInContext(*this, disjunctions, favorings))
return std::make_pair(bestDisjunction, favorings[bestDisjunction].second);
// Pick the disjunction with the smallest number of favored, then active
// choices.
auto bestDisjunction = std::min_element(
disjunctions.begin(), disjunctions.end(),
[&](Constraint *first, Constraint *second) -> bool {
unsigned firstActive = first->countActiveNestedConstraints();
unsigned secondActive = second->countActiveNestedConstraints();
auto &[firstScore, firstFavoredChoices] = favorings[first];
auto &[secondScore, secondFavoredChoices] = favorings[second];
// Rank based on scores only if both disjunctions are supported.
if (isSupportedDisjunction(first) && isSupportedDisjunction(second)) {
// If both disjunctions have the same score they should be ranked
// based on number of favored/active choices.
if (firstScore != secondScore)
return firstScore > secondScore;
}
unsigned numFirstFavored = firstFavoredChoices.size();
unsigned numSecondFavored = secondFavoredChoices.size();
if (numFirstFavored == numSecondFavored) {
if (firstActive != secondActive)
return firstActive < secondActive;
}
numFirstFavored = numFirstFavored ? numFirstFavored : firstActive;
numSecondFavored = numSecondFavored ? numSecondFavored : secondActive;
return numFirstFavored < numSecondFavored;
});
if (bestDisjunction != disjunctions.end())
return std::make_pair(*bestDisjunction, favorings[*bestDisjunction].second);
return std::nullopt;
}