diff --git a/lib/Sema/CSDisjunction.cpp b/lib/Sema/CSDisjunction.cpp index b8b8099fff2..64e1e919dea 100644 --- a/lib/Sema/CSDisjunction.cpp +++ b/lib/Sema/CSDisjunction.cpp @@ -60,6 +60,68 @@ STATISTIC(NumDisjunctionsPruned, "disjunction pruning rounds"); using namespace swift; using namespace constraints; +std::optional> +ConstraintSystem::findConstraintThroughOptionals( + TypeVariableType *typeVar, OptionalWrappingDirection optionalDirection, + llvm::function_ref predicate) { + unsigned numOptionals = 0; + auto *rep = getRepresentative(typeVar); + + SmallPtrSet visitedVars; + while (visitedVars.insert(rep).second) { + // Look for a disjunction that binds this type variable to an overload set. + TypeVariableType *optionalObjectTypeVar = nullptr; + auto constraints = getConstraintGraph().gatherNearbyConstraints( + rep, + [&](Constraint *match) { + // If we have an "optional object of" constraint, we may need to + // look through it to find the constraint we're looking for. + if (match->getKind() != ConstraintKind::OptionalObject) + return predicate(match, rep); + + switch (optionalDirection) { + case OptionalWrappingDirection::Promote: { + // We want to go from T to T?, so check if we're on the RHS, and + // move over to the LHS if we can. + auto rhsTypeVar = match->getSecondType()->getAs(); + if (rhsTypeVar && getRepresentative(rhsTypeVar) == rep) { + optionalObjectTypeVar = + match->getFirstType()->getAs(); + } + break; + } + case OptionalWrappingDirection::Unwrap: { + // We want to go from T? to T, so check if we're on the LHS, and + // move over to the RHS if we can. + auto lhsTypeVar = match->getFirstType()->getAs(); + if (lhsTypeVar && getRepresentative(lhsTypeVar) == rep) { + optionalObjectTypeVar = + match->getSecondType()->getAs(); + } + break; + } + } + // Don't include the optional constraint in the results. + return false; + }); + + // If we found a result, return it. + if (!constraints.empty()) + return std::make_pair(constraints[0], numOptionals); + + // If we found an "optional object of" constraint, follow it. + if (optionalObjectTypeVar && !getFixedType(optionalObjectTypeVar)) { + numOptionals += 1; + rep = getRepresentative(optionalObjectTypeVar); + continue; + } + + // Otherwise we're done. + return std::nullopt; + } + return std::nullopt; +} + ConstraintSystem::SolutionKind ConstraintSystem::filterDisjunction( Constraint *disjunction, bool restoreOnFail, @@ -562,6 +624,32 @@ bool ConstraintSystem::simplifyAppliedOverloads( applicableFn->getLocator()); } +Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction( + TypeVariableType *tyvar, unsigned *numOptionalUnwraps) { + assert(!getFixedType(tyvar)); + auto result = findConstraintThroughOptionals( + tyvar, OptionalWrappingDirection::Promote, + [&](Constraint *match, TypeVariableType *currentRep) { + // Check to see if we have a bind overload disjunction that binds the + // type var we need. + if (match->getKind() != ConstraintKind::Disjunction || + match->getNestedConstraints().front()->getKind() != + ConstraintKind::BindOverload) + return false; + + auto lhsTy = match->getNestedConstraints().front()->getFirstType(); + auto *lhsTyVar = lhsTy->getAs(); + return lhsTyVar && currentRep == getRepresentative(lhsTyVar); + }); + if (!result) + return nullptr; + + if (numOptionalUnwraps) + *numOptionalUnwraps = result->second; + + return result->first; +} + bool ConstraintSystem::simplifyAppliedOverloads( Type fnType, FunctionType *argFnType, ConstraintLocatorBuilder locator) { // If we've already bound the function type, bail. diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 34689f29999..27854cdbe48 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -1159,94 +1159,6 @@ bool ConstraintSystem::solveForCodeCompletion( return true; } -std::optional> -ConstraintSystem::findConstraintThroughOptionals( - TypeVariableType *typeVar, OptionalWrappingDirection optionalDirection, - llvm::function_ref predicate) { - unsigned numOptionals = 0; - auto *rep = getRepresentative(typeVar); - - SmallPtrSet visitedVars; - while (visitedVars.insert(rep).second) { - // Look for a disjunction that binds this type variable to an overload set. - TypeVariableType *optionalObjectTypeVar = nullptr; - auto constraints = getConstraintGraph().gatherNearbyConstraints( - rep, - [&](Constraint *match) { - // If we have an "optional object of" constraint, we may need to - // look through it to find the constraint we're looking for. - if (match->getKind() != ConstraintKind::OptionalObject) - return predicate(match, rep); - - switch (optionalDirection) { - case OptionalWrappingDirection::Promote: { - // We want to go from T to T?, so check if we're on the RHS, and - // move over to the LHS if we can. - auto rhsTypeVar = match->getSecondType()->getAs(); - if (rhsTypeVar && getRepresentative(rhsTypeVar) == rep) { - optionalObjectTypeVar = - match->getFirstType()->getAs(); - } - break; - } - case OptionalWrappingDirection::Unwrap: { - // We want to go from T? to T, so check if we're on the LHS, and - // move over to the RHS if we can. - auto lhsTypeVar = match->getFirstType()->getAs(); - if (lhsTypeVar && getRepresentative(lhsTypeVar) == rep) { - optionalObjectTypeVar = - match->getSecondType()->getAs(); - } - break; - } - } - // Don't include the optional constraint in the results. - return false; - }); - - // If we found a result, return it. - if (!constraints.empty()) - return std::make_pair(constraints[0], numOptionals); - - // If we found an "optional object of" constraint, follow it. - if (optionalObjectTypeVar && !getFixedType(optionalObjectTypeVar)) { - numOptionals += 1; - rep = getRepresentative(optionalObjectTypeVar); - continue; - } - - // Otherwise we're done. - return std::nullopt; - } - return std::nullopt; -} - -Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction( - TypeVariableType *tyvar, unsigned *numOptionalUnwraps) { - assert(!getFixedType(tyvar)); - auto result = findConstraintThroughOptionals( - tyvar, OptionalWrappingDirection::Promote, - [&](Constraint *match, TypeVariableType *currentRep) { - // Check to see if we have a bind overload disjunction that binds the - // type var we need. - if (match->getKind() != ConstraintKind::Disjunction || - match->getNestedConstraints().front()->getKind() != - ConstraintKind::BindOverload) - return false; - - auto lhsTy = match->getNestedConstraints().front()->getFirstType(); - auto *lhsTyVar = lhsTy->getAs(); - return lhsTyVar && currentRep == getRepresentative(lhsTyVar); - }); - if (!result) - return nullptr; - - if (numOptionalUnwraps) - *numOptionalUnwraps = result->second; - - return result->first; -} - // Performance hack: if there are two generic overloads, and one is // more specialized than the other, prefer the more-specialized one. static Constraint *