Merge pull request #85513 from slavapestov/refactor-csbindings

Refactor BindingSet a little bit
This commit is contained in:
Slava Pestov
2025-11-17 17:52:16 -05:00
committed by GitHub
6 changed files with 119 additions and 108 deletions

View File

@@ -469,10 +469,6 @@ public:
/// checking.
bool isViable(PotentialBinding &binding, bool isTransitive);
explicit operator bool() const {
return hasViableBindings() || isDirectHole();
}
/// Determine whether this set has any "viable" (or non-hole) bindings.
///
/// A viable binding could be - a direct or transitive binding
@@ -486,6 +482,12 @@ public:
!Defaults.empty();
}
/// Determine whether this set can be chosen as the next binding set
/// to attempt.
bool isViable() const {
return hasViableBindings() || isDirectHole();
}
ArrayRef<Constraint *> getConformanceRequirements() const {
return Protocols;
}
@@ -544,6 +546,8 @@ public:
/// Check if this binding is favored over a conjunction.
bool favoredOverConjunction(Constraint *conjunction) const;
void inferTransitiveKeyPathBindings();
/// Detect `subtype` relationship between two type variables and
/// attempt to infer supertype bindings transitively e.g.
///
@@ -553,19 +557,27 @@ public:
///
/// \param inferredBindings The set of all bindings inferred for type
/// variables in the workset.
void inferTransitiveBindings();
void inferTransitiveSupertypeBindings();
void inferTransitiveUnresolvedMemberRefBindings();
/// Detect subtype, conversion or equivalence relationship
/// between two type variables and attempt to propagate protocol
/// requirements down the subtype or equivalence chain.
void inferTransitiveProtocolRequirements();
/// Finalize binding computation for this type variable by
/// inferring bindings from context e.g. transitive bindings.
/// Check whether the given binding set covers any of the
/// literal protocols associated with this type variable.
void determineLiteralCoverage();
/// Finalize binding computation for key path type variables.
///
/// \returns true if finalization successful (which makes binding set viable),
/// and false otherwise.
bool finalize(bool transitive);
bool finalizeKeyPathBindings();
/// Handle diagnostics of unresolved member chains.
void finalizeUnresolvedMemberChainResult();
static BindingScore formBindingScore(const BindingSet &b);
@@ -590,10 +602,6 @@ private:
void addDefault(Constraint *constraint);
/// Check whether the given binding set covers any of the
/// literal protocols associated with this type variable.
void determineLiteralCoverage();
StringRef getLiteralBindingKind(LiteralBindingKind K) const {
#define ENTRY(Kind, String) \
case LiteralBindingKind::Kind: \

View File

@@ -103,8 +103,7 @@ bool BindingSet::isDirectHole() const {
if (!CS.shouldAttemptFixes())
return false;
return Bindings.empty() && getNumViableLiteralBindings() == 0 &&
Defaults.empty() && TypeVar->getImpl().canBindToHole();
return !hasViableBindings() && TypeVar->getImpl().canBindToHole();
}
static bool isGenericParameter(TypeVariableType *TypeVar) {
@@ -494,9 +493,7 @@ void BindingSet::inferTransitiveProtocolRequirements() {
} while (!workList.empty());
}
void BindingSet::inferTransitiveBindings() {
using BindingKind = AllowedBindingKind;
void BindingSet::inferTransitiveKeyPathBindings() {
// If the current type variable represents a key path root type
// let's try to transitively infer its type through bindings of
// a key path type.
@@ -551,7 +548,7 @@ void BindingSet::inferTransitiveBindings() {
}
} else {
addBinding(
binding.withSameSource(inferredRootTy, BindingKind::Exact),
binding.withSameSource(inferredRootTy, AllowedBindingKind::Exact),
/*isTransitive=*/true);
}
}
@@ -559,7 +556,9 @@ void BindingSet::inferTransitiveBindings() {
}
}
}
}
void BindingSet::inferTransitiveSupertypeBindings() {
for (const auto &entry : Info.SupertypeOf) {
auto &node = CS.getConstraintGraph()[entry.first];
if (!node.hasBindingSet())
@@ -609,8 +608,8 @@ void BindingSet::inferTransitiveBindings() {
// either be Exact or Supertypes in order for it to make sense
// to add Supertype bindings based on the relationship between
// our type variables.
if (binding.Kind != BindingKind::Exact &&
binding.Kind != BindingKind::Supertypes)
if (binding.Kind != AllowedBindingKind::Exact &&
binding.Kind != AllowedBindingKind::Supertypes)
continue;
auto type = binding.BindingType;
@@ -621,12 +620,49 @@ void BindingSet::inferTransitiveBindings() {
if (ConstraintSystem::typeVarOccursInType(TypeVar, type))
continue;
addBinding(binding.withSameSource(type, BindingKind::Supertypes),
addBinding(binding.withSameSource(type, AllowedBindingKind::Supertypes),
/*isTransitive=*/true);
}
}
}
void BindingSet::inferTransitiveUnresolvedMemberRefBindings() {
if (!hasViableBindings()) {
if (auto *locator = TypeVar->getImpl().getLocator()) {
if (locator->isLastElement<LocatorPathElt::MemberRefBase>()) {
// If this is a base of an unresolved member chain, as a last
// resort effort let's infer base to be a protocol type based
// on contextual conformance requirements.
//
// This allows us to find solutions in cases like this:
//
// \code
// func foo<T: P>(_: T) {}
// foo(.bar) <- `.bar` should be a static member of `P`.
// \endcode
inferTransitiveProtocolRequirements();
if (TransitiveProtocols.has_value()) {
for (auto *constraint : *TransitiveProtocols) {
Type protocolTy = constraint->getSecondType();
// Compiler-known marker protocols cannot be extended with members,
// so do not consider them.
if (auto p = protocolTy->getAs<ProtocolType>()) {
if (ProtocolDecl *decl = p->getDecl())
if (decl->getKnownProtocolKind() && decl->isMarkerProtocol())
continue;
}
addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
/*isTransitive=*/false);
}
}
}
}
}
}
static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
Type rootType, Type valueType) {
KeyPathMutability mutability;
@@ -664,51 +700,11 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
return keyPathTy;
}
bool BindingSet::finalize(bool transitive) {
if (transitive)
inferTransitiveBindings();
determineLiteralCoverage();
bool BindingSet::finalizeKeyPathBindings() {
if (auto *locator = TypeVar->getImpl().getLocator()) {
if (locator->isLastElement<LocatorPathElt::MemberRefBase>()) {
// If this is a base of an unresolved member chain, as a last
// resort effort let's infer base to be a protocol type based
// on contextual conformance requirements.
//
// This allows us to find solutions in cases like this:
//
// \code
// func foo<T: P>(_: T) {}
// foo(.bar) <- `.bar` should be a static member of `P`.
// \endcode
if (transitive && !hasViableBindings()) {
inferTransitiveProtocolRequirements();
if (TransitiveProtocols.has_value()) {
for (auto *constraint : *TransitiveProtocols) {
Type protocolTy = constraint->getSecondType();
// Compiler-known marker protocols cannot be extended with members,
// so do not consider them.
if (auto p = protocolTy->getAs<ProtocolType>()) {
if (ProtocolDecl *decl = p->getDecl())
if (decl->getKnownProtocolKind() && decl->isMarkerProtocol())
continue;
}
addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
/*isTransitive=*/false);
}
}
}
}
if (TypeVar->getImpl().isKeyPathType()) {
auto &ctx = CS.getASTContext();
auto *keyPathLoc = TypeVar->getImpl().getLocator();
auto *keyPath = castToExpr<KeyPathExpr>(keyPathLoc->getAnchor());
auto *keyPath = castToExpr<KeyPathExpr>(locator->getAnchor());
bool isValid;
std::optional<KeyPathCapability> capability;
@@ -775,7 +771,7 @@ bool BindingSet::finalize(bool transitive) {
auto keyPathTy = getKeyPathType(ctx, *capability, rootTy,
CS.getKeyPathValueType(keyPath));
updatedBindings.insert(
{keyPathTy, AllowedBindingKind::Exact, keyPathLoc});
{keyPathTy, AllowedBindingKind::Exact, locator});
} else if (CS.shouldAttemptFixes()) {
auto fixedRootTy = CS.getFixedType(rootTy);
// If key path is structurally correct and has a resolved root
@@ -802,10 +798,14 @@ bool BindingSet::finalize(bool transitive) {
Bindings = std::move(updatedBindings);
Defaults.clear();
return true;
}
}
return true;
}
void BindingSet::finalizeUnresolvedMemberChainResult() {
if (auto *locator = TypeVar->getImpl().getLocator()) {
if (CS.shouldAttemptFixes() &&
locator->isLastElement<LocatorPathElt::UnresolvedMemberChainResult>()) {
// Let's see whether this chain is valid, if it isn't then to avoid
@@ -828,8 +828,6 @@ bool BindingSet::finalize(bool transitive) {
}
}
}
return true;
}
void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
@@ -1143,37 +1141,6 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
node.initBindingSet();
}
// Determine whether given type variable with its set of bindings is
// viable to be attempted on the next step of the solver. If type variable
// has no "direct" bindings of any kind e.g. direct bindings to concrete
// types, default types from "defaultable" constraints or literal
// conformances, such type variable is not viable to be evaluated to be
// attempted next.
auto isViableForRanking = [this](const BindingSet &bindings) -> bool {
auto *typeVar = bindings.getTypeVariable();
// Key path root type variable is always viable because it can be
// transitively inferred from key path type during binding set
// finalization.
if (typeVar->getImpl().isKeyPathRoot())
return true;
// Type variable representing a base of unresolved member chain should
// always be considered viable for ranking since it's allow to infer
// types from transitive protocol requirements.
if (auto *locator = typeVar->getImpl().getLocator()) {
if (locator->isLastElement<LocatorPathElt::MemberRefBase>())
return true;
}
// If type variable is marked as a potential hole there is always going
// to be at least one binding available for it.
if (shouldAttemptFixes() && typeVar->getImpl().canBindToHole())
return true;
return bool(bindings);
};
// Now let's see if we could infer something for related type
// variables based on other bindings.
for (auto *typeVar : getTypeVariables()) {
@@ -1183,6 +1150,16 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
auto &bindings = node.getBindingSet();
// Special handling for key paths.
bindings.inferTransitiveKeyPathBindings();
if (!bindings.finalizeKeyPathBindings())
continue;
// Special handling for "leading-dot" unresolved member references,
// like .foo.
bindings.inferTransitiveUnresolvedMemberRefBindings();
bindings.finalizeUnresolvedMemberChainResult();
// Before attempting to infer transitive bindings let's check
// whether there are any viable "direct" bindings associated with
// current type variable, if there are none - it means that this type
@@ -1193,12 +1170,12 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
// associated with given type variable, any default constraints,
// or any conformance requirements to literal protocols with can
// produce a default type.
bool isViable = isViableForRanking(bindings);
bool isViable = bindings.isViable();
if (!bindings.finalize(true))
continue;
bindings.inferTransitiveSupertypeBindings();
bindings.determineLiteralCoverage();
if (!bindings || !isViable)
if (!isViable)
continue;
onCandidate(bindings);
@@ -1591,7 +1568,10 @@ BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) {
assert(!typeVar->getImpl().getFixedType(nullptr) && "has a fixed type");
BindingSet bindings(*this, typeVar, CG[typeVar].getPotentialBindings());
bindings.finalize(false);
(void) bindings.finalizeKeyPathBindings();
bindings.finalizeUnresolvedMemberChainResult();
bindings.determineLiteralCoverage();
return bindings;
}

View File

@@ -1105,7 +1105,9 @@ static void determineBestChoicesInContext(
// Simply adding it as a binding won't work because if the second argument
// is non-optional the overload that returns `T?` would still have a lower
// score.
if (!bindingSet && isNilCoalescingOperator(disjunction)) {
if (!bindingSet.hasViableBindings() &&
!bindingSet.isDirectHole() &&
isNilCoalescingOperator(disjunction)) {
auto &cg = cs.getConstraintGraph();
if (llvm::any_of(cg[typeVar].getConstraints(),
[&typeVar](Constraint *constraint) {

View File

@@ -921,7 +921,8 @@ bool ConstraintGraph::contractEdges() {
// us enough information to decided on l-valueness.
if (tyvar1->getImpl().canBindToInOut()) {
bool isNotContractable = true;
if (auto bindings = CS.getBindingsFor(tyvar1)) {
auto bindings = CS.getBindingsFor(tyvar1);
if (bindings.isViable()) {
// Holes can't be contracted.
if (bindings.isHole())
continue;

View File

@@ -125,7 +125,15 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {
cs.getConstraintGraph()[floatLiteralTy].initBindingSet();
bindings.finalize(/*transitive=*/true);
bindings.inferTransitiveKeyPathBindings();
(void) bindings.finalizeKeyPathBindings();
bindings.inferTransitiveUnresolvedMemberRefBindings();
bindings.finalizeUnresolvedMemberChainResult();
bindings.inferTransitiveSupertypeBindings();
bindings.determineLiteralCoverage();
// Inferred a single transitive binding through `$T_float`.
ASSERT_EQ(bindings.Bindings.size(), (unsigned)1);

View File

@@ -140,8 +140,20 @@ BindingSet SemaTest::inferBindings(ConstraintSystem &cs,
continue;
auto &bindings = node.getBindingSet();
// FIXME: This is also called in inferTransitiveUnresolvedMemberRefBindings(),
// why do we need to call it here too?
bindings.inferTransitiveProtocolRequirements();
bindings.finalize(/*transitive=*/true);
bindings.inferTransitiveKeyPathBindings();
(void) bindings.finalizeKeyPathBindings();
bindings.inferTransitiveUnresolvedMemberRefBindings();
bindings.finalizeUnresolvedMemberChainResult();
bindings.inferTransitiveSupertypeBindings();
bindings.determineLiteralCoverage();
}
auto &node = cs.getConstraintGraph()[typeVar];