Sema: Split up BindingSet::finalize() into finalizeKeyPathBindings() and finalizeUnresolvedMemberChain()

This commit is contained in:
Slava Pestov
2025-11-13 20:16:12 -05:00
parent 02129f2b1f
commit 45d03e152c
4 changed files with 27 additions and 16 deletions

View File

@@ -560,12 +560,14 @@ public:
/// literal protocols associated with this type variable.
void determineLiteralCoverage();
/// Finalize binding computation for this type variable by
/// inferring bindings from context e.g. transitive bindings.
/// Finalize binding computation for key path type variables.
///
/// \returns true if finalization successful (which makes binding set viable),
/// and false otherwise.
bool finalize();
bool finalizeKeyPathBindings();
/// Handle diagnostics of unresolved member chains.
void finalizeUnresolvedMemberChainResult();
static BindingScore formBindingScore(const BindingSet &b);

View File

@@ -699,13 +699,11 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
return keyPathTy;
}
bool BindingSet::finalize() {
bool BindingSet::finalizeKeyPathBindings() {
if (auto *locator = TypeVar->getImpl().getLocator()) {
if (TypeVar->getImpl().isKeyPathType()) {
auto &ctx = CS.getASTContext();
auto *keyPathLoc = TypeVar->getImpl().getLocator();
auto *keyPath = castToExpr<KeyPathExpr>(keyPathLoc->getAnchor());
auto *keyPath = castToExpr<KeyPathExpr>(locator->getAnchor());
bool isValid;
std::optional<KeyPathCapability> capability;
@@ -772,7 +770,7 @@ bool BindingSet::finalize() {
auto keyPathTy = getKeyPathType(ctx, *capability, rootTy,
CS.getKeyPathValueType(keyPath));
updatedBindings.insert(
{keyPathTy, AllowedBindingKind::Exact, keyPathLoc});
{keyPathTy, AllowedBindingKind::Exact, locator});
} else if (CS.shouldAttemptFixes()) {
auto fixedRootTy = CS.getFixedType(rootTy);
// If key path is structurally correct and has a resolved root
@@ -799,10 +797,14 @@ bool BindingSet::finalize() {
Bindings = std::move(updatedBindings);
Defaults.clear();
return true;
}
}
return true;
}
void BindingSet::finalizeUnresolvedMemberChainResult() {
if (auto *locator = TypeVar->getImpl().getLocator()) {
if (CS.shouldAttemptFixes() &&
locator->isLastElement<LocatorPathElt::UnresolvedMemberChainResult>()) {
// Let's see whether this chain is valid, if it isn't then to avoid
@@ -825,8 +827,6 @@ bool BindingSet::finalize() {
}
}
}
return true;
}
void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
@@ -1194,9 +1194,10 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
bindings.inferTransitiveBindings();
if (!bindings.finalize())
if (!bindings.finalizeKeyPathBindings())
continue;
bindings.finalizeUnresolvedMemberChainResult();
bindings.determineLiteralCoverage();
if (!bindings.hasViableBindings() && !bindings.isDirectHole())
@@ -1595,7 +1596,9 @@ BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) {
assert(!typeVar->getImpl().getFixedType(nullptr) && "has a fixed type");
BindingSet bindings(*this, typeVar, CG[typeVar].getPotentialBindings());
bindings.finalize();
(void) bindings.finalizeKeyPathBindings();
bindings.finalizeUnresolvedMemberChainResult();
bindings.determineLiteralCoverage();
return bindings;

View File

@@ -126,7 +126,8 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {
cs.getConstraintGraph()[floatLiteralTy].initBindingSet();
bindings.inferTransitiveBindings();
bindings.finalize();
(void) bindings.finalizeKeyPathBindings();
bindings.finalizeUnresolvedMemberChainResult();
bindings.determineLiteralCoverage();
// Inferred a single transitive binding through `$T_float`.

View File

@@ -140,9 +140,14 @@ BindingSet SemaTest::inferBindings(ConstraintSystem &cs,
continue;
auto &bindings = node.getBindingSet();
// FIXME: This is also called in inferTransitiveBindings(), why do we need
// to call it again?
bindings.inferTransitiveProtocolRequirements();
bindings.inferTransitiveBindings();
bindings.finalize();
(void) bindings.finalizeKeyPathBindings();
bindings.finalizeUnresolvedMemberChainResult();
bindings.determineLiteralCoverage();
}