Merge pull request #85967 from slavapestov/dont-copy-that-floppy

Sema: Avoid copying BindingSets
This commit is contained in:
Slava Pestov
2025-12-12 00:18:47 -05:00
committed by GitHub
11 changed files with 28 additions and 33 deletions

View File

@@ -385,6 +385,10 @@ public:
BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar, BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar,
const PotentialBindings &info); const PotentialBindings &info);
BindingSet(BindingSet &&other) = default;
BindingSet(const BindingSet &other) = delete;
ConstraintSystem &getConstraintSystem() const { return CS; } ConstraintSystem &getConstraintSystem() const { return CS; }
TypeVariableType *getTypeVariable() const { return TypeVar; } TypeVariableType *getTypeVariable() const { return TypeVar; }
@@ -568,8 +572,11 @@ public:
/// requirements down the subtype or equivalence chain. /// requirements down the subtype or equivalence chain.
void inferTransitiveProtocolRequirements(); void inferTransitiveProtocolRequirements();
/// Check whether the given binding set covers any of the /// Check whether the given binding set covers any of the literal protocols
/// literal protocols associated with this type variable. /// associated with this type variable. The idea is that if a type variable
/// has a binding like Int and also it has a conformance requirement to
/// ExpressibleByIntegerLitral, we can avoid attempting the default type of
/// that literal literal if we already attempted Int.
void determineLiteralCoverage(); void determineLiteralCoverage();
/// Finalize binding computation for key path type variables. /// Finalize binding computation for key path type variables.

View File

@@ -95,10 +95,6 @@ Type typeCheckParameterDefault(Expr *&, DeclContext *, Type, bool, bool);
} // end namespace swift } // end namespace swift
/// Allocate memory within the given constraint system.
void *operator new(size_t bytes, swift::constraints::ConstraintSystem& cs,
size_t alignment = 8);
namespace swift { namespace swift {
/// Specify how we handle the binding of underconstrained (free) type variables /// Specify how we handle the binding of underconstrained (free) type variables
@@ -5270,7 +5266,7 @@ public:
/// Determine whether given type variable with its set of bindings is viable /// Determine whether given type variable with its set of bindings is viable
/// to be attempted on the next step of the solver. /// to be attempted on the next step of the solver.
std::optional<BindingSet> determineBestBindings( const BindingSet *determineBestBindings(
llvm::function_ref<void(const BindingSet &)> onCandidate); llvm::function_ref<void(const BindingSet &)> onCandidate);
/// Get bindings for the given type variable based on current /// Get bindings for the given type variable based on current
@@ -6200,7 +6196,7 @@ class TypeVarBindingProducer : public BindingProducer<TypeVariableBinding> {
public: public:
using Element = TypeVariableBinding; using Element = TypeVariableBinding;
TypeVarBindingProducer(BindingSet &bindings); TypeVarBindingProducer(const BindingSet &bindings);
/// Retrieve a set of bindings available in the current state. /// Retrieve a set of bindings available in the current state.
ArrayRef<Binding> getCurrentBindings() const { return Bindings; } ArrayRef<Binding> getCurrentBindings() const { return Bindings; }

View File

@@ -1180,7 +1180,7 @@ bool BindingSet::operator<(const BindingSet &other) {
return isPotentiallyIncomplete() < other.isPotentiallyIncomplete(); return isPotentiallyIncomplete() < other.isPotentiallyIncomplete();
} }
std::optional<BindingSet> ConstraintSystem::determineBestBindings( const BindingSet *ConstraintSystem::determineBestBindings(
llvm::function_ref<void(const BindingSet &)> onCandidate) { llvm::function_ref<void(const BindingSet &)> onCandidate) {
// Look for potential type variable bindings. // Look for potential type variable bindings.
BindingSet *bestBindings = nullptr; BindingSet *bestBindings = nullptr;
@@ -1238,10 +1238,7 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
bestBindings = &bindings; bestBindings = &bindings;
} }
if (!bestBindings) return bestBindings;
return std::nullopt;
return std::optional(*bestBindings);
} }
/// Find the set of type variables that are inferable from the given type. /// Find the set of type variables that are inferable from the given type.

View File

@@ -263,7 +263,7 @@ StepResult ComponentStep::take(bool prevFailed) {
SmallString<64> potentialBindings; SmallString<64> potentialBindings;
llvm::raw_svector_ostream bos(potentialBindings); llvm::raw_svector_ostream bos(potentialBindings);
auto bestBindings = CS.determineBestBindings([&](const BindingSet &bindings) { const auto *bestBindings = CS.determineBestBindings([&](const BindingSet &bindings) {
if (CS.isDebugMode() && bindings.hasViableBindings()) { if (CS.isDebugMode() && bindings.hasViableBindings()) {
bos.indent(CS.solverState->getCurrentIndent() + 2); bos.indent(CS.solverState->getCurrentIndent() + 2);
bos << "("; bos << "(";

View File

@@ -540,7 +540,7 @@ class TypeVariableStep final : public BindingStep<TypeVarBindingProducer> {
bool SawFirstLiteralConstraint = false; bool SawFirstLiteralConstraint = false;
public: public:
TypeVariableStep(BindingContainer &bindings, TypeVariableStep(const BindingContainer &bindings,
SmallVectorImpl<Solution> &solutions) SmallVectorImpl<Solution> &solutions)
: BindingStep(bindings.getConstraintSystem(), {bindings}, solutions), : BindingStep(bindings.getConstraintSystem(), {bindings}, solutions),
TypeVar(bindings.getTypeVariable()) {} TypeVar(bindings.getTypeVariable()) {}

View File

@@ -1141,7 +1141,7 @@ Constraint::getTrailingClosureMatching() const {
void *Constraint::operator new(size_t bytes, ConstraintSystem& cs, void *Constraint::operator new(size_t bytes, ConstraintSystem& cs,
size_t alignment) { size_t alignment) {
return ::operator new (bytes, cs, alignment); return cs.getAllocator().Allocate(bytes, alignment);
} }
// FIXME: Perhaps we should store the Constraint -> PreparedOverload mapping // FIXME: Perhaps we should store the Constraint -> PreparedOverload mapping

View File

@@ -5323,7 +5323,7 @@ ConstraintSystem::inferKeyPathLiteralCapability(KeyPathExpr *keyPath) {
return success(mutability, isSendable); return success(mutability, isSendable);
} }
TypeVarBindingProducer::TypeVarBindingProducer(BindingSet &bindings) TypeVarBindingProducer::TypeVarBindingProducer(const BindingSet &bindings)
: BindingProducer(bindings.getConstraintSystem(), : BindingProducer(bindings.getConstraintSystem(),
bindings.getTypeVariable()->getImpl().getLocator()), bindings.getTypeVariable()->getImpl().getLocator()),
TypeVar(bindings.getTypeVariable()), CanBeNil(bindings.canBeNil()) { TypeVar(bindings.getTypeVariable()), CanBeNil(bindings.canBeNil()) {

View File

@@ -214,11 +214,6 @@ bool TypeVariableType::Implementation::isTernary() const {
return locator && locator->directlyAt<TernaryExpr>(); return locator && locator->directlyAt<TernaryExpr>();
} }
void *operator new(size_t bytes, ConstraintSystem& cs,
size_t alignment) {
return cs.getAllocator().Allocate(bytes, alignment);
}
bool constraints::computeTupleShuffle(TupleType *fromTuple, bool constraints::computeTupleShuffle(TupleType *fromTuple,
TupleType *toTuple, TupleType *toTuple,
SmallVectorImpl<unsigned> &sources) { SmallVectorImpl<unsigned> &sources) {

View File

@@ -196,7 +196,7 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) {
cs.getConstraintLocator({}, LocatorPathElt::ContextualType( cs.getConstraintLocator({}, LocatorPathElt::ContextualType(
CTP_Initialization))); CTP_Initialization)));
auto bindings = inferBindings(cs, typeVar); auto &bindings = inferBindings(cs, typeVar);
ASSERT_TRUE(bindings.getConformanceRequirements().empty()); ASSERT_TRUE(bindings.getConformanceRequirements().empty());
ASSERT_TRUE(bool(bindings.TransitiveProtocols)); ASSERT_TRUE(bool(bindings.TransitiveProtocols));
verifyProtocolInferenceResults(*bindings.TransitiveProtocols, verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
@@ -218,7 +218,7 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) {
cs.addConstraint(ConstraintKind::Conversion, typeVar, GPT1, cs.addConstraint(ConstraintKind::Conversion, typeVar, GPT1,
cs.getConstraintLocator({})); cs.getConstraintLocator({}));
auto bindings = inferBindings(cs, typeVar); auto &bindings = inferBindings(cs, typeVar);
ASSERT_TRUE(bindings.getConformanceRequirements().empty()); ASSERT_TRUE(bindings.getConformanceRequirements().empty());
ASSERT_TRUE(bool(bindings.TransitiveProtocols)); ASSERT_TRUE(bool(bindings.TransitiveProtocols));
verifyProtocolInferenceResults(*bindings.TransitiveProtocols, verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
@@ -281,10 +281,10 @@ TEST_F(SemaTest, TestComplexTransitiveProtocolInference) {
cs.addConstraint(ConstraintKind::Equal, typeVar1, typeVar5, nilLocator); cs.addConstraint(ConstraintKind::Equal, typeVar1, typeVar5, nilLocator);
cs.addConstraint(ConstraintKind::Conversion, typeVar5, typeVar6, nilLocator); cs.addConstraint(ConstraintKind::Conversion, typeVar5, typeVar6, nilLocator);
auto bindingsForT1 = inferBindings(cs, typeVar1); auto &bindingsForT1 = inferBindings(cs, typeVar1);
auto bindingsForT2 = inferBindings(cs, typeVar2); auto &bindingsForT2 = inferBindings(cs, typeVar2);
auto bindingsForT3 = inferBindings(cs, typeVar3); auto &bindingsForT3 = inferBindings(cs, typeVar3);
auto bindingsForT5 = inferBindings(cs, typeVar5); auto &bindingsForT5 = inferBindings(cs, typeVar5);
ASSERT_TRUE(bool(bindingsForT1.TransitiveProtocols)); ASSERT_TRUE(bool(bindingsForT1.TransitiveProtocols));
verifyProtocolInferenceResults(*bindingsForT1.TransitiveProtocols, verifyProtocolInferenceResults(*bindingsForT1.TransitiveProtocols,
@@ -335,7 +335,7 @@ TEST_F(SemaTest, TestTransitiveProtocolInferenceThroughEquivalenceChains) {
cs.addConstraint(ConstraintKind::ConformsTo, typeVar2, protocolTy0, nilLocator); cs.addConstraint(ConstraintKind::ConformsTo, typeVar2, protocolTy0, nilLocator);
cs.addConstraint(ConstraintKind::ConformsTo, typeVar3, protocolTy1, nilLocator); cs.addConstraint(ConstraintKind::ConformsTo, typeVar3, protocolTy1, nilLocator);
auto bindings = inferBindings(cs, typeVar0); auto &bindings = inferBindings(cs, typeVar0);
ASSERT_TRUE(bool(bindings.TransitiveProtocols)); ASSERT_TRUE(bool(bindings.TransitiveProtocols));
verifyProtocolInferenceResults(*bindings.TransitiveProtocols, verifyProtocolInferenceResults(*bindings.TransitiveProtocols,

View File

@@ -124,8 +124,8 @@ ProtocolType *SemaTest::createProtocol(llvm::StringRef protocolName,
return ProtocolType::get(PD, parent, Context); return ProtocolType::get(PD, parent, Context);
} }
BindingSet SemaTest::inferBindings(ConstraintSystem &cs, const BindingSet &SemaTest::inferBindings(ConstraintSystem &cs,
TypeVariableType *typeVar) { TypeVariableType *typeVar) {
for (auto *typeVar : cs.getTypeVariables()) { for (auto *typeVar : cs.getTypeVariables()) {
auto &node = cs.getConstraintGraph()[typeVar]; auto &node = cs.getConstraintGraph()[typeVar];
node.resetBindingSet(); node.resetBindingSet();

View File

@@ -80,8 +80,8 @@ protected:
ProtocolType *createProtocol(llvm::StringRef protocolName, ProtocolType *createProtocol(llvm::StringRef protocolName,
Type parent = Type()); Type parent = Type());
static BindingSet inferBindings(ConstraintSystem &cs, static const BindingSet &inferBindings(ConstraintSystem &cs,
TypeVariableType *typeVar); TypeVariableType *typeVar);
}; };
} // end namespace unittest } // end namespace unittest