//===--- CSOptimizer.cpp - Constraint Optimizer ---------------------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2025 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements disjunction and other constraint optimizations. // //===----------------------------------------------------------------------===// #include "TypeChecker.h" #include "OpenedExistentials.h" #include "swift/AST/ConformanceLookup.h" #include "swift/AST/ExistentialLayout.h" #include "swift/AST/GenericSignature.h" #include "swift/Basic/Defer.h" #include "swift/Basic/OptionSet.h" #include "swift/Sema/ConstraintGraph.h" #include "swift/Sema/ConstraintSystem.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/raw_ostream.h" #include #include using namespace swift; using namespace constraints; namespace { struct DisjunctionInfo { /// The score of the disjunction is the highest score from its choices. /// If the score is nullopt it means that the disjunction is not optimizable. std::optional Score; /// The highest scoring choices that could be favored when disjunction /// is attempted. llvm::TinyPtrVector FavoredChoices; /// Whether the decisions were based on speculative information /// i.e. literal argument candidates or initializer type inference. bool IsSpeculative; DisjunctionInfo() = default; DisjunctionInfo(std::optional score, ArrayRef favoredChoices, bool speculative) : Score(score), FavoredChoices(favoredChoices), IsSpeculative(speculative) {} static DisjunctionInfo none() { return {std::nullopt, {}, false}; } }; class DisjunctionInfoBuilder { std::optional Score; SmallVector FavoredChoices; bool IsSpeculative; public: DisjunctionInfoBuilder(std::optional score) : DisjunctionInfoBuilder(score, {}) {} DisjunctionInfoBuilder(std::optional score, ArrayRef favoredChoices) : Score(score), FavoredChoices(favoredChoices.begin(), favoredChoices.end()), IsSpeculative(false) {} void setFavoredChoices(ArrayRef choices) { FavoredChoices.clear(); FavoredChoices.append(choices.begin(), choices.end()); } void addFavoredChoice(Constraint *constraint) { FavoredChoices.push_back(constraint); } void setSpeculative(bool value = true) { IsSpeculative = value; } DisjunctionInfo build() { return {Score, FavoredChoices, IsSpeculative}; } }; static DeclContext *getDisjunctionDC(Constraint *disjunction) { auto *choice = disjunction->getNestedConstraints()[0]; switch (choice->getKind()) { case ConstraintKind::BindOverload: case ConstraintKind::ValueMember: case ConstraintKind::UnresolvedValueMember: case ConstraintKind::ValueWitness: return choice->getDeclContext(); default: return nullptr; } } /// Determine whether the given disjunction appears in a context /// transformed by a result builder. static bool isInResultBuilderContext(ConstraintSystem &cs, Constraint *disjunction) { auto *DC = getDisjunctionDC(disjunction); if (!DC) return false; do { auto fnContext = AnyFunctionRef::fromDeclContext(DC); if (!fnContext) return false; if (cs.getAppliedResultBuilderTransform(*fnContext)) return true; } while ((DC = DC->getParent())); return false; } /// If the given operator disjunction appears in some position // inside of a not yet resolved call i.e. `a.b(1 + c(4) - 1)` // both `+` and `-` are "in" argument context of `b`. static bool isOperatorPassedToUnresolvedCall(ConstraintSystem &cs, Constraint *disjunction) { ASSERT(isOperatorDisjunction(disjunction)); auto *curr = castToExpr(disjunction->getLocator()->getAnchor()); while (auto *parent = cs.getParentExpr(curr)) { SWIFT_DEFER { curr = parent; }; switch (parent->getKind()) { case ExprKind::OptionalEvaluation: case ExprKind::Paren: case ExprKind::Binary: case ExprKind::PrefixUnary: case ExprKind::PostfixUnary: continue; // a.b(<> ? <> : <<...>>) case ExprKind::Ternary: { auto *T = cast(parent); // If the operator is located in the condition it's // not tied to the context. if (T->getCondExpr() == curr) return false; // But the branches are connected to the context. continue; } // Handles `a(<>), `a[<>]`, // `.a(<>)` etc. case ExprKind::Call: { auto *call = cast(parent); // Type(...) if (isa(call->getFn())) { auto *ctorLoc = cs.getConstraintLocator( call, {LocatorPathElt::ApplyFunction(), LocatorPathElt::ConstructorMember()}); return !cs.findSelectedOverloadFor(ctorLoc); } // Ignore injected result builder methods like `buildExpression` // and `buildBlock`. if (auto *UDE = dyn_cast(call->getFn())) { if (isResultBuilderMethodReference(cs.getASTContext(), UDE)) return false; } return !cs.findSelectedOverloadFor(call->getFn()); } default: return false; } } return false; } // TODO: both `isIntegerType` and `isFloatType` should be available on Type // as `isStdlib{Integer, Float}Type`. static bool isIntegerType(Type type) { return type->isInt() || type->isInt8() || type->isInt16() || type->isInt32() || type->isInt64() || type->isUInt() || type->isUInt8() || type->isUInt16() || type->isUInt32() || type->isUInt64(); } static bool isFloatType(Type type) { return type->isFloat() || type->isDouble() || type->isFloat80(); } static bool isUnboundArrayType(Type type) { if (auto *UGT = type->getAs()) return UGT->getDecl() == type->getASTContext().getArrayDecl(); return false; } static bool isUnboundDictionaryType(Type type) { if (auto *UGT = type->getAs()) return UGT->getDecl() == type->getASTContext().getDictionaryDecl(); return false; } static bool isSupportedOperator(Constraint *disjunction) { if (!isOperatorDisjunction(disjunction)) return false; auto choices = disjunction->getNestedConstraints(); auto *decl = getOverloadChoiceDecl(choices.front()); auto name = decl->getBaseIdentifier(); if (name.isArithmeticOperator() || name.isStandardComparisonOperator() || name.isBitwiseOperator() || name.isNilCoalescingOperator()) { return true; } // Operators like &<<, &>>, &+, .== etc. if (llvm::any_of(choices, [](Constraint *choice) { return isSIMDOperator(getOverloadChoiceDecl(choice)); })) { return true; } return false; } static bool isSupportedSpecialConstructor(ConstructorDecl *ctor) { if (auto *selfDecl = ctor->getImplicitSelfDecl()) { auto selfTy = selfDecl->getInterfaceType(); /// Support `Int*`, `Float*` and `Double` initializers since their generic /// overloads are not too complicated. return selfTy && (isIntegerType(selfTy) || isFloatType(selfTy)); } return false; } static bool isStandardComparisonOperator(Constraint *disjunction) { auto *choice = disjunction->getNestedConstraints()[0]; if (auto *decl = getOverloadChoiceDecl(choice)) return decl->isOperator() && decl->getBaseIdentifier().isStandardComparisonOperator(); return false; } static bool isStandardInfixLogicalOperator(Constraint *disjunction) { auto *choice = disjunction->getNestedConstraints()[0]; if (auto *decl = getOverloadChoiceDecl(choice)) return decl->isOperator() && decl->getBaseIdentifier().isStandardInfixLogicalOperator(); return false; } static bool isOperatorNamed(Constraint *disjunction, StringRef name) { auto *choice = disjunction->getNestedConstraints()[0]; if (auto *decl = getOverloadChoiceDecl(choice)) return decl->isOperator() && decl->getBaseIdentifier().is(name); return false; } static bool isArithmeticOperator(ValueDecl *decl) { return decl->isOperator() && decl->getBaseIdentifier().isArithmeticOperator(); } /// Generic choices are supported only if they are not complex enough /// that would they'd require solving to figure out whether they are a /// potential match or not. static bool isSupportedGenericOverloadChoice(ValueDecl *decl, GenericFunctionType *choiceType) { // Same type requirements cannot be handled because each // candidate-parameter pair is (currently) considered in isolation. if (llvm::any_of(choiceType->getRequirements(), [](const Requirement &req) { switch (req.getKind()) { case RequirementKind::SameType: case RequirementKind::SameShape: return true; case RequirementKind::Conformance: case RequirementKind::Superclass: case RequirementKind::Layout: return false; } })) return false; // If there are no same-type requirements, allow signatures // that use only concrete types or generic parameters directly // in their parameter positions i.e. `(T, Int)`. auto *paramList = decl->getParameterList(); if (!paramList) return false; return llvm::all_of(paramList->getArray(), [](const ParamDecl *P) { auto paramType = P->getInterfaceType(); return paramType->is() || !paramType->hasTypeParameter(); }); } static bool isSupportedDisjunction(Constraint *disjunction) { auto choices = disjunction->getNestedConstraints(); if (isOperatorDisjunction(disjunction)) return isSupportedOperator(disjunction); if (auto *ctor = dyn_cast_or_null( getOverloadChoiceDecl(choices.front()))) { if (isSupportedSpecialConstructor(ctor)) return true; } // Non-operator disjunctions are supported only if they don't // have any complex generic choices. return llvm::all_of(choices, [&](Constraint *choice) { if (choice->isDisabled()) return true; if (choice->getKind() != ConstraintKind::BindOverload) return false; if (auto *decl = getOverloadChoiceDecl(choice)) { // Cannot optimize declarations that return IUO because // they form a disjunction over a result type once attempted. if (decl->isImplicitlyUnwrappedOptional()) return false; auto choiceType = decl->getInterfaceType()->getAs(); if (!choiceType || choiceType->hasError()) return false; // Non-generic choices are always supported. if (choiceType->is()) return true; if (auto *genericFn = choiceType->getAs()) return isSupportedGenericOverloadChoice(decl, genericFn); return false; } return false; }); } /// Determine whether the given overload choice constitutes a /// valid choice that would be attempted during normal solving /// without any score increases. static ValueDecl *isViableOverloadChoice(ConstraintSystem &cs, Constraint *constraint, ConstraintLocator *locator) { if (constraint->isDisabled()) return nullptr; if (constraint->getKind() != ConstraintKind::BindOverload) return nullptr; auto choice = constraint->getOverloadChoice(); auto *decl = choice.getDeclOrNull(); if (!decl) return nullptr; // Ignore declarations that come from implicitly imported modules // when `MemberImportVisibility` feature is enabled otherwise // we might end up favoring an overload that would be diagnosed // as unavailable later. if (cs.getASTContext().LangOpts.hasFeature(Feature::MemberImportVisibility)) { if (auto *useDC = constraint->getDeclContext()) { if (!useDC->isDeclImported(decl)) return nullptr; } } // If disjunction choice is unavailable we cannot // do anything with it. if (cs.isDeclUnavailable(decl, locator)) return nullptr; return decl; } /// Given the type variable that represents a result type of a /// function call, check whether that call is to an initializer /// and based on that deduce possible type for the result. /// /// @return A type and a flag that indicates whether there /// are any viable failable overloads and empty pair if the /// type variable isn't a result of an initializer call. static llvm::PointerIntPair inferTypeFromInitializerResultType(ConstraintSystem &cs, TypeVariableType *typeVar, ArrayRef disjunctions) { assert(typeVar->getImpl().isFunctionResult()); auto *resultLoc = typeVar->getImpl().getLocator(); auto *call = getAsExpr(resultLoc->getAnchor()); if (!call) return {}; auto *fn = call->getFn()->getSemanticsProvidingExpr(); Type instanceTy; ConstraintLocator *ctorLocator = nullptr; if (auto *typeExpr = getAsExpr(fn)) { instanceTy = cs.getType(typeExpr)->getMetatypeInstanceType(); ctorLocator = cs.getConstraintLocator(call, {LocatorPathElt::ApplyFunction(), LocatorPathElt::ConstructorMember()}); } else if (auto *UDE = getAsExpr(fn)) { if (!UDE->getName().getBaseName().isConstructor()) return {}; instanceTy = cs.getType(UDE->getBase())->getMetatypeInstanceType(); ctorLocator = cs.getConstraintLocator(UDE, LocatorPathElt::Member()); } if (!instanceTy || !ctorLocator) return {}; auto initRef = llvm::find_if(disjunctions, [&ctorLocator](Constraint *disjunction) { return disjunction->getLocator() == ctorLocator; }); if (initRef == disjunctions.end()) return {}; unsigned numFailable = 0; unsigned total = 0; for (auto *choice : (*initRef)->getNestedConstraints()) { auto *decl = isViableOverloadChoice(cs, choice, ctorLocator); if (!decl || !isa(decl)) continue; auto *ctor = cast(decl); if (ctor->isFailable()) ++numFailable; ++total; } if (numFailable > 0) { // If all of the active choices are failable, produce an optional // type only. if (numFailable == total) return {instanceTy->wrapInOptionalType(), /*hasFailable=*/false}; // Otherwise there are two options. return {instanceTy, /*hasFailable*/ true}; } return {instanceTy, /*hasFailable=*/false}; } /// If the given expression represents a chain of operators that have /// only declaration/member references and/or literals as arguments, /// attempt to deduce a potential type of the chain. For example if /// chain has only integral literals it's going to be `Int`, if there /// are some floating-point literals mixed in - it's going to be `Double`. static Type inferTypeOfArithmeticOperatorChain(ConstraintSystem &cs, ASTNode node) { class OperatorChainAnalyzer : public ASTWalker { ASTContext &C; DeclContext *DC; ConstraintSystem &CS; llvm::SmallPtrSet, 2> candidates; bool unsupported = false; PreWalkResult walkToExprPre(Expr *expr) override { if (isa(expr)) return Action::Continue(expr); if (isa(expr) || isa(expr)) return Action::Continue(expr); if (isa(expr)) return Action::Continue(expr); // This inference works only with arithmetic operators // because we know the structure of their overloads. if (auto *ODRE = dyn_cast(expr)) { if (auto *choice = ODRE->getDecls().front()) { if (choice->getBaseIdentifier().isArithmeticOperator()) return Action::Continue(expr); } } if (auto *LE = dyn_cast(expr)) { if (auto *P = TypeChecker::getLiteralProtocol(C, LE)) { if (auto defaultTy = TypeChecker::getDefaultType(P, DC)) { addCandidateType(defaultTy, /*literal=*/true); // String interpolation expressions have `TapExpr` // as their children, no reason to walk them. return Action::SkipChildren(expr); } } } if (auto *UDE = dyn_cast(expr)) { auto memberTy = CS.getType(UDE); if (!memberTy->hasTypeVariable()) { addCandidateType(memberTy, /*literal=*/false); return Action::SkipChildren(expr); } } if (auto *DRE = dyn_cast(expr)) { auto declTy = CS.getType(DRE); if (!declTy->hasTypeVariable()) { addCandidateType(declTy, /*literal=*/false); return Action::SkipChildren(expr); } } unsupported = true; return Action::Stop(); } void addCandidateType(Type type, bool literal) { if (literal) { if (type->isInt()) { // Floating-point types always subsume Int in operator chains. if (llvm::any_of(candidates, [](const auto &candidate) { auto ty = candidate.getPointer(); return isFloatType(ty) || ty->isCGFloat(); })) return; } else if (isFloatType(type) || type->isCGFloat()) { // A single use of a floating-point literal flips the // type of the entire chain to it. (void)candidates.erase({C.getIntType(), /*literal=*/true}); } } candidates.insert({type, literal}); } public: OperatorChainAnalyzer(ConstraintSystem &CS) : C(CS.getASTContext()), DC(CS.DC), CS(CS) {} Type chainType() const { if (unsupported) return Type(); return candidates.size() != 1 ? Type() : (*candidates.begin()).getPointer(); } }; OperatorChainAnalyzer analyzer(cs); node.walk(analyzer); return analyzer.chainType(); } NullablePtr getApplicableFnConstraint(ConstraintGraph &CG, Constraint *disjunction) { auto *boundVar = disjunction->getNestedConstraints()[0] ->getFirstType() ->getAs(); if (!boundVar) return nullptr; auto constraints = CG.gatherNearbyConstraints(boundVar, [](Constraint *constraint) { return constraint->getKind() == ConstraintKind::ApplicableFunction; }); if (constraints.size() != 1) return nullptr; auto *applicableFn = constraints.front(); // Unapplied disjunction could appear as a argument to applicable function, // we are not interested in that. return applicableFn->getSecondType()->isEqual(boundVar) ? applicableFn : nullptr; } void forEachDisjunctionChoice( ConstraintSystem &cs, Constraint *disjunction, llvm::function_ref callback) { for (auto constraint : disjunction->getNestedConstraints()) { auto *decl = isViableOverloadChoice(cs, constraint, disjunction->getLocator()); if (!decl) continue; Type overloadType = cs.getEffectiveOverloadType( disjunction->getLocator(), constraint->getOverloadChoice(), /*allowMembers=*/true, constraint->getDeclContext()); if (!overloadType || !overloadType->is()) continue; callback(constraint, decl, overloadType->castTo()); } } static OverloadedDeclRefExpr *isOverloadedDeclRef(Constraint *disjunction) { assert(disjunction->getKind() == ConstraintKind::Disjunction); auto *locator = disjunction->getLocator(); if (locator->getPath().empty()) return getAsExpr(locator->getAnchor()); return nullptr; } static unsigned numOverloadChoicesMatchingOnArity(OverloadedDeclRefExpr *ODRE, ArgumentList *arguments) { return llvm::count_if(ODRE->getDecls(), [&arguments](auto *choice) { if (auto *paramList = choice->getParameterList()) return arguments->size() == paramList->size(); return false; }); } static bool isVariadicGenericOverload(ValueDecl *choice) { auto genericContext = choice->getAsGenericContext(); if (!genericContext) return false; auto *GPL = genericContext->getGenericParams(); if (!GPL) return false; return llvm::any_of(GPL->getParams(), [&](const GenericTypeParamDecl *GP) { return GP->isParameterPack(); }); } /// This maintains an "old hack" behavior where overloads of some /// `OverloadedDeclRef` calls were favored purely based on number of /// argument and (non-defaulted) parameters matching. static void findFavoredChoicesBasedOnArity( ConstraintSystem &cs, Constraint *disjunction, ArgumentList *argumentList, llvm::function_ref favoredChoice) { auto *ODRE = isOverloadedDeclRef(disjunction); if (!ODRE) return; if (numOverloadChoicesMatchingOnArity(ODRE, argumentList) > 1) return; bool hasVariadicGenerics = false; SmallVector favored; forEachDisjunctionChoice( cs, disjunction, [&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) { if (decl->getAttrs().hasAttribute()) return; if (isVariadicGenericOverload(decl)) { hasVariadicGenerics = true; return; } if (overloadType->getNumParams() == argumentList->size() || llvm::count_if(*decl->getParameterList(), [](auto *param) { return !param->isDefaultArgument(); }) == argumentList->size()) favored.push_back(choice); }); if (hasVariadicGenerics) return; for (auto *choice : favored) favoredChoice(choice); } /// Preserves old behavior where, for unary calls, the solver would not previously /// consider choices that didn't match on the number of parameters (regardless of /// defaults and variadics) and only exact matches were favored. static std::optional preserveFavoringOfUnlabeledUnaryArgument( ConstraintSystem &cs, Constraint *disjunction, ArgumentList *argumentList) { if (!argumentList->isUnlabeledUnary()) return std::nullopt; if (!isExpr( cs.getParentExpr(argumentList->getUnlabeledUnaryExpr()))) return std::nullopt; // The hack rolled back favoring choices if one of the overloads was a // protocol requirement or variadic generic. // // Note that it doesn't matter whether such overload choices are viable // or not, their presence disabled this "optimization". if (llvm::any_of(disjunction->getNestedConstraints(), [](Constraint *choice) { auto *decl = getOverloadChoiceDecl(choice); if (!decl) return false; return isa(decl->getDeclContext()) || (!decl->getAttrs().hasAttribute() && isVariadicGenericOverload(decl)); })) return std::nullopt; auto ODRE = isOverloadedDeclRef(disjunction); bool preserveFavoringOfUnlabeledUnaryArgument = !ODRE || numOverloadChoicesMatchingOnArity(ODRE, argumentList) < 2; if (!preserveFavoringOfUnlabeledUnaryArgument) return std::nullopt; auto *argument = argumentList->getUnlabeledUnaryExpr()->getSemanticsProvidingExpr(); // The hack operated on "favored" types and only declaration references, // applications, and (dynamic) subscripts had them if they managed to // get an overload choice selected during constraint generation. // It's sometimes possible to infer a type of a literal and an operator // chain, so it should be allowed as well. if (!(isExpr(argument) || isExpr(argument) || isExpr(argument) || isExpr(argument) || isExpr(argument) || isExpr(argument))) return DisjunctionInfo::none(); auto argumentType = cs.getType(argument)->getRValueType(); // For chains like `1 + 2 * 3` it's easy to deduce the type because // we know what literal types are preferred. if (isa(argument)) { auto chainTy = inferTypeOfArithmeticOperatorChain(cs, argument); if (!chainTy) return DisjunctionInfo::none(); argumentType = chainTy; } // Use default type of a literal (when available) to make a guess. // This is what old hack used to do as well. if (auto *LE = dyn_cast(argument)) { auto *P = TypeChecker::getLiteralProtocol(cs.getASTContext(), LE); if (!P) return DisjunctionInfo::none(); auto defaultTy = TypeChecker::getDefaultType(P, cs.DC); if (!defaultTy) return DisjunctionInfo::none(); argumentType = defaultTy; } ASSERT(argumentType); if (argumentType->hasTypeVariable() || argumentType->hasDependentMember()) return DisjunctionInfo::none(); SmallVector favoredChoices; forEachDisjunctionChoice( cs, disjunction, [&argumentType, &favoredChoices, &argument]( Constraint *choice, ValueDecl *decl, FunctionType *overloadType) { if (decl->getAttrs().hasAttribute()) return; if (overloadType->getNumParams() != 1) return; auto ¶m = overloadType->getParams()[0]; // Literals are speculative, let's not attempt to apply them too // eagerly. if (!param.getParameterFlags().isNone() && (isa(argument) || isa(argument))) return; if (argumentType->isEqual(param.getPlainType())) favoredChoices.push_back(choice); }); return DisjunctionInfoBuilder(/*score=*/favoredChoices.empty() ? 0 : 1, favoredChoices) .build(); } } // end anonymous namespace /// Given a set of disjunctions, attempt to determine /// favored choices in the current context. static void determineBestChoicesInContext( ConstraintSystem &cs, SmallVectorImpl &disjunctions, llvm::DenseMap &result) { double bestOverallScore = 0.0; auto recordResult = [&bestOverallScore, &result](Constraint *disjunction, DisjunctionInfo &&info) { bestOverallScore = std::max(bestOverallScore, info.Score.value_or(0)); result.try_emplace(disjunction, info); }; for (auto *disjunction : disjunctions) { // If this is a compiler synthesized disjunction, mark it as supported // and record all of the previously favored choices. Such disjunctions // include - explicit coercions, IUO references,injected implicit // initializers for CGFloat<->Double conversions and restrictions with // multiple choices. if (disjunction->countFavoredNestedConstraints() > 0) { DisjunctionInfoBuilder info(/*score=*/2.0); for (auto *choice : disjunction->getNestedConstraints()) { if (choice->isFavored()) info.addFavoredChoice(choice); } recordResult(disjunction, info.build()); continue; } auto applicableFn = getApplicableFnConstraint(cs.getConstraintGraph(), disjunction); if (applicableFn.isNull()) { auto *locator = disjunction->getLocator(); if (auto expr = getAsExpr(locator->getAnchor())) { auto *parentExpr = cs.getParentExpr(expr); // Look through optional evaluation, so // we can cover expressions like `a?.b + 2`. if (isExpr(parentExpr)) parentExpr = cs.getParentExpr(parentExpr); if (parentExpr) { // If this is a chained member reference or a direct operator // argument it could be prioritized since it helps to establish // context for other calls i.e. `(a.)b + 2` if `a` and/or `b` // are disjunctions they should be preferred over `+`. switch (parentExpr->getKind()) { case ExprKind::Binary: case ExprKind::PrefixUnary: case ExprKind::PostfixUnary: case ExprKind::UnresolvedDot: { llvm::SmallVector favoredChoices; // Favor choices that don't require application. llvm::copy_if( disjunction->getNestedConstraints(), std::back_inserter(favoredChoices), [](Constraint *choice) { auto *decl = getOverloadChoiceDecl(choice); return decl && !decl->getInterfaceType()->is(); }); recordResult( disjunction, DisjunctionInfoBuilder(/*score=*/1.0, favoredChoices).build()); continue; } default: break; } } } continue; } auto argFuncType = applicableFn.get()->getFirstType()->getAs(); auto argumentList = cs.getArgumentList(applicableFn.get()->getLocator()); if (!argumentList) return; for (const auto &argument : *argumentList) { if (auto *expr = argument.getExpr()) { // Directly `<#...#>` or has one inside. if (isa(expr) || cs.containsIDEInspectionTarget(expr)) return; } } // This maintains an "old hack" behavior where overloads // of `OverloadedDeclRef` calls were favored purely // based on arity of arguments and parameters matching. { llvm::TinyPtrVector favoredChoices; findFavoredChoicesBasedOnArity(cs, disjunction, argumentList, [&favoredChoices](Constraint *choice) { favoredChoices.push_back(choice); }); if (!favoredChoices.empty()) { recordResult( disjunction, DisjunctionInfoBuilder(/*score=*/0.01, favoredChoices).build()); continue; } } // Preserves old behavior where, for unary calls, the solver // would not consider choices that didn't match on the number // of parameters (regardless of defaults) and only exact // matches were favored. if (auto info = preserveFavoringOfUnlabeledUnaryArgument(cs, disjunction, argumentList)) { recordResult(disjunction, std::move(info.value())); continue; } if (!isSupportedDisjunction(disjunction)) continue; SmallVector argsWithLabels; { argsWithLabels.append(argFuncType->getParams().begin(), argFuncType->getParams().end()); FunctionType::relabelParams(argsWithLabels, argumentList); } struct ArgumentCandidate { Type type; // The candidate type is derived from a literal expression. bool fromLiteral : 1; // The candidate type is derived from a call to an // initializer i.e. `Double(...)`. bool fromInitializerCall : 1; ArgumentCandidate(Type type, bool fromLiteral = false, bool fromInitializerCall = false) : type(type), fromLiteral(fromLiteral), fromInitializerCall(fromInitializerCall) {} }; // Determine whether there are any non-speculative choices // in the given set of candidates. Speculative choices are // literals or types inferred from initializer calls. auto anyNonSpeculativeCandidates = [&](ArrayRef candidates) { // If there is only one (non-CGFloat) candidate inferred from // an initializer call we don't consider this a speculation. // // CGFloat inference is always speculative because of the // implicit conversion between Double and CGFloat. if (llvm::count_if(candidates, [&](const auto &candidate) { return candidate.fromInitializerCall && !candidate.type->isCGFloat(); }) == 1) return true; // If there are no non-literal and non-initializer-inferred types // in the list, consider this is a speculation. return llvm::any_of(candidates, [&](const auto &candidate) { return !candidate.fromLiteral && !candidate.fromInitializerCall; }); }; auto anyNonSpeculativeResultTypes = [](ArrayRef results) { return llvm::any_of(results, [](Type resultTy) { // Double and CGFloat are considered speculative because // there exists an implicit conversion between them and // preference is based on score impact in the overall solution. return !(resultTy->isDouble() || resultTy->isCGFloat()); }); }; SmallVector, 2> argumentCandidates; argumentCandidates.resize(argFuncType->getNumParams()); llvm::TinyPtrVector resultTypes; bool hasArgumentCandidates = false; bool isOperator = isOperatorDisjunction(disjunction); for (unsigned i = 0, n = argFuncType->getNumParams(); i != n; ++i) { const auto ¶m = argFuncType->getParams()[i]; auto argType = cs.simplifyType(param.getPlainType()); SmallVector optionals; // i.e. `??` operator could produce an optional type // so `test(<> ?? 0) could result in an optional // argument that wraps a type variable. It should be possible // to infer bindings from underlying type variable and restore // optionality. if (argType->hasTypeVariable()) { if (auto *typeVar = argType->lookThroughAllOptionalTypes(optionals) ->getAs()) argType = typeVar; } SmallVector types; if (auto *typeVar = argType->getAs()) { auto bindingSet = cs.getBindingsFor(typeVar); // We need to have a notion of "complete" binding set before // we can allow inference from generic parameters and ternary, // otherwise we'd make a favoring decision that might not be // correct i.e. `v ?? (<> ? nil : o)` where `o` is `Int`. // `getBindingsFor` doesn't currently infer transitive bindings // which means that for a ternary we'd only have a single // binding - `Int` which could lead to favoring overload of // `??` and has non-optional parameter on the right-hand side. if (typeVar->getImpl().getGenericParameter() || typeVar->getImpl().isTernary()) continue; auto restoreOptionality = [](Type type, unsigned numOptionals) { for (unsigned i = 0; i != numOptionals; ++i) type = type->wrapInOptionalType(); return type; }; for (const auto &binding : bindingSet.Bindings) { auto type = restoreOptionality(binding.BindingType, optionals.size()); types.push_back({type}); } for (const auto &literal : bindingSet.Literals) { if (literal.second.hasDefaultType()) { // Add primary default type auto type = restoreOptionality(literal.second.getDefaultType(), optionals.size()); types.push_back({type, /*fromLiteral=*/true}); } else if (literal.first == cs.getASTContext().getProtocol( KnownProtocolKind::ExpressibleByNilLiteral) && literal.second.IsDirectRequirement) { // `==` and `!=` operators have special overloads that accept `nil` // as `_OptionalNilComparisonType` which is preferred over a // generic form `(T?, T?)`. if (isOperatorNamed(disjunction, "==") || isOperatorNamed(disjunction, "!=")) { auto nilComparisonTy = cs.getASTContext().get_OptionalNilComparisonTypeType(); types.push_back({nilComparisonTy, /*fromLiteral=*/true}); } } } // Help situations like `1 + {Double, CGFloat}(...)` by inferring // a type for the second operand of `+` based on a type being // constructed. if (typeVar->getImpl().isFunctionResult()) { auto *resultLoc = typeVar->getImpl().getLocator(); if (auto type = inferTypeOfArithmeticOperatorChain( cs, resultLoc->getAnchor())) { types.push_back({type, /*fromLiteral=*/true}); } auto binding = inferTypeFromInitializerResultType(cs, typeVar, disjunctions); if (auto instanceTy = binding.getPointer()) { types.push_back({instanceTy, /*fromLiteral=*/false, /*fromInitializerCall=*/true}); if (binding.getInt()) types.push_back({instanceTy->wrapInOptionalType(), /*fromLiteral=*/false, /*fromInitializerCall=*/true}); } } } else { types.push_back({argType, /*fromLiteral=*/false}); } argumentCandidates[i].append(types); hasArgumentCandidates |= !types.empty(); } auto resultType = cs.simplifyType(argFuncType->getResult()); if (auto *typeVar = resultType->getAs()) { auto bindingSet = cs.getBindingsFor(typeVar); for (const auto &binding : bindingSet.Bindings) { resultTypes.push_back(binding.BindingType); } // Infer bindings for each side of a ternary condition. bindingSet.forEachAdjacentVariable( [&cs, &resultTypes](TypeVariableType *adjacentVar) { auto *adjacentLoc = adjacentVar->getImpl().getLocator(); // This is one of the sides of a ternary operator. if (adjacentLoc->directlyAt()) { auto adjacentBindings = cs.getBindingsFor(adjacentVar); for (const auto &binding : adjacentBindings.Bindings) resultTypes.push_back(binding.BindingType); } }); } else { resultTypes.push_back(resultType); } // Determine whether all of the argument candidates are speculative (i.e. // literals). This information is going to be used later on when we need to // decide how to score a matching choice. bool onlySpeculativeArgumentCandidates = hasArgumentCandidates && llvm::none_of( indices(argFuncType->getParams()), [&](const unsigned argIdx) { return anyNonSpeculativeCandidates(argumentCandidates[argIdx]); }); bool canUseContextualResultTypes = isOperator && !isStandardComparisonOperator(disjunction); // Match arguments to the given overload choice. auto matchArguments = [&](OverloadChoice choice, FunctionType *overloadType) -> std::optional { auto *decl = choice.getDeclOrNull(); assert(decl); auto hasAppliedSelf = decl->hasCurriedSelf() && doesMemberRefApplyCurriedSelf(choice.getBaseType(), decl); ParameterListInfo paramListInfo(overloadType->getParams(), decl, hasAppliedSelf); MatchCallArgumentListener listener; return matchCallArguments(argsWithLabels, overloadType->getParams(), paramListInfo, argumentList->getFirstTrailingClosureIndex(), /*allow fixes*/ false, listener, std::nullopt); }; // Determine whether the candidate type is a subclass of the superclass // type. std::function isSubclassOf = [&](Type candidateType, Type superclassType) { // Conversion from a concrete type to its existential value. if (superclassType->isExistentialType() && !superclassType->isAny()) { auto layout = superclassType->getExistentialLayout(); if (auto layoutConstraint = layout.getLayoutConstraint()) { if (layoutConstraint->isClass() && !(candidateType->isClassExistentialType() || candidateType->mayHaveSuperclass())) return false; } if (layout.explicitSuperclass && !isSubclassOf(candidateType, layout.explicitSuperclass)) return false; return llvm::all_of(layout.getProtocols(), [&](ProtocolDecl *P) { if (auto superclass = P->getSuperclassDecl()) { if (!isSubclassOf(candidateType, superclass->getDeclaredInterfaceType())) return false; } auto result = TypeChecker::containsProtocol(candidateType, P, /*allowMissing=*/false); return result.first || result.second; }); } if (auto *selfType = candidateType->getAs()) { candidateType = selfType->getSelfType(); } if (auto *archetypeType = candidateType->getAs()) { candidateType = archetypeType->getSuperclass(); if (!candidateType) return false; } auto *subclassDecl = candidateType->getClassOrBoundGenericClass(); auto *superclassDecl = superclassType->getClassOrBoundGenericClass(); if (!(subclassDecl && superclassDecl)) return false; return superclassDecl->isSuperclassOf(subclassDecl); }; enum class MatchFlag { OnParam = 0x01, Literal = 0x02, ExactOnly = 0x04, DisableCGFloatDoubleConversion = 0x08, StringInterpolation = 0x10, }; using MatchOptions = OptionSet; // Perform a limited set of checks to determine whether the candidate // could possibly match the parameter type: // // - Equality // - Protocol conformance(s) // - Optional injection // - Superclass conversion // - Array-to-pointer conversion // - Value to existential conversion // - Existential opening // - Exact match on top-level types // // In situations when it's not possible to determine whether a candidate // type matches a parameter type (i.e. when partially resolved generic // types are matched) this function is going to produce \c std::nullopt // instead of `0` that indicates "not a match". std::function(GenericSignature, ValueDecl *, std::optional, Type, Type, MatchOptions)> scoreCandidateMatch = [&](GenericSignature genericSig, ValueDecl *choice, std::optional paramIdx, Type candidateType, Type paramType, MatchOptions options) -> std::optional { auto areEqual = [&](Type a, Type b) { return a->getDesugaredType()->isEqual(b->getDesugaredType()); }; auto isCGFloatDoubleConversionSupported = [&options]() { // CGFloat <-> Double conversion is supposed only while // match argument candidates to parameters. return options.contains(MatchFlag::OnParam) && !options.contains(MatchFlag::DisableCGFloatDoubleConversion); }; // Allow CGFloat -> Double widening conversions between // candidate argument types and parameter types. This would // make sure that Double is always preferred over CGFloat // when using literals and ranking supported disjunction // choices. Narrowing conversion (Double -> CGFloat) should // be delayed as much as possible. if (isCGFloatDoubleConversionSupported()) { if (candidateType->isCGFloat() && paramType->isDouble()) { return options.contains(MatchFlag::Literal) ? 0.2 : 0.9; } } if (options.contains(MatchFlag::ExactOnly)) { // If an exact match is requested favor only `[...]` to `Array<...>` // since everything else is going to increase to score. if (options.contains(MatchFlag::Literal)) { if (isUnboundArrayType(candidateType)) return paramType->isArray() ? 0.3 : 0; if (isUnboundDictionaryType(candidateType)) return cs.isDictionaryType(paramType) ? 0.3 : 0; } if (!areEqual(candidateType, paramType)) return 0; return options.contains(MatchFlag::Literal) ? 0.3 : 1; } // Exact match between candidate and parameter types. if (areEqual(candidateType, paramType)) { return options.contains(MatchFlag::Literal) ? 0.3 : 1; } if (options.contains(MatchFlag::Literal)) { if (paramType->hasTypeParameter() || paramType->isAnyExistentialType()) { // Attempt to match literal default to generic parameter. // This helps to determine whether there are any generic // overloads that are a possible match. auto score = scoreCandidateMatch(genericSig, choice, paramIdx, candidateType, paramType, options - MatchFlag::Literal); if (score == 0) return 0; // Optional injection lowers the score for operators to match // pre-optimizer behavior. return choice->isOperator() && paramType->getOptionalObjectType() ? 0.2 : 0.3; } else { // Integer and floating-point literals can match any parameter // type that conforms to `ExpressibleBy{Integer, Float}Literal` // protocol. Since this assessment is done in isolation we don't // lower the score even though this would be a non-default binding // for a literal. if (candidateType->isInt() && TypeChecker::conformsToKnownProtocol( paramType, KnownProtocolKind::ExpressibleByIntegerLiteral)) return 0.3; if (candidateType->isDouble() && TypeChecker::conformsToKnownProtocol( paramType, KnownProtocolKind::ExpressibleByFloatLiteral)) return 0.3; if (candidateType->isBool() && TypeChecker::conformsToKnownProtocol( paramType, KnownProtocolKind::ExpressibleByBooleanLiteral)) return 0.3; if (candidateType->isString()) { auto literalProtocol = options.contains(MatchFlag::StringInterpolation) ? KnownProtocolKind::ExpressibleByStringInterpolation : KnownProtocolKind::ExpressibleByStringLiteral; if (TypeChecker::conformsToKnownProtocol(paramType, literalProtocol)) return 0.3; } auto &ctx = cs.getASTContext(); // Check if the other side conforms to `ExpressibleByArrayLiteral` // protocol (in some way). We want an overly optimistic result // here to avoid under-favoring. if (candidateType->isArray() && checkConformanceWithoutContext( paramType, ctx.getProtocol(KnownProtocolKind::ExpressibleByArrayLiteral), /*allowMissing=*/true)) return 0.3; // Check if the other side conforms to // `ExpressibleByDictionaryLiteral` protocol (in some way). // We want an overly optimistic result here to avoid under-favoring. if (candidateType->isDictionary() && checkConformanceWithoutContext( paramType, ctx.getProtocol( KnownProtocolKind::ExpressibleByDictionaryLiteral), /*allowMissing=*/true)) return 0.3; } return 0; } // Check whether match would require optional injection. { SmallVector candidateOptionals; SmallVector paramOptionals; candidateType = candidateType->lookThroughAllOptionalTypes(candidateOptionals); paramType = paramType->lookThroughAllOptionalTypes(paramOptionals); if (!candidateOptionals.empty() || !paramOptionals.empty()) { auto requiresOptionalInjection = [&]() { return paramOptionals.size() > candidateOptionals.size(); }; // Can match i.e. Int? to Int or T to Int? if ((paramOptionals.empty() && paramType->is()) || paramOptionals.size() >= candidateOptionals.size()) { auto score = scoreCandidateMatch(genericSig, choice, paramIdx, candidateType, paramType, options); if (score > 0) { // Injection lowers the score slightly to comply with // old behavior where exact matches on operator parameter // types were always preferred. if (choice->isOperator() && requiresOptionalInjection()) return score.value() - 0.1; } return score; } // Optionality mismatch. return 0; } } // Candidate could be converted to a superclass. if (isSubclassOf(candidateType, paramType)) return 1; // Possible Array -> Unsafe*Pointer conversion. if (options.contains(MatchFlag::OnParam)) { if (candidateType->isArray() && paramType->getAnyPointerElementType()) return 1; } // If both argument and parameter are tuples of the same arity, // it's a match. { if (auto *candidateTuple = candidateType->getAs()) { auto *paramTuple = paramType->getAs(); if (paramTuple && candidateTuple->getNumElements() == paramTuple->getNumElements()) return 1; } } // If the parameter is `Any` we assume that all candidates are // convertible to it, which makes it a perfect match. The solver // would then decide whether erasing to an existential is preferable. if (paramType->isAny()) return 1; // Check if a candidate could be matched to a parameter by // an existential opening. if (options.contains(MatchFlag::OnParam) && candidateType->getMetatypeInstanceType()->isExistentialType()) { if (auto *genericParam = paramType->getMetatypeInstanceType() ->getAs()) { if (canOpenExistentialAt(choice, *paramIdx, genericParam, candidateType->getMetatypeInstanceType())) { // Lower the score slightly for operators to make sure that // concrete overloads are always preferred over generic ones. return choice->isOperator() ? 0.9 : 1; } } } // Check protocol requirement(s) if this parameter is a // generic parameter type. if (genericSig && paramType->isTypeParameter()) { // Light-weight check if cases where `checkRequirements` is not // applicable. auto checkProtocolRequirementsOnly = [&]() -> double { auto protocolRequirements = genericSig->getRequiredProtocols(paramType); if (llvm::all_of(protocolRequirements, [&](ProtocolDecl *protocol) { return bool(cs.lookupConformance(candidateType, protocol)); })) { if (auto *GP = paramType->getAs()) { auto *paramDecl = GP->getDecl(); if (paramDecl && paramDecl->isOpaqueType()) return 1.0; } return 0.7; } return 0; }; // If candidate is not fully resolved or is matched against a // dependent member type (i.e. `Self.T`), let's check conformances // only and lower the score. if (candidateType->hasTypeVariable() || candidateType->hasUnboundGenericType() || paramType->is()) { return checkProtocolRequirementsOnly(); } // Cannot match anything but generic type parameters here. if (!paramType->is()) return std::nullopt; bool hasUnsatisfiableRequirements = false; SmallVector requirements; for (const auto &requirement : genericSig.getRequirements()) { if (hasUnsatisfiableRequirements) break; llvm::SmallPtrSet toExamine; auto recordReferencesGenericParams = [&toExamine](Type type) { type.visit([&toExamine](Type innerTy) { if (auto *GP = innerTy->getAs()) toExamine.insert(GP); }); }; recordReferencesGenericParams(requirement.getFirstType()); if (requirement.getKind() != RequirementKind::Layout) recordReferencesGenericParams(requirement.getSecondType()); if (llvm::any_of(toExamine, [&](GenericTypeParamType *GP) { return paramType->isEqual(GP); })) { requirements.push_back(requirement); // If requirement mentions other generic parameters // `checkRequirements` would because we don't have // candidate substitutions for anything but the current // parameter type. hasUnsatisfiableRequirements |= toExamine.size() > 1; } } // If there are no requirements associated with the generic // parameter or dependent member type it could match any type. if (requirements.empty()) return 0.7; // If some of the requirements cannot be satisfied, because // they reference other generic parameters, for example: // ``, let's perform a // light-weight check instead of skipping this overload choice. if (hasUnsatisfiableRequirements) return checkProtocolRequirementsOnly(); // If the candidate type is fully resolved, let's check all of // the requirements that are associated with the corresponding // parameter, if all of them are satisfied this candidate is // an exact match. auto result = checkRequirements( requirements, [¶mType, &candidateType](SubstitutableType *type) -> Type { if (type->isEqual(paramType)) return candidateType; return ErrorType::get(type); }, SubstOptions(std::nullopt)); // Concrete operator overloads are always more preferable to // generic ones if there are exact or subtype matches, for // everything else the solver should try both concrete and // generic and disambiguate during ranking. if (result == CheckRequirementsResult::Success) return choice->isOperator() ? 0.9 : 1.0; return 0; } // Parameter is generic, let's check whether top-level // types match i.e. Array as a parameter. // // This is slightly better than all of the conformances matching // because the parameter is concrete and could split the system. if (paramType->hasTypeParameter()) { auto *candidateDecl = candidateType->getAnyNominal(); auto *paramDecl = paramType->getAnyNominal(); // Conservatively we need to make sure that this is not worse // than matching against a generic parameter with or without // requirements. if (candidateDecl && paramDecl && candidateDecl == paramDecl) { // If the candidate it not yet fully resolved, let's lower the // score slightly to avoid over-favoring generic overload choices. if (candidateType->hasTypeVariable()) return 0.8; // If the candidate is fully resolved we need to treat this // as we would generic parameter otherwise there is a risk // of skipping some of the valid choices. return choice->isOperator() ? 0.9 : 1.0; } } return 0; }; // The choice with the best score. double bestScore = 0.0; SmallVector, 2> favoredChoices; forEachDisjunctionChoice( cs, disjunction, [&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) { GenericSignature genericSig; { if (auto *GF = dyn_cast(decl)) { genericSig = GF->getGenericSignature(); } else if (auto *SD = dyn_cast(decl)) { genericSig = SD->getGenericSignature(); } } auto matchings = matchArguments(choice->getOverloadChoice(), overloadType); if (!matchings) return; // Require exact matches only if all of the arguments // are literals and there are no usable contextual result // types that could help narrow favored choices. bool favorExactMatchesOnly = onlySpeculativeArgumentCandidates && (!canUseContextualResultTypes || resultTypes.empty()); // This is important for SIMD operators in particular because // a lot of their overloads have same-type requires to a concrete // type: `(_: SIMD*, ...) -> ...`. if (genericSig) { overloadType = overloadType->getReducedType(genericSig) ->castTo(); } double score = 0.0; unsigned numDefaulted = 0; for (unsigned paramIdx = 0, n = overloadType->getNumParams(); paramIdx != n; ++paramIdx) { const auto ¶m = overloadType->getParams()[paramIdx]; auto argIndices = matchings->parameterBindings[paramIdx]; switch (argIndices.size()) { case 0: // Current parameter is defaulted, mark and continue. ++numDefaulted; continue; case 1: // One-to-one match between argument and parameter. break; default: // Cannot deal with multiple possible matchings at the moment. return; } auto argIdx = argIndices.front(); // Looks like there is nothing know about the argument. if (argumentCandidates[argIdx].empty()) continue; const auto paramFlags = param.getParameterFlags(); // If parameter is variadic we cannot compare because we don't know // real arity. if (paramFlags.isVariadic()) continue; auto paramType = param.getPlainType(); if (paramFlags.isAutoClosure()) paramType = paramType->castTo()->getResult(); // FIXME: Let's skip matching function types for now // because they have special rules for e.g. Concurrency // (around @Sendable) and @convention(c). if (paramType->lookThroughAllOptionalTypes()->is()) continue; // The idea here is to match the parameter type against // all of the argument candidate types and pick the best // match (i.e. exact equality one). // // If none of the candidates match exactly and they are // all bound concrete types, we consider this is mismatch // at this parameter position and remove the overload choice // from consideration. double bestCandidateScore = 0; llvm::BitVector mismatches(argumentCandidates[argIdx].size()); for (unsigned candidateIdx : indices(argumentCandidates[argIdx])) { // If one of the candidates matched exactly there is no reason // to continue checking. if (bestCandidateScore == 1) break; auto candidate = argumentCandidates[argIdx][candidateIdx]; // `inout` parameter accepts only l-value argument. if (paramFlags.isInOut() && !candidate.type->is()) { mismatches.set(candidateIdx); continue; } MatchOptions options(MatchFlag::OnParam); if (candidate.fromLiteral) options |= MatchFlag::Literal; if (favorExactMatchesOnly) options |= MatchFlag::ExactOnly; // Disable CGFloat -> Double conversion for unary operators. // // Some of the unary operators, i.e. prefix `-`, don't have // CGFloat variants and expect generic `FloatingPoint` overload // to match CGFloat type. Let's not attempt `CGFloat` -> `Double` // conversion for unary operators because it always leads // to a worse solutions vs. generic overloads. if (n == 1 && decl->isOperator()) options |= MatchFlag::DisableCGFloatDoubleConversion; // Disable implicit CGFloat -> Double widening conversion if // argument is an explicit call to `CGFloat` initializer. if (candidate.type->isCGFloat() && candidate.fromInitializerCall) options |= MatchFlag::DisableCGFloatDoubleConversion; if (isExpr( argumentList->getExpr(argIdx) ->getSemanticsProvidingExpr())) options |= MatchFlag::StringInterpolation; // The specifier for a candidate only matters for `inout` check. auto candidateScore = scoreCandidateMatch(genericSig, decl, paramIdx, candidate.type->getWithoutSpecifierType(), paramType, options); if (!candidateScore) continue; if (candidateScore > 0) { bestCandidateScore = std::max(bestCandidateScore, candidateScore.value()); continue; } if (!candidate.type->hasTypeVariable()) mismatches.set(candidateIdx); } // If none of the candidates for this parameter matched, let's // drop this overload from any further consideration. if (mismatches.all()) return; score += bestCandidateScore; } // An overload whether all of the parameters are defaulted // that's called without arguments. if (numDefaulted == overloadType->getNumParams()) return; // Average the score to avoid disfavoring disjunctions with fewer // parameters. score /= (overloadType->getNumParams() - numDefaulted); // If one of the result types matches exactly, that's a good // indication that overload choice should be favored. // // It's only safe to match result types of operators // because regular functions/methods/subscripts and // especially initializers could end up with a lot of // favored overloads because on the result type alone. if (canUseContextualResultTypes && (score > 0 || !hasArgumentCandidates)) { if (llvm::any_of(resultTypes, [&](const Type candidateResultTy) { return scoreCandidateMatch(genericSig, decl, /*paramIdx=*/std::nullopt, overloadType->getResult(), candidateResultTy, /*options=*/{}) > 0; })) { score += 1.0; } } if (score > 0) { // Nudge the score slightly to prefer concrete homogeneous // arithmetic operators. // // This is an opportunistic optimization based on the operator // use patterns where homogeneous operators are the most // heavily used ones. if (isArithmeticOperator(decl) && overloadType->getNumParams() == 2) { auto resultTy = overloadType->getResult(); if (!resultTy->hasTypeParameter() && llvm::all_of(overloadType->getParams(), [&resultTy](const auto ¶m) { return param.getPlainType()->isEqual(resultTy); })) score += 0.01; } favoredChoices.push_back({choice, score}); bestScore = std::max(bestScore, score); } }); if (cs.isDebugMode()) { PrintOptions PO; PO.PrintTypesForDebugging = true; llvm::errs().indent(cs.solverState->getCurrentIndent()) << "<<< Disjunction " << disjunction->getNestedConstraints()[0]->getFirstType()->getString( PO) << " with score " << bestScore << "\n"; } bestOverallScore = std::max(bestOverallScore, bestScore); // Determine if the score and favoring decisions here are // based only on "speculative" sources i.e. inference from // literals. // // This information is going to be used by the disjunction // selection algorithm to prevent over-eager selection of // the operators over unsupported non-operator declarations. bool isSpeculative = onlySpeculativeArgumentCandidates && (!canUseContextualResultTypes || !anyNonSpeculativeResultTypes(resultTypes)); DisjunctionInfoBuilder info(/*score=*/bestScore); info.setSpeculative(isSpeculative); for (const auto &choice : favoredChoices) { if (choice.second == bestScore) info.addFavoredChoice(choice.first); } recordResult(disjunction, info.build()); } if (cs.isDebugMode() && bestOverallScore > 0) { PrintOptions PO; PO.PrintTypesForDebugging = true; auto getLogger = [&](unsigned extraIndent = 0) -> llvm::raw_ostream & { return llvm::errs().indent(cs.solverState->getCurrentIndent() + extraIndent); }; { auto &log = getLogger(); log << "(Optimizing disjunctions: ["; interleave( disjunctions, [&](const auto *disjunction) { log << disjunction->getNestedConstraints()[0] ->getFirstType() ->getString(PO); }, [&]() { log << ", "; }); log << "]\n"; } getLogger(/*extraIndent=*/4) << "Best overall score = " << bestOverallScore << '\n'; for (auto *disjunction : disjunctions) { auto &entry = result[disjunction]; getLogger(/*extraIndent=*/4) << "[Disjunction '" << disjunction->getNestedConstraints()[0]->getFirstType()->getString( PO) << "' with score = " << entry.Score.value_or(0) << '\n'; for (const auto *choice : entry.FavoredChoices) { auto &log = getLogger(/*extraIndent=*/6); log << "- "; choice->print(log, &cs.getASTContext().SourceMgr); log << '\n'; } getLogger(/*extraIdent=*/4) << "]\n"; } getLogger() << ")\n"; } } static std::optional isPreferable(ConstraintSystem &cs, Constraint *disjunctionA, Constraint *disjunctionB) { // Consider only operator vs. non-operator situations. if (isOperatorDisjunction(disjunctionA) == isOperatorDisjunction(disjunctionB)) return std::nullopt; // Prevent operator selection if its passed as an argument // to not-yet resolved call. This helps to make sure that // in result builder context chained members and other // non-operator disjunctions are always selected first, // because they provide the context and help to prune the system. if (isInResultBuilderContext(cs, disjunctionA)) { if (isOperatorDisjunction(disjunctionA)) { if (isOperatorPassedToUnresolvedCall(cs, disjunctionA)) return false; } else { if (isOperatorPassedToUnresolvedCall(cs, disjunctionB)) return true; } } return std::nullopt; } std::optional>> ConstraintSystem::selectDisjunction() { if (performanceHacksEnabled()) { if (auto *disjunction = selectDisjunctionWithHacks()) return std::make_pair(disjunction, llvm::TinyPtrVector()); return std::nullopt; } SmallVector disjunctions; collectDisjunctions(disjunctions); if (disjunctions.empty()) return std::nullopt; llvm::DenseMap favorings; determineBestChoicesInContext(*this, disjunctions, favorings); // Pick the disjunction with the smallest number of favored, then active // choices. auto bestDisjunction = std::min_element( disjunctions.begin(), disjunctions.end(), [&](Constraint *first, Constraint *second) -> bool { unsigned firstActive = first->countActiveNestedConstraints(); unsigned secondActive = second->countActiveNestedConstraints(); if (firstActive == 1 || secondActive == 1) return secondActive != 1; if (auto preference = isPreferable(*this, first, second)) return preference.value(); auto &[firstScore, firstFavoredChoices, isFirstSpeculative] = favorings[first]; auto &[secondScore, secondFavoredChoices, isSecondSpeculative] = favorings[second]; bool isFirstOperator = isOperatorDisjunction(first); bool isSecondOperator = isOperatorDisjunction(second); // Infix logical operators are usually not overloaded and don't // form disjunctions, but when they do, let's prefer them over // other operators when they have fewer choices because it helps // to split operator chains. if (isFirstOperator && isSecondOperator) { if (isStandardInfixLogicalOperator(first) != isStandardInfixLogicalOperator(second)) return firstActive < secondActive; } // Not all of the non-operator disjunctions are supported by the // ranking algorithm, so to prevent eager selection of operators // when nothing concrete is known about them, let's reset the score // and compare purely based on number of choices. if (isFirstOperator != isSecondOperator) { if (isFirstOperator && isFirstSpeculative) firstScore = 0; if (isSecondOperator && isSecondSpeculative) secondScore = 0; } // Rank based on scores only if both disjunctions are supported. if (firstScore && secondScore) { // If both disjunctions have the same score they should be ranked // based on number of favored/active choices. if (*firstScore != *secondScore) return *firstScore > *secondScore; // If the scores are the same and both disjunctions are operators // they could be ranked purely based on whether the candidates // were speculative or not. The one with more context always wins. // // Consider the following situation: // // func test(_: Int) { ... } // func test(_: String) { ... } // // test("a" + "b" + "c") // // In this case we should always prefer ... + "c" over "a" + "b" // because it would fail and prune the other overload if parameter // type (aka contextual type) is `Int`. if (isFirstOperator && isSecondOperator && isFirstSpeculative != isSecondSpeculative) return isSecondSpeculative; } // Use favored choices only if disjunction score is higher // than zero. This means that we can maintain favoring // choices without impacting selection decisions. unsigned numFirstFavored = firstScore.value_or(0) ? firstFavoredChoices.size() : 0; unsigned numSecondFavored = secondScore.value_or(0) ? secondFavoredChoices.size() : 0; if (numFirstFavored == numSecondFavored) { if (firstActive != secondActive) return firstActive < secondActive; } numFirstFavored = numFirstFavored ? numFirstFavored : firstActive; numSecondFavored = numSecondFavored ? numSecondFavored : secondActive; return numFirstFavored < numSecondFavored; }); if (bestDisjunction != disjunctions.end()) return std::make_pair(*bestDisjunction, favorings[*bestDisjunction].FavoredChoices); return std::nullopt; }