[CSOptimizer] Favor SIMD related arithmetic operator choices if argument is SIMD<N> type

This commit is contained in:
Pavel Yaskevich
2023-02-10 10:50:03 -08:00
parent 672ae3d252
commit c2f7451c7b

View File

@@ -85,6 +85,34 @@ void forEachDisjunctionChoice(
}
}
static bool isSIMDType(Type type) {
auto *NTD = dyn_cast_or_null<StructDecl>(type->getAnyNominal());
if (!NTD)
return false;
auto typeName = NTD->getName().str();
if (!typeName.startswith("SIMD"))
return false;
return NTD->getParentModule()->getName().is("Swift");
}
static bool isArithmeticOperatorOnSIMDProtocol(ValueDecl *decl) {
if (!isSIMDOperator(decl))
return false;
if (!decl->getBaseIdentifier().isArithmeticOperator())
return false;
auto *DC = decl->getDeclContext();
if (auto *P = DC->getSelfProtocolDecl()) {
if (auto knownKind = P->getKnownProtocolKind())
return *knownKind == KnownProtocolKind::SIMD;
}
return false;
}
} // end anonymous namespace
/// Given a set of disjunctions, attempt to determine
@@ -152,6 +180,23 @@ static void determineBestChoicesInContext(
resultTypes.push_back(resultType);
}
auto isViableOverload = [&](ValueDecl *decl) {
// Allow standard arithmetic operator overloads on SIMD protocol
// to be considered because we can favor them when then argument
// is a known SIMD<N> type.
if (isArithmeticOperatorOnSIMDProtocol(decl))
return true;
// Don't consider generic overloads because we need conformance
// checking functionality to determine best favoring, preferring
// such overloads based only on concrete types leads to subpar
// choices due to missed information.
if (decl->getInterfaceType()->is<GenericFunctionType>())
return false;
return true;
};
// The choice with the best score.
double bestScore = 0.0;
SmallVector<std::pair<Constraint *, double>, 2> favoredChoices;
@@ -159,11 +204,7 @@ static void determineBestChoicesInContext(
forEachDisjunctionChoice(
cs, disjunction,
[&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) {
// Don't consider generic overloads because we need conformance
// checking functionality to determine best favoring, preferring
// such overloads based only on concrete types leads to subpar
// choices due to missed information.
if (decl->getInterfaceType()->is<GenericFunctionType>())
if (!isViableOverload(decl))
return;
if (overloadType->getNumParams() != argFuncType->getNumParams())
@@ -204,6 +245,16 @@ static void determineBestChoicesInContext(
if (candidateType->isEqual(paramType)) {
argScore = std::max(
argScore, /*fromLiteral=*/candidate.second ? 0.3 : 1.0);
continue;
}
// If argument is SIMD<N> type i.e. SIMD1<...> it's appropriate
// to favor of the overloads that are declared on SIMD protocol
// and expect a particular `Scalar` if it's known.
if (isSIMDType(candidateType) &&
isArithmeticOperatorOnSIMDProtocol(decl)) {
argScore = std::max(argScore, 1.0);
continue;
}
}