[CSOptimizer] Allow generic operator overloads without associated type parameters

This commit is contained in:
Pavel Yaskevich
2023-02-13 13:14:17 -08:00
parent 7c1c46d4e4
commit bc5f70a9a3
2 changed files with 59 additions and 59 deletions

View File

@@ -14,6 +14,8 @@
//
//===----------------------------------------------------------------------===//
#include "TypeChecker.h"
#include "swift/AST/GenericSignature.h"
#include "swift/Sema/ConstraintGraph.h"
#include "swift/Sema/ConstraintSystem.h"
#include "llvm/ADT/BitVector.h"
@@ -86,34 +88,6 @@ void forEachDisjunctionChoice(
}
}
static bool isSIMDType(Type type) {
auto *NTD = dyn_cast_or_null<StructDecl>(type->getAnyNominal());
if (!NTD)
return false;
auto typeName = NTD->getName().str();
if (!typeName.startswith("SIMD"))
return false;
return NTD->getParentModule()->getName().is("Swift");
}
static bool isArithmeticOperatorOnSIMDProtocol(ValueDecl *decl) {
if (!isSIMDOperator(decl))
return false;
if (!decl->getBaseIdentifier().isArithmeticOperator())
return false;
auto *DC = decl->getDeclContext();
if (auto *P = DC->getSelfProtocolDecl()) {
if (auto knownKind = P->getKnownProtocolKind())
return *knownKind == KnownProtocolKind::SIMD;
}
return false;
}
} // end anonymous namespace
/// Given a set of disjunctions, attempt to determine
@@ -181,23 +155,6 @@ static void determineBestChoicesInContext(
resultTypes.push_back(resultType);
}
auto isViableOverload = [&](ValueDecl *decl) {
// Allow standard arithmetic operator overloads on SIMD protocol
// to be considered because we can favor them when then argument
// is a known SIMD<N> type.
if (isArithmeticOperatorOnSIMDProtocol(decl))
return true;
// Don't consider generic overloads because we need conformance
// checking functionality to determine best favoring, preferring
// such overloads based only on concrete types leads to subpar
// choices due to missed information.
if (decl->getInterfaceType()->is<GenericFunctionType>())
return false;
return true;
};
// The choice with the best score.
double bestScore = 0.0;
SmallVector<std::pair<Constraint *, double>, 2> favoredChoices;
@@ -205,8 +162,21 @@ static void determineBestChoicesInContext(
forEachDisjunctionChoice(
cs, disjunction,
[&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) {
if (!isViableOverload(decl))
return;
GenericSignature genericSig;
{
if (auto *GF = dyn_cast<AbstractFunctionDecl>(decl)) {
genericSig = GF->getGenericSignature();
} else if (auto *SD = dyn_cast<SubscriptDecl>(decl)) {
genericSig = SD->getGenericSignature();
}
// Let's not consider non-operator generic overloads because we
// need conformance checking functionality to determine best
// favoring, preferring such overloads based on concrete types
// alone leads to subpar choices due to missed information.
if (genericSig && !decl->isOperator())
return;
}
ParameterListInfo paramListInfo(
overloadType->getParams(), decl,
@@ -249,6 +219,23 @@ static void determineBestChoicesInContext(
if (candidateArgumentTypes[i].empty())
continue;
// Check protocol requirement(s) if this parameter is a
// generic parameter type.
GenericSignature::RequiredProtocols protocolRequirements;
if (genericSig) {
if (auto *GP = paramType->getAs<GenericTypeParamType>()) {
protocolRequirements = genericSig->getRequiredProtocols(GP);
// It's a generic parameter which might be connected via
// same-type constraints to other generic parameters but
// we cannot check that here, so let's ignore it.
if (protocolRequirements.empty())
continue;
}
if (paramType->getAs<DependentMemberType>())
return;
}
// The idea here is to match the parameter type against
// all of the argument candidate types and pick the best
// match (i.e. exact equality one).
@@ -281,22 +268,36 @@ static void determineBestChoicesInContext(
// The specifier only matters for `inout` check.
candidateType = candidateType->getWithoutSpecifierType();
// Exact match on one of the candidate bindings.
if (candidateType->isEqual(paramType)) {
// We don't check generic requirements against literal default
// types because it creates more noise than signal for operators.
if (!protocolRequirements.empty() && !isLiteralDefault) {
if (llvm::all_of(
protocolRequirements, [&](ProtocolDecl *protocol) {
return TypeChecker::conformsToProtocol(
candidateType, protocol, cs.DC->getParentModule(),
/*allowMissing=*/false);
})) {
// Score is lower here because we still prefer concrete
// overloads over the generic ones when possible.
bestCandidateScore = std::max(bestCandidateScore, 0.7);
continue;
}
} else if (paramType->hasTypeParameter()) {
// i.e. Array<Element> or Optional<Wrapped> as a parameter.
// This is slightly better than all of the conformances matching
// because the parameter is concrete and could split the graph.
if (paramType->getAnyNominal() == candidateType->getAnyNominal()) {
bestCandidateScore = std::max(bestCandidateScore, 0.8);
continue;
}
} else if (candidateType->isEqual(paramType)) {
// Exact match on one of the candidate bindings.
bestCandidateScore =
std::max(bestCandidateScore, isLiteralDefault ? 0.3 : 1.0);
continue;
}
// If argument is SIMD<N> type i.e. SIMD1<...> it's appropriate
// to favor of the overloads that are declared on SIMD protocol
// and expect a particular `Scalar` if it's known.
if (isSIMDType(candidateType) &&
isArithmeticOperatorOnSIMDProtocol(decl)) {
bestCandidateScore = 1.0;
continue;
}
// Only established arguments could be considered mismatches,
// literal default types should be regarded as holes if they
// didn't match.

View File

@@ -5,7 +5,6 @@ let i: Int? = 1
let j: Int?
let k: Int? = 2
// expected-error@+1 {{the compiler is unable to type-check this expression in reasonable time}}
let _ = [i, j, k].reduce(0 as Int?) {
$0 != nil && $1 != nil ? $0! + $1! : ($0 != nil ? $0! : ($1 != nil ? $1! : nil))
}