[CSOptimizer] Account for the fact that sometimes all initializer choices are failable

If all of the viable initializer overloads are failable,
the only valid inference choice is an optional candidate type.
This commit is contained in:
Pavel Yaskevich
2025-03-11 14:07:00 -07:00
parent 9e97b8e3a1
commit db0a9de996
2 changed files with 99 additions and 38 deletions

View File

@@ -344,6 +344,43 @@ static bool isSupportedDisjunction(Constraint *disjunction) {
});
}
/// Determine whether the given overload choice constitutes a
/// valid choice that would be attempted during normal solving
/// without any score increases.
static ValueDecl *isViableOverloadChoice(ConstraintSystem &cs,
Constraint *constraint,
ConstraintLocator *locator) {
if (constraint->isDisabled())
return nullptr;
if (constraint->getKind() != ConstraintKind::BindOverload)
return nullptr;
auto choice = constraint->getOverloadChoice();
auto *decl = choice.getDeclOrNull();
if (!decl)
return nullptr;
// Ignore declarations that come from implicitly imported modules
// when `MemberImportVisibility` feature is enabled otherwise
// we might end up favoring an overload that would be diagnosed
// as unavailable later.
if (cs.getASTContext().LangOpts.hasFeature(Feature::MemberImportVisibility)) {
if (auto *useDC = constraint->getOverloadUseDC()) {
if (!useDC->isDeclImported(decl))
return nullptr;
}
}
// If disjunction choice is unavailable or disfavored we cannot
// do anything with it.
if (decl->getAttrs().hasAttribute<DisfavoredOverloadAttr>() ||
cs.isDeclUnavailable(decl, locator))
return nullptr;
return decl;
}
/// Given the type variable that represents a result type of a
/// function call, check whether that call is to an initializer
/// and based on that deduce possible type for the result.
@@ -389,16 +426,30 @@ inferTypeFromInitializerResultType(ConstraintSystem &cs,
if (initRef == disjunctions.end())
return {};
bool hasFailable =
llvm::any_of((*initRef)->getNestedConstraints(), [](Constraint *choice) {
if (choice->isDisabled())
return false;
auto *decl =
dyn_cast_or_null<ConstructorDecl>(getOverloadChoiceDecl(choice));
return decl && decl->isFailable();
});
unsigned numFailable = 0;
unsigned total = 0;
for (auto *choice : (*initRef)->getNestedConstraints()) {
auto *decl = isViableOverloadChoice(cs, choice, ctorLocator);
if (!decl || !isa<ConstructorDecl>(decl))
continue;
return {instanceTy, hasFailable};
auto *ctor = cast<ConstructorDecl>(decl);
if (ctor->isFailable())
++numFailable;
++total;
}
if (numFailable > 0) {
// If all of the active choices are failable, produce an optional
// type only.
if (numFailable == total)
return {instanceTy->wrapInOptionalType(), /*hasFailable=*/false};
// Otherwise there are two options.
return {instanceTy, /*hasFailable*/ true};
}
return {instanceTy, /*hasFailable=*/false};
}
/// If the given expression represents a chain of operators that only have
@@ -502,38 +553,14 @@ void forEachDisjunctionChoice(
llvm::function_ref<void(Constraint *, ValueDecl *decl, FunctionType *)>
callback) {
for (auto constraint : disjunction->getNestedConstraints()) {
if (constraint->isDisabled())
continue;
if (constraint->getKind() != ConstraintKind::BindOverload)
continue;
auto choice = constraint->getOverloadChoice();
auto *decl = choice.getDeclOrNull();
auto *decl =
isViableOverloadChoice(cs, constraint, disjunction->getLocator());
if (!decl)
continue;
// Ignore declarations that come from implicitly imported modules
// when `MemberImportVisibility` feature is enabled otherwise
// we might end up favoring an overload that would be diagnosed
// as unavailable later.
if (cs.getASTContext().LangOpts.hasFeature(
Feature::MemberImportVisibility)) {
if (auto *useDC = constraint->getDeclContext()) {
if (!useDC->isDeclImported(decl))
continue;
}
}
// If disjunction choice is unavailable or disfavored we cannot
// do anything with it.
if (decl->getAttrs().hasAttribute<DisfavoredOverloadAttr>() ||
cs.isDeclUnavailable(decl, disjunction->getLocator()))
continue;
Type overloadType =
cs.getEffectiveOverloadType(disjunction->getLocator(), choice,
/*allowMembers=*/true, cs.DC);
Type overloadType = cs.getEffectiveOverloadType(
disjunction->getLocator(), constraint->getOverloadChoice(),
/*allowMembers=*/true, constraint->getOverloadUseDC());
if (!overloadType || !overloadType->is<FunctionType>())
continue;

View File

@@ -326,3 +326,37 @@ struct TestUnary {
}
}
}
// Prevent non-optional overload of `??` to be favored when all initializers are failable.
class A {}
class B {}
protocol P {
init()
}
extension P {
init?(v: A) { self.init() }
}
struct V : P {
init() {}
@_disfavoredOverload
init?(v: B?) {}
// Important to keep this to make sure that disabled constraints
// are handled properly.
init<T: Collection>(other: T) where T.Element == Character {}
}
class TestFailableOnly {
var v: V?
func test(defaultB: B) {
guard let _ = self.v ?? V(v: defaultB) else { // OK (no warnings)
return
}
}
}