Merge pull request #78267 from slavapestov/incremental-bindings-part-0

Sema: Some small cleanups in CSBindings.cpp and related code
This commit is contained in:
Slava Pestov
2025-01-16 09:03:32 -05:00
committed by GitHub
15 changed files with 327 additions and 222 deletions

View File

@@ -913,13 +913,6 @@ namespace swift {
/// is for testing purposes.
std::vector<std::string> DebugForbidTypecheckPrefixes;
/// The upper bound to number of sub-expressions unsolved
/// before termination of the shrink phrase of the constraint solver.
unsigned SolverShrinkUnsolvedThreshold = 10;
/// Disable the shrink phase of the expression type checker.
bool SolverDisableShrink = false;
/// Enable experimental operator designated types feature.
bool EnableOperatorDesignatedTypes = false;
@@ -935,6 +928,9 @@ namespace swift {
/// Allow request evalutation to perform type checking lazily, instead of
/// eagerly typechecking source files after parsing.
bool EnableLazyTypecheck = false;
/// Disable the component splitter phase of the expression type checker.
bool SolverDisableSplitter = false;
};
/// Options for controlling the behavior of the Clang importer.

View File

@@ -825,15 +825,17 @@ def downgrade_typecheck_interface_error : Flag<["-"], "downgrade-typecheck-inter
def enable_volatile_modules : Flag<["-"], "enable-volatile-modules">,
HelpText<"Load Swift modules in memory">;
def solver_expression_time_threshold_EQ : Joined<["-"], "solver-expression-time-threshold=">;
def solver_expression_time_threshold_EQ : Joined<["-"], "solver-expression-time-threshold=">,
HelpText<"Expression type checking timeout, in seconds">;
def solver_scope_threshold_EQ : Joined<["-"], "solver-scope-threshold=">;
def solver_scope_threshold_EQ : Joined<["-"], "solver-scope-threshold=">,
HelpText<"Expression type checking scope limit">;
def solver_trail_threshold_EQ : Joined<["-"], "solver-trail-threshold=">;
def solver_trail_threshold_EQ : Joined<["-"], "solver-trail-threshold=">,
HelpText<"Expression type checking trail change limit">;
def solver_disable_shrink :
Flag<["-"], "solver-disable-shrink">,
HelpText<"Disable the shrink phase of expression type checking">;
def solver_disable_splitter : Flag<["-"], "solver-disable-splitter">,
HelpText<"Disable the component splitter phase of expression type checking">;
def disable_constraint_solver_performance_hacks : Flag<["-"], "disable-constraint-solver-performance-hacks">,
HelpText<"Disable all the hacks in the constraint solver">;

View File

@@ -301,6 +301,11 @@ public:
Constraint *constraint);
void reset();
void dump(ConstraintSystem &CS,
TypeVariableType *TypeVar,
llvm::raw_ostream &out,
unsigned indent) const;
};
@@ -567,64 +572,27 @@ public:
///
/// \param inferredBindings The set of all bindings inferred for type
/// variables in the workset.
void inferTransitiveBindings(
const llvm::SmallDenseMap<TypeVariableType *, BindingSet>
&inferredBindings);
void inferTransitiveBindings();
/// Detect subtype, conversion or equivalence relationship
/// between two type variables and attempt to propagate protocol
/// requirements down the subtype or equivalence chain.
void inferTransitiveProtocolRequirements(
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings);
void inferTransitiveProtocolRequirements();
/// Finalize binding computation for this type variable by
/// inferring bindings from context e.g. transitive bindings.
///
/// \returns true if finalization successful (which makes binding set viable),
/// and false otherwise.
bool finalize(
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings);
bool finalize(bool transitive);
static BindingScore formBindingScore(const BindingSet &b);
/// Compare two sets of bindings, where \c x < y indicates that
/// \c x is a better set of bindings that \c y.
friend bool operator<(const BindingSet &x, const BindingSet &y) {
auto xScore = formBindingScore(x);
auto yScore = formBindingScore(y);
bool operator==(const BindingSet &other);
if (xScore < yScore)
return true;
if (yScore < xScore)
return false;
auto xDefaults = x.getNumViableDefaultableBindings();
auto yDefaults = y.getNumViableDefaultableBindings();
// If there is a difference in number of default types,
// prioritize bindings with fewer of them.
if (xDefaults != yDefaults)
return xDefaults < yDefaults;
// If neither type variable is a "hole" let's check whether
// there is a subtype relationship between them and prefer
// type variable which represents superclass first in order
// for "subtype" type variable to attempt more bindings later.
// This is required because algorithm can't currently infer
// bindings for subtype transitively through superclass ones.
if (!(std::get<0>(xScore) && std::get<0>(yScore))) {
if (x.Info.isSubtypeOf(y.getTypeVariable()))
return false;
if (y.Info.isSubtypeOf(x.getTypeVariable()))
return true;
}
// As a last resort, let's check if the bindings are
// potentially incomplete, and if so, let's de-prioritize them.
return x.isPotentiallyIncomplete() < y.isPotentiallyIncomplete();
}
/// Compare two sets of bindings, where \c this < other indicates that
/// \c this is a better set of bindings that \c other.
bool operator<(const BindingSet &other);
void dump(llvm::raw_ostream &out, unsigned indent) const;

View File

@@ -84,9 +84,24 @@ public:
/// as this type variable.
ArrayRef<TypeVariableType *> getEquivalenceClass() const;
inference::PotentialBindings &getCurrentBindings() {
assert(forRepresentativeVar());
return Bindings;
inference::PotentialBindings &getPotentialBindings() {
DEBUG_ASSERT(forRepresentativeVar());
return Potential;
}
void initBindingSet();
inference::BindingSet &getBindingSet() {
ASSERT(hasBindingSet());
return *Set;
}
bool hasBindingSet() const {
return Set.has_value();
}
void resetBindingSet() {
Set.reset();
}
private:
@@ -131,15 +146,6 @@ private:
/// Binding Inference {
/// Infer bindings from the given constraint and notify referenced variables
/// about its arrival (if requested). This happens every time a new constraint
/// gets added to a constraint graph node.
void introduceToInference(Constraint *constraint);
/// Forget about the given constraint. This happens every time a constraint
/// gets removed for a constraint graph.
void retractFromInference(Constraint *constraint);
/// Perform graph updates that must be undone after we bind a fixed type
/// to a type variable.
void retractFromInference(Type fixedType);
@@ -182,8 +188,13 @@ private:
/// The type variable this node represents.
TypeVariableType *TypeVar;
/// The set of bindings associated with this type variable.
inference::PotentialBindings Bindings;
/// The potential bindings for this type variable, updated incrementally by
/// the constraint graph.
inference::PotentialBindings Potential;
/// The binding set for this type variable, computed by
/// determineBestBindings().
std::optional<inference::BindingSet> Set;
/// The vector of constraints that mention this type variable, in a stable
/// order for iteration.

View File

@@ -5170,7 +5170,9 @@ public:
/// Get bindings for the given type variable based on current
/// state of the constraint system.
BindingSet getBindingsFor(TypeVariableType *typeVar, bool finalize = true);
///
/// FIXME: Remove this.
BindingSet getBindingsFor(TypeVariableType *typeVar);
private:
/// Add a constraint to the constraint system.

View File

@@ -1772,8 +1772,6 @@ static bool ParseTypeCheckerArgs(TypeCheckerOptions &Opts, ArgList &Args,
Opts.SolverScopeThreshold);
setUnsignedIntegerArgument(OPT_solver_trail_threshold_EQ,
Opts.SolverTrailThreshold);
setUnsignedIntegerArgument(OPT_solver_shrink_unsolved_threshold,
Opts.SolverShrinkUnsolvedThreshold);
Opts.DebugTimeFunctionBodies |= Args.hasArg(OPT_debug_time_function_bodies);
Opts.DebugTimeExpressions |=
@@ -1862,8 +1860,8 @@ static bool ParseTypeCheckerArgs(TypeCheckerOptions &Opts, ArgList &Args,
Opts.DebugForbidTypecheckPrefixes.push_back(A);
}
if (Args.getLastArg(OPT_solver_disable_shrink))
Opts.SolverDisableShrink = true;
if (Args.getLastArg(OPT_solver_disable_splitter))
Opts.SolverDisableSplitter = true;
if (FrontendOpts.RequestedAction == FrontendOptions::ActionType::Immediate)
Opts.DeferToRuntime = true;

View File

@@ -31,6 +31,13 @@ using namespace swift;
using namespace constraints;
using namespace inference;
void ConstraintGraphNode::initBindingSet() {
ASSERT(!hasBindingSet());
ASSERT(forRepresentativeVar());
Set.emplace(CG.getConstraintSystem(), TypeVar, Potential);
}
/// Check whether there exists a type that could be implicitly converted
/// to a given type i.e. is the given type is Double or Optional<..> this
/// function is going to return true because CGFloat could be converted
@@ -278,8 +285,7 @@ bool BindingSet::isPotentiallyIncomplete() const {
return false;
}
void BindingSet::inferTransitiveProtocolRequirements(
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
void BindingSet::inferTransitiveProtocolRequirements() {
if (TransitiveProtocols)
return;
@@ -314,13 +320,13 @@ void BindingSet::inferTransitiveProtocolRequirements(
do {
auto *currentVar = workList.back().second;
auto cachedBindings = inferredBindings.find(currentVar);
if (cachedBindings == inferredBindings.end()) {
auto &node = CS.getConstraintGraph()[currentVar];
if (!node.hasBindingSet()) {
workList.pop_back();
continue;
}
auto &bindings = cachedBindings->getSecond();
auto &bindings = node.getBindingSet();
// If current variable already has transitive protocol
// conformances inferred, there is no need to look deeper
@@ -352,11 +358,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
if (!equivalenceClass.insert(typeVar))
continue;
auto bindingSet = inferredBindings.find(typeVar);
if (bindingSet == inferredBindings.end())
auto &node = CS.getConstraintGraph()[typeVar];
if (!node.hasBindingSet())
continue;
auto &equivalences = bindingSet->getSecond().Info.EquivalentTo;
auto &equivalences = node.getBindingSet().Info.EquivalentTo;
for (const auto &eqVar : equivalences) {
workList.push_back(eqVar.first);
}
@@ -367,11 +373,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
if (memberVar == currentVar)
continue;
auto eqBindings = inferredBindings.find(memberVar);
if (eqBindings == inferredBindings.end())
auto &node = CS.getConstraintGraph()[memberVar];
if (!node.hasBindingSet())
continue;
const auto &bindings = eqBindings->getSecond();
const auto &bindings = node.getBindingSet();
llvm::SmallPtrSet<Constraint *, 2> placeholder;
// Add any direct protocols from members of the
@@ -423,9 +429,9 @@ void BindingSet::inferTransitiveProtocolRequirements(
// Propagate inferred protocols to all of the members of the
// equivalence class.
for (const auto &equivalence : bindings.Info.EquivalentTo) {
auto eqBindings = inferredBindings.find(equivalence.first);
if (eqBindings != inferredBindings.end()) {
auto &bindings = eqBindings->getSecond();
auto &node = CS.getConstraintGraph()[equivalence.first];
if (node.hasBindingSet()) {
auto &bindings = node.getBindingSet();
bindings.TransitiveProtocols.emplace(protocolsForEquivalence.begin(),
protocolsForEquivalence.end());
}
@@ -438,9 +444,7 @@ void BindingSet::inferTransitiveProtocolRequirements(
} while (!workList.empty());
}
void BindingSet::inferTransitiveBindings(
const llvm::SmallDenseMap<TypeVariableType *, BindingSet>
&inferredBindings) {
void BindingSet::inferTransitiveBindings() {
using BindingKind = AllowedBindingKind;
// If the current type variable represents a key path root type
@@ -450,9 +454,9 @@ void BindingSet::inferTransitiveBindings(
auto *locator = TypeVar->getImpl().getLocator();
if (auto *keyPathTy =
CS.getType(locator->getAnchor())->getAs<TypeVariableType>()) {
auto keyPathBindings = inferredBindings.find(keyPathTy);
if (keyPathBindings != inferredBindings.end()) {
auto &bindings = keyPathBindings->getSecond();
auto &node = CS.getConstraintGraph()[keyPathTy];
if (node.hasBindingSet()) {
auto &bindings = node.getBindingSet();
for (auto &binding : bindings.Bindings) {
auto bindingTy = binding.BindingType->lookThroughAllOptionalTypes();
@@ -476,9 +480,9 @@ void BindingSet::inferTransitiveBindings(
// transitively used because conversions between generic arguments
// are not allowed.
if (auto *contextualRootVar = inferredRootTy->getAs<TypeVariableType>()) {
auto rootBindings = inferredBindings.find(contextualRootVar);
if (rootBindings != inferredBindings.end()) {
auto &bindings = rootBindings->getSecond();
auto &node = CS.getConstraintGraph()[contextualRootVar];
if (node.hasBindingSet()) {
auto &bindings = node.getBindingSet();
// Don't infer if root is not yet fully resolved.
if (bindings.isDelayed())
@@ -507,11 +511,11 @@ void BindingSet::inferTransitiveBindings(
}
for (const auto &entry : Info.SupertypeOf) {
auto relatedBindings = inferredBindings.find(entry.first);
if (relatedBindings == inferredBindings.end())
auto &node = CS.getConstraintGraph()[entry.first];
if (!node.hasBindingSet())
continue;
auto &bindings = relatedBindings->getSecond();
auto &bindings = node.getBindingSet();
// FIXME: This is a workaround necessary because solver doesn't filter
// bindings based on protocol requirements placed on a type variable.
@@ -610,9 +614,9 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
return keyPathTy;
}
bool BindingSet::finalize(
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
inferTransitiveBindings(inferredBindings);
bool BindingSet::finalize(bool transitive) {
if (transitive)
inferTransitiveBindings();
determineLiteralCoverage();
@@ -628,8 +632,8 @@ bool BindingSet::finalize(
// func foo<T: P>(_: T) {}
// foo(.bar) <- `.bar` should be a static member of `P`.
// \endcode
if (!hasViableBindings()) {
inferTransitiveProtocolRequirements(inferredBindings);
if (transitive && !hasViableBindings()) {
inferTransitiveProtocolRequirements();
if (TransitiveProtocols.has_value()) {
for (auto *constraint : *TransitiveProtocols) {
@@ -956,6 +960,56 @@ void BindingSet::addLiteralRequirement(Constraint *constraint) {
Literals.insert({protocol, std::move(literal)});
}
bool BindingSet::operator==(const BindingSet &other) {
if (AdjacentVars != other.AdjacentVars)
return false;
if (Bindings.size() != other.Bindings.size())
return false;
for (auto i : indices(Bindings)) {
const auto &x = Bindings[i];
const auto &y = other.Bindings[i];
if (x.BindingType.getPointer() != y.BindingType.getPointer() ||
x.Kind != y.Kind)
return false;
}
if (Literals.size() != other.Literals.size())
return false;
for (auto pair : Literals) {
auto found = other.Literals.find(pair.first);
if (found == other.Literals.end())
return false;
const auto &x = pair.second;
const auto &y = found->second;
if (x.Source != y.Source ||
x.DefaultType.getPointer() != y.DefaultType.getPointer() ||
x.IsDirectRequirement != y.IsDirectRequirement) {
return false;
}
}
if (Defaults.size() != other.Defaults.size())
return false;
for (auto pair : Defaults) {
auto found = other.Defaults.find(pair.first);
if (found == other.Defaults.end() ||
pair.second != found->second)
return false;
}
if (TransitiveProtocols != other.TransitiveProtocols)
return false;
return true;
}
BindingSet::BindingScore BindingSet::formBindingScore(const BindingSet &b) {
// If there are no bindings available but this type
// variable represents a closure - let's consider it
@@ -976,17 +1030,54 @@ BindingSet::BindingScore BindingSet::formBindingScore(const BindingSet &b) {
-numNonDefaultableBindings);
}
bool BindingSet::operator<(const BindingSet &other) {
auto xScore = formBindingScore(*this);
auto yScore = formBindingScore(other);
if (xScore < yScore)
return true;
if (yScore < xScore)
return false;
auto xDefaults = getNumViableDefaultableBindings();
auto yDefaults = other.getNumViableDefaultableBindings();
// If there is a difference in number of default types,
// prioritize bindings with fewer of them.
if (xDefaults != yDefaults)
return xDefaults < yDefaults;
// If neither type variable is a "hole" let's check whether
// there is a subtype relationship between them and prefer
// type variable which represents superclass first in order
// for "subtype" type variable to attempt more bindings later.
// This is required because algorithm can't currently infer
// bindings for subtype transitively through superclass ones.
if (!(std::get<0>(xScore) && std::get<0>(yScore))) {
if (Info.isSubtypeOf(other.getTypeVariable()))
return false;
if (other.Info.isSubtypeOf(getTypeVariable()))
return true;
}
// As a last resort, let's check if the bindings are
// potentially incomplete, and if so, let's de-prioritize them.
return isPotentiallyIncomplete() < other.isPotentiallyIncomplete();
}
std::optional<BindingSet> ConstraintSystem::determineBestBindings(
llvm::function_ref<void(const BindingSet &)> onCandidate) {
// Look for potential type variable bindings.
std::optional<BindingSet> bestBindings;
llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
BindingSet *bestBindings = nullptr;
// First, let's collect all of the possible bindings.
for (auto *typeVar : getTypeVariables()) {
if (!typeVar->getImpl().hasRepresentativeOrFixed()) {
cache.insert({typeVar, getBindingsFor(typeVar, /*finalize=*/false)});
}
auto &node = CG[typeVar];
node.resetBindingSet();
if (!typeVar->getImpl().hasRepresentativeOrFixed())
node.initBindingSet();
}
// Determine whether given type variable with its set of bindings is
@@ -1023,11 +1114,12 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
// Now let's see if we could infer something for related type
// variables based on other bindings.
for (auto *typeVar : getTypeVariables()) {
auto cachedBindings = cache.find(typeVar);
if (cachedBindings == cache.end())
auto &node = CG[typeVar];
if (!node.hasBindingSet())
continue;
auto &bindings = cachedBindings->getSecond();
auto &bindings = node.getBindingSet();
// Before attempting to infer transitive bindings let's check
// whether there are any viable "direct" bindings associated with
// current type variable, if there are none - it means that this type
@@ -1040,7 +1132,7 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
// produce a default type.
bool isViable = isViableForRanking(bindings);
if (!bindings.finalize(cache))
if (!bindings.finalize(true))
continue;
if (!bindings || !isViable)
@@ -1051,10 +1143,13 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
// If these are the first bindings, or they are better than what
// we saw before, use them instead.
if (!bestBindings || bindings < *bestBindings)
bestBindings.emplace(bindings);
bestBindings = &bindings;
}
return bestBindings;
if (!bestBindings)
return std::nullopt;
return std::optional(*bestBindings);
}
/// Find the set of type variables that are inferable from the given type.
@@ -1435,18 +1530,13 @@ bool BindingSet::favoredOverConjunction(Constraint *conjunction) const {
return true;
}
BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar,
bool finalize) {
BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) {
assert(typeVar->getImpl().getRepresentative(nullptr) == typeVar &&
"not a representative");
assert(!typeVar->getImpl().getFixedType(nullptr) && "has a fixed type");
BindingSet bindings(*this, typeVar, CG[typeVar].getCurrentBindings());
if (finalize) {
llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
bindings.finalize(cache);
}
BindingSet bindings(*this, typeVar, CG[typeVar].getPotentialBindings());
bindings.finalize(false);
return bindings;
}
@@ -2042,6 +2132,48 @@ void PotentialBindings::reset() {
AssociatedCodeCompletionToken = ASTNode();
}
void PotentialBindings::dump(ConstraintSystem &cs,
TypeVariableType *typeVar,
llvm::raw_ostream &out,
unsigned indent) const {
PrintOptions PO;
PO.PrintTypesForDebugging = true;
out << "Potential bindings for ";
typeVar->getImpl().print(out);
out << "\n";
out << "[constraints: ";
interleave(Constraints,
[&](Constraint *constraint) {
constraint->print(out, &cs.getASTContext().SourceMgr, indent,
/*skipLocator=*/true);
},
[&out]() { out << ", "; });
out << "] ";
if (!AdjacentVars.empty()) {
out << "[adjacent to: ";
SmallVector<std::pair<TypeVariableType *, Constraint *>> adjacentVars(
AdjacentVars.begin(), AdjacentVars.end());
llvm::sort(adjacentVars,
[](auto lhs, auto rhs) {
return lhs.first->getID() < rhs.first->getID();
});
interleave(adjacentVars,
[&](std::pair<TypeVariableType *, Constraint *> pair) {
out << pair.first->getString(PO);
if (pair.first->getImpl().getFixedType(/*record=*/nullptr))
out << " (fixed)";
out << " via ";
pair.second->print(out, &cs.getASTContext().SourceMgr, indent,
/*skipLocator=*/true);
},
[&out]() { out << ", "; });
out << "] ";
}
}
void BindingSet::forEachLiteralRequirement(
llvm::function_ref<void(KnownProtocolKind)> callback) const {
for (const auto &literal : Literals) {
@@ -2181,22 +2313,21 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const {
if (!attributes.empty())
out << "] ";
if (involvesTypeVariables()) {
if (!AdjacentVars.empty()) {
out << "[adjacent to: ";
if (AdjacentVars.empty()) {
out << "<none>";
} else {
SmallVector<TypeVariableType *> adjacentVars(AdjacentVars.begin(),
AdjacentVars.end());
llvm::sort(adjacentVars,
[](const TypeVariableType *lhs, const TypeVariableType *rhs) {
SmallVector<TypeVariableType *> adjacentVars(AdjacentVars.begin(),
AdjacentVars.end());
llvm::sort(adjacentVars,
[](const TypeVariableType *lhs, const TypeVariableType *rhs) {
return lhs->getID() < rhs->getID();
});
interleave(
adjacentVars,
[&](const auto *typeVar) { out << typeVar->getString(PO); },
[&out]() { out << ", "; });
}
});
interleave(adjacentVars,
[&](auto *typeVar) {
out << typeVar->getString(PO);
if (typeVar->getImpl().getFixedType(/*record=*/nullptr))
out << " (fixed)";
},
[&out]() { out << ", "; });
out << "] ";
}
@@ -2209,24 +2340,25 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const {
enum class BindingKind { Exact, Subtypes, Supertypes, Literal };
BindingKind Kind;
Type BindingType;
PrintableBinding(BindingKind kind, Type bindingType)
: Kind(kind), BindingType(bindingType) {}
bool Viable;
PrintableBinding(BindingKind kind, Type bindingType, bool viable)
: Kind(kind), BindingType(bindingType), Viable(viable) {}
public:
static PrintableBinding supertypesOf(Type binding) {
return PrintableBinding{BindingKind::Supertypes, binding};
return PrintableBinding{BindingKind::Supertypes, binding, true};
}
static PrintableBinding subtypesOf(Type binding) {
return PrintableBinding{BindingKind::Subtypes, binding};
return PrintableBinding{BindingKind::Subtypes, binding, true};
}
static PrintableBinding exact(Type binding) {
return PrintableBinding{BindingKind::Exact, binding};
return PrintableBinding{BindingKind::Exact, binding, true};
}
static PrintableBinding literalDefaultType(Type binding) {
return PrintableBinding{BindingKind::Literal, binding};
static PrintableBinding literalDefaultType(Type binding, bool viable) {
return PrintableBinding{BindingKind::Literal, binding, viable};
}
void print(llvm::raw_ostream &out, const PrintOptions &PO,
@@ -2244,7 +2376,10 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const {
out << "(default type of literal) ";
break;
}
BindingType.print(out, PO);
if (BindingType)
BindingType.print(out, PO);
if (!Viable)
out << " [literal not viable]";
}
};
@@ -2266,10 +2401,11 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const {
}
}
for (const auto &literal : Literals) {
if (literal.second.viableAsBinding()) {
potentialBindings.push_back(PrintableBinding::literalDefaultType(
literal.second.getDefaultType()));
}
potentialBindings.push_back(PrintableBinding::literalDefaultType(
literal.second.hasDefaultType()
? literal.second.getDefaultType()
: Type(),
literal.second.viableAsBinding()));
}
if (potentialBindings.empty()) {
out << "<none>";

View File

@@ -382,7 +382,7 @@ static void determineBestChoicesInContext(
SmallVector<std::pair<Type, bool>, 2> types;
if (auto *typeVar = argType->getAs<TypeVariableType>()) {
auto bindingSet = cs.getBindingsFor(typeVar, /*finalize=*/true);
auto bindingSet = cs.getBindingsFor(typeVar);
for (const auto &binding : bindingSet.Bindings) {
types.push_back({binding.BindingType, /*fromLiteral=*/false});
@@ -421,7 +421,7 @@ static void determineBestChoicesInContext(
auto resultType = cs.simplifyType(argFuncType->getResult());
if (auto *typeVar = resultType->getAs<TypeVariableType>()) {
auto bindingSet = cs.getBindingsFor(typeVar, /*finalize=*/true);
auto bindingSet = cs.getBindingsFor(typeVar);
for (const auto &binding : bindingSet.Bindings) {
resultTypes.push_back(binding.BindingType);

View File

@@ -95,6 +95,12 @@ void SplitterStep::computeFollowupSteps(
// Contract the edges of the constraint graph.
CG.optimize();
if (CS.getASTContext().TypeCheckerOpts.SolverDisableSplitter) {
steps.push_back(std::make_unique<ComponentStep>(
CS, 0, &CS.InactiveConstraints, Solutions));
return;
}
// Compute the connected components of the constraint graph.
auto components = CG.computeConnectedComponents(CS.getTypeVariables());
unsigned numComponents = components.size();

View File

@@ -97,7 +97,8 @@ void ConstraintGraphNode::reset() {
TypeVar = nullptr;
EquivalenceClass.clear();
Bindings.reset();
Potential.reset();
Set.reset();
}
bool ConstraintGraphNode::forRepresentativeVar() const {
@@ -229,8 +230,10 @@ void ConstraintGraphNode::notifyReferencingVars(
void ConstraintGraphNode::notifyReferencedVars(
llvm::function_ref<void(ConstraintGraphNode &)> notification) const {
for (auto *fixedBinding : getReferencedVars()) {
notification(CG[fixedBinding]);
for (auto *referencedVar : getReferencedVars()) {
auto *repr = referencedVar->getImpl().getRepresentative(/*record=*/nullptr);
if (!repr->getImpl().getFixedType(/*record=*/nullptr))
notification(CG[repr]);
}
}
@@ -284,30 +287,6 @@ void ConstraintGraphNode::removeReferencedBy(TypeVariableType *typeVar) {
}
}
void ConstraintGraphNode::introduceToInference(Constraint *constraint) {
if (forRepresentativeVar()) {
auto fixedType = TypeVar->getImpl().getFixedType(/*record=*/nullptr);
if (!fixedType)
getCurrentBindings().infer(CG.getConstraintSystem(), TypeVar, constraint);
} else {
auto *repr =
getTypeVariable()->getImpl().getRepresentative(/*record=*/nullptr);
CG[repr].introduceToInference(constraint);
}
}
void ConstraintGraphNode::retractFromInference(Constraint *constraint) {
if (forRepresentativeVar()) {
auto fixedType = TypeVar->getImpl().getFixedType(/*record=*/nullptr);
if (!fixedType)
getCurrentBindings().retract(CG.getConstraintSystem(), TypeVar,constraint);
} else {
auto *repr =
getTypeVariable()->getImpl().getRepresentative(/*record=*/nullptr);
CG[repr].retractFromInference(constraint);
}
}
void ConstraintGraphNode::updateFixedType(
Type fixedType,
llvm::function_ref<void (ConstraintGraphNode &,
@@ -327,7 +306,11 @@ void ConstraintGraphNode::updateFixedType(
fixedType->getTypeVariables(referencedVars);
for (auto *referencedVar : referencedVars) {
auto &node = CG[referencedVar];
auto *repr = referencedVar->getImpl().getRepresentative(/*record=*/nullptr);
if (repr->getImpl().getFixedType(/*record=*/nullptr))
continue;
auto &node = CG[repr];
// Newly referred vars need to re-introduce all constraints associated
// with this type variable since they are now going to be used in
@@ -340,18 +323,20 @@ void ConstraintGraphNode::updateFixedType(
}
void ConstraintGraphNode::retractFromInference(Type fixedType) {
auto &cs = CG.getConstraintSystem();
return updateFixedType(
fixedType,
[](ConstraintGraphNode &node, Constraint *constraint) {
node.retractFromInference(constraint);
[&cs](ConstraintGraphNode &node, Constraint *constraint) {
node.getPotentialBindings().retract(cs, node.getTypeVariable(), constraint);
});
}
void ConstraintGraphNode::introduceToInference(Type fixedType) {
auto &cs = CG.getConstraintSystem();
return updateFixedType(
fixedType,
[](ConstraintGraphNode &node, Constraint *constraint) {
node.introduceToInference(constraint);
[&cs](ConstraintGraphNode &node, Constraint *constraint) {
node.getPotentialBindings().infer(cs, node.getTypeVariable(), constraint);
});
}
@@ -376,13 +361,13 @@ void ConstraintGraph::addConstraint(Constraint *constraint) {
addConstraint(typeVar, constraint);
auto &node = (*this)[typeVar];
node.introduceToInference(constraint);
auto *repr = typeVar->getImpl().getRepresentative(/*record=*/nullptr);
if (!repr->getImpl().getFixedType(/*record=*/nullptr))
(*this)[repr].getPotentialBindings().infer(CS, repr, constraint);
if (isUsefulForReferencedVars(constraint)) {
node.notifyReferencedVars([&](ConstraintGraphNode &referencedVar) {
referencedVar.introduceToInference(constraint);
(*this)[typeVar].notifyReferencedVars([&](ConstraintGraphNode &node) {
node.getPotentialBindings().infer(CS, node.getTypeVariable(), constraint);
});
}
}
@@ -414,14 +399,13 @@ void ConstraintGraph::removeConstraint(Constraint *constraint) {
// For the nodes corresponding to each type variable...
auto referencedTypeVars = constraint->getTypeVariables();
for (auto typeVar : referencedTypeVars) {
// Find the node for this type variable.
auto &node = (*this)[typeVar];
node.retractFromInference(constraint);
auto *repr = typeVar->getImpl().getRepresentative(/*record=*/nullptr);
if (!repr->getImpl().getFixedType(/*record=*/nullptr))
(*this)[repr].getPotentialBindings().retract(CS, repr, constraint);
if (isUsefulForReferencedVars(constraint)) {
node.notifyReferencedVars([&](ConstraintGraphNode &referencedVar) {
referencedVar.retractFromInference(constraint);
(*this)[typeVar].notifyReferencedVars([&](ConstraintGraphNode &node) {
node.getPotentialBindings().retract(CS, node.getTypeVariable(), constraint);
});
}
@@ -467,7 +451,7 @@ void ConstraintGraph::mergeNodesPre(TypeVariableType *typeVar2) {
node.notifyReferencingVars(
[&](ConstraintGraphNode &node, Constraint *constraint) {
node.retractFromInference(constraint);
node.getPotentialBindings().retract(CS, node.getTypeVariable(), constraint);
});
}
}
@@ -497,19 +481,20 @@ void ConstraintGraph::mergeNodes(TypeVariableType *typeVar1,
auto &node = (*this)[newMember];
for (auto *constraint : node.getConstraints()) {
repNode.introduceToInference(constraint);
if (!typeVar1->getImpl().getFixedType(/*record=*/nullptr))
repNode.getPotentialBindings().infer(CS, typeVar1, constraint);
if (!isUsefulForReferencedVars(constraint))
continue;
repNode.notifyReferencedVars([&](ConstraintGraphNode &referencedVar) {
referencedVar.introduceToInference(constraint);
repNode.notifyReferencedVars([&](ConstraintGraphNode &node) {
node.getPotentialBindings().infer(CS, node.getTypeVariable(), constraint);
});
}
node.notifyReferencingVars(
[&](ConstraintGraphNode &node, Constraint *constraint) {
node.introduceToInference(constraint);
node.getPotentialBindings().infer(CS, node.getTypeVariable(), constraint);
});
}
}
@@ -557,12 +542,12 @@ void ConstraintGraph::unrelateTypeVariables(TypeVariableType *typeVar,
void ConstraintGraph::inferBindings(TypeVariableType *typeVar,
Constraint *constraint) {
(*this)[typeVar].getCurrentBindings().infer(CS, typeVar, constraint);
(*this)[typeVar].getPotentialBindings().infer(CS, typeVar, constraint);
}
void ConstraintGraph::retractBindings(TypeVariableType *typeVar,
Constraint *constraint) {
(*this)[typeVar].getCurrentBindings().retract(CS, typeVar, constraint);
(*this)[typeVar].getPotentialBindings().retract(CS, typeVar, constraint);
}
#pragma mark Algorithms

View File

@@ -118,17 +118,17 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {
cs.getConstraintLocator({}));
{
auto bindings = cs.getBindingsFor(otherTy);
cs.getConstraintGraph()[otherTy].initBindingSet();
auto &bindings = cs.getConstraintGraph()[otherTy].getBindingSet();
// Make sure that there are no direct bindings or protocol requirements.
ASSERT_EQ(bindings.Bindings.size(), (unsigned)0);
ASSERT_EQ(bindings.Literals.size(), (unsigned)0);
llvm::SmallDenseMap<TypeVariableType *, BindingSet> env;
env.insert({floatLiteralTy, cs.getBindingsFor(floatLiteralTy)});
cs.getConstraintGraph()[floatLiteralTy].initBindingSet();
bindings.finalize(env);
bindings.finalize(/*transitive=*/true);
// Inferred a single transitive binding through `$T_float`.
ASSERT_EQ(bindings.Bindings.size(), (unsigned)1);

View File

@@ -126,24 +126,25 @@ ProtocolType *SemaTest::createProtocol(llvm::StringRef protocolName,
BindingSet SemaTest::inferBindings(ConstraintSystem &cs,
TypeVariableType *typeVar) {
llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
for (auto *typeVar : cs.getTypeVariables()) {
auto &node = cs.getConstraintGraph()[typeVar];
node.resetBindingSet();
if (!typeVar->getImpl().hasRepresentativeOrFixed())
cache.insert({typeVar, cs.getBindingsFor(typeVar, /*finalize=*/false)});
node.initBindingSet();
}
for (auto *typeVar : cs.getTypeVariables()) {
auto cachedBindings = cache.find(typeVar);
if (cachedBindings == cache.end())
auto &node = cs.getConstraintGraph()[typeVar];
if (!node.hasBindingSet())
continue;
auto &bindings = cachedBindings->getSecond();
bindings.inferTransitiveProtocolRequirements(cache);
bindings.finalize(cache);
auto &bindings = node.getBindingSet();
bindings.inferTransitiveProtocolRequirements();
bindings.finalize(/*transitive=*/true);
}
auto result = cache.find(typeVar);
assert(result != cache.end());
return result->second;
auto &node = cs.getConstraintGraph()[typeVar];
ASSERT(node.hasBindingSet());
return node.getBindingSet();
}

View File

@@ -1,4 +1,4 @@
// RUN: %target-typecheck-verify-swift -solver-disable-shrink
// RUN: %target-typecheck-verify-swift
// Self-contained test case
protocol P1 {}; func f<T: P1>(_: T, _: T) -> T { fatalError() }

View File

@@ -1,4 +1,4 @@
// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -solver-disable-shrink
// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1
// REQUIRES: tools-release,no_asan
struct Date {

View File

@@ -1,4 +1,4 @@
// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -solver-disable-shrink
// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1
// REQUIRES: tools-release,no_asan
// UNSUPPORTED: swift_test_mode_optimize_none && OS=linux-gnu