diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index 3155d99f033..b44de9df574 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -1111,6 +1111,43 @@ public: } }; +/// Describes the arguments to which a parameter binds. +/// FIXME: This is an awful data structure. We want the equivalent of a +/// TinyPtrVector for unsigned values. +using ParamBinding = SmallVector; + +/// The result of calling matchCallArguments(). +struct MatchCallArgumentResult { + /// The direction of trailing closure matching that was performed. + TrailingClosureMatching trailingClosureMatching; + + /// The parameter bindings determined by the match. + SmallVector parameterBindings; + + /// When present, the forward and backward scans each produced a result, + /// and the parameter bindings are different. The primary result will be + /// forwarding, and this represents the backward binding. + Optional> backwardParameterBindings; + + friend bool operator==(const MatchCallArgumentResult &lhs, + const MatchCallArgumentResult &rhs) { + if (lhs.trailingClosureMatching != rhs.trailingClosureMatching) + return false; + if (lhs.parameterBindings != rhs.parameterBindings) + return false; + return lhs.backwardParameterBindings == rhs.backwardParameterBindings; + } + + /// Generate a result that maps the provided number of arguments to the same + /// number of parameters via forward match. + static MatchCallArgumentResult forArity(unsigned argCount) { + SmallVector Bindings; + for (unsigned i : range(argCount)) + Bindings.push_back({i}); + return {TrailingClosureMatching::Forward, Bindings, None}; + } +}; + /// A complete solution to a constraint system. /// /// A solution to a constraint system consists of type variable bindings to @@ -1159,9 +1196,9 @@ public: llvm::SmallVector Fixes; /// For locators associated with call expressions, the trailing closure - /// matching rule that was applied. - llvm::SmallMapVector - trailingClosureMatchingChoices; + /// matching rule and parameter bindings that were applied. + llvm::SmallMapVector + argumentMatchingChoices; /// The set of disjunction choices used to arrive at this solution, /// which informs constraint application. @@ -1203,6 +1240,10 @@ public: /// A map from argument expressions to their applied property wrapper expressions. llvm::MapVector> appliedPropertyWrappers; + /// Record a new argument matching choice for given locator that maps a + /// single argument to a single parameter. + void recordSingleArgMatchingChoice(ConstraintLocator *locator); + /// Simplify the given type by substituting all occurrences of /// type variables for their fixed types. Type simplifyType(Type type) const; @@ -2210,9 +2251,9 @@ private: AppliedDisjunctions; /// For locators associated with call expressions, the trailing closure - /// matching rule that was applied. - std::vector> - trailingClosureMatchingChoices; + /// matching rule and parameter bindings that were applied. + std::vector> + argumentMatchingChoices; /// The set of implicit value conversions performed by the solver on /// a current path to reach a solution. @@ -2709,8 +2750,8 @@ public: /// The length of \c AppliedDisjunctions. unsigned numAppliedDisjunctions; - /// The length of \c trailingClosureMatchingChoices; - unsigned numTrailingClosureMatchingChoices; + /// The length of \c argumentMatchingChoices. + unsigned numArgumentMatchingChoices; /// The length of \c OpenedTypes. unsigned numOpenedTypes; @@ -3226,11 +3267,8 @@ public: void recordPotentialHole(Type type); - void recordTrailingClosureMatch( - ConstraintLocator *locator, - TrailingClosureMatching trailingClosureMatch) { - trailingClosureMatchingChoices.push_back({locator, trailingClosureMatch}); - } + void recordMatchCallArgumentResult(ConstraintLocator *locator, + MatchCallArgumentResult result); /// Walk a closure AST to determine its effects. /// @@ -5086,11 +5124,6 @@ static inline bool computeTupleShuffle(TupleType *fromTuple, sources); } -/// Describes the arguments to which a parameter binds. -/// FIXME: This is an awful data structure. We want the equivalent of a -/// TinyPtrVector for unsigned values. -using ParamBinding = SmallVector; - /// Class used as the base for listeners to the \c matchCallArguments process. /// /// By default, none of the callbacks do anything. @@ -5158,20 +5191,6 @@ public: virtual bool relabelArguments(ArrayRef newNames); }; -/// The result of calling matchCallArguments(). -struct MatchCallArgumentResult { - /// The direction of trailing closure matching that was performed. - TrailingClosureMatching trailingClosureMatching; - - /// The parameter bindings determined by the match. - SmallVector parameterBindings; - - /// When present, the forward and backward scans each produced a result, - /// and the parameter bindings are different. The primary result will be - /// forwarding, and this represents the backward binding. - Optional> backwardParameterBindings; -}; - /// Match the call arguments (as described by the given argument type) to /// the parameters (as described by the given parameter type). /// diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 7ef126cb63f..6c1bec87d3c 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -3099,6 +3099,7 @@ namespace { DeclNameLoc nameLoc, bool implicit, ConstraintLocator *ctorLocator, SelectedOverload overload) { + auto locator = cs.getConstraintLocator(expr); auto choice = overload.choice; assert(choice.getKind() != OverloadChoiceKind::DeclViaDynamic); auto *ctor = cast(choice.getDecl()); @@ -3106,9 +3107,8 @@ namespace { // If the subexpression is a metatype, build a direct reference to the // constructor. if (cs.getType(base)->is()) { - return buildMemberRef( - base, dotLoc, overload, nameLoc, cs.getConstraintLocator(expr), - ctorLocator, implicit, AccessSemantics::Ordinary); + return buildMemberRef(base, dotLoc, overload, nameLoc, locator, + ctorLocator, implicit, AccessSemantics::Ordinary); } // The subexpression must be either 'self' or 'super'. @@ -3158,8 +3158,7 @@ namespace { auto *call = new (cs.getASTContext()) DotSyntaxCallExpr(ctorRef, dotLoc, base); - return finishApply(call, cs.getType(expr), cs.getConstraintLocator(expr), - ctorLocator); + return finishApply(call, cs.getType(expr), locator, ctorLocator); } /// Give the deprecation warning for referring to a global function @@ -4841,19 +4840,19 @@ namespace { } auto kind = origComponent.getKind(); - auto locator = cs.getConstraintLocator( - E, LocatorPathElt::KeyPathComponent(i)); + auto componentLocator = + cs.getConstraintLocator(E, LocatorPathElt::KeyPathComponent(i)); - // Adjust the locator such that it includes any additional elements to - // point to the component's callee, e.g a SubscriptMember for a - // subscript component. - locator = cs.getCalleeLocator(locator); + // Get a locator such that it includes any additional elements to point + // to the component's callee, e.g a SubscriptMember for a subscript + // component. + auto calleeLoc = cs.getCalleeLocator(componentLocator); bool isDynamicMember = false; // If this is an unresolved link, make sure we resolved it. if (kind == KeyPathExpr::Component::Kind::UnresolvedProperty || kind == KeyPathExpr::Component::Kind::UnresolvedSubscript) { - auto foundDecl = solution.getOverloadChoiceIfAvailable(locator); + auto foundDecl = solution.getOverloadChoiceIfAvailable(calleeLoc); if (!foundDecl) { // If we couldn't resolve the component, leave it alone. resolvedComponents.push_back(origComponent); @@ -4876,9 +4875,9 @@ namespace { switch (kind) { case KeyPathExpr::Component::Kind::UnresolvedProperty: { - buildKeyPathPropertyComponent(solution.getOverloadChoice(locator), - origComponent.getLoc(), - locator, resolvedComponents); + buildKeyPathPropertyComponent(solution.getOverloadChoice(calleeLoc), + origComponent.getLoc(), calleeLoc, + resolvedComponents); break; } case KeyPathExpr::Component::Kind::UnresolvedSubscript: { @@ -4887,9 +4886,9 @@ namespace { subscriptLabels = origComponent.getSubscriptLabels(); buildKeyPathSubscriptComponent( - solution.getOverloadChoice(locator), - origComponent.getLoc(), origComponent.getIndexExpr(), - subscriptLabels, locator, resolvedComponents); + solution.getOverloadChoice(calleeLoc), origComponent.getLoc(), + origComponent.getIndexExpr(), subscriptLabels, componentLocator, + resolvedComponents); break; } case KeyPathExpr::Component::Kind::OptionalChain: { @@ -5138,9 +5137,10 @@ namespace { SmallVectorImpl &components) { auto subscript = cast(overload.choice.getDecl()); assert(!subscript->isGetterMutating()); + auto memberLoc = cs.getCalleeLocator(locator); // Compute substitutions to refer to the member. - auto ref = resolveConcreteDeclRef(subscript, locator); + auto ref = resolveConcreteDeclRef(subscript, memberLoc); // If this is a @dynamicMemberLookup reference to resolve a property // through the subscript(dynamicMember:) member, restore the @@ -5159,12 +5159,15 @@ namespace { if (overload.choice.getKind() == OverloadChoiceKind::KeyPathDynamicMemberLookup) { indexExpr = buildKeyPathDynamicMemberIndexExpr( - indexType->castTo(), componentLoc, locator); + indexType->castTo(), componentLoc, memberLoc); } else { auto fieldName = overload.choice.getName().getBaseIdentifier().str(); indexExpr = buildDynamicMemberLookupIndexExpr(fieldName, componentLoc, indexType); } + // Record the implicit subscript expr's parameter bindings and matching + // direction as `coerceCallArguments` requires them. + solution.recordSingleArgMatchingChoice(locator); } auto subscriptType = @@ -5172,10 +5175,11 @@ namespace { auto resolvedTy = subscriptType->getResult(); // Coerce the indices to the type the subscript expects. - auto *newIndexExpr = - coerceCallArguments(indexExpr, subscriptType, ref, - /*applyExpr*/ nullptr, labels, - locator, /*appliedPropertyWrappers*/ {}); + auto *newIndexExpr = coerceCallArguments( + indexExpr, subscriptType, ref, + /*applyExpr*/ nullptr, labels, + cs.getConstraintLocator(locator, ConstraintLocator::ApplyArgument), + /*appliedPropertyWrappers*/ {}); // We need to be able to hash the captured index values in order for // KeyPath itself to be hashable, so check that all of the subscript @@ -5761,6 +5765,8 @@ Expr *ExprRewriter::coerceCallArguments( ArrayRef argLabels, ConstraintLocatorBuilder locator, ArrayRef appliedPropertyWrappers) { + assert(locator.last() && locator.last()->is()); + auto &ctx = getConstraintSystem().getASTContext(); auto params = funcType->getParams(); unsigned appliedWrapperIndex = 0; @@ -5773,11 +5779,6 @@ Expr *ExprRewriter::coerceCallArguments( LocatorPathElt::ApplyArgToParam(argIdx, paramIdx, flags)); }; - bool matchCanFail = - llvm::any_of(params, [](const AnyFunctionType::Param ¶m) { - return param.getPlainType()->hasUnresolvedType(); - }); - // Determine whether this application has curried self. bool skipCurriedSelf = apply ? hasCurriedSelf(cs, callee, apply) : true; // Determine the parameter bindings. @@ -5812,34 +5813,14 @@ Expr *ExprRewriter::coerceCallArguments( // Apply labels to arguments. AnyFunctionType::relabelParams(args, argLabels); - MatchCallArgumentListener listener; auto unlabeledTrailingClosureIndex = arg->getUnlabeledTrailingClosureIndexOfPackedArgument(); - // Determine the trailing closure matching rule that was applied. This - // is only relevant for explicit calls and subscripts. - auto trailingClosureMatching = TrailingClosureMatching::Forward; - { - SmallVector path; - auto anchor = locator.getLocatorParts(path); - if (!path.empty() && path.back().is() && - !anchor.isExpr(ExprKind::UnresolvedDot)) { - auto locatorPtr = cs.getConstraintLocator(locator); - assert(solution.trailingClosureMatchingChoices.count(locatorPtr) == 1); - trailingClosureMatching = solution.trailingClosureMatchingChoices.find( - locatorPtr)->second; - } - } - - auto callArgumentMatch = constraints::matchCallArguments( - args, params, paramInfo, unlabeledTrailingClosureIndex, - /*allowFixes=*/false, listener, trailingClosureMatching); - - assert((matchCanFail || callArgumentMatch) && - "Call arguments did not match up?"); - (void)matchCanFail; - - auto parameterBindings = std::move(callArgumentMatch->parameterBindings); + // Determine the parameter bindings that were applied. + auto *locatorPtr = cs.getConstraintLocator(locator); + assert(solution.argumentMatchingChoices.count(locatorPtr) == 1); + auto parameterBindings = solution.argumentMatchingChoices.find(locatorPtr) + ->second.parameterBindings; // We should either have parentheses or a tuple. auto *argTuple = dyn_cast(arg); @@ -6849,12 +6830,11 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType, ConstraintLocator::ConstructorMember})); solution.overloadChoices.insert({memberLoc, overload}); - solution.trailingClosureMatchingChoices.insert( - {cs.getConstraintLocator(callLocator, - ConstraintLocator::ApplyArgument), - TrailingClosureMatching::Forward}); } + // Record the implicit call's parameter bindings and match direction. + solution.recordSingleArgMatchingChoice(callLocator); + finishApply(implicitInit, toType, callLocator, callLocator); return implicitInit; } diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 94dc23a1a5c..cecbcaaf713 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -1477,6 +1477,14 @@ namespace { } Type visitUnresolvedDotExpr(UnresolvedDotExpr *expr) { + // UnresolvedDot applies the base to remove a single curry level from a + // member reference without using an applicable function constraint so + // we record the call argument matching here so it can be found later when + // a solution is applied to the AST. + CS.recordMatchCallArgumentResult( + CS.getConstraintLocator(expr, ConstraintLocator::ApplyArgument), + MatchCallArgumentResult::forArity(1)); + // If this is Builtin.type_join*, just return any type and move // on since we're going to discard this, and creating any type // variables for the reference will cause problems. diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 0a8ca7ac702..899a3bd8783 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -1309,9 +1309,9 @@ ConstraintSystem::TypeMatchResult constraints::matchCallArguments( } selectedTrailingMatching = callArgumentMatch->trailingClosureMatching; - // Record the direction of matching used for this call. - cs.recordTrailingClosureMatch(cs.getConstraintLocator(locator), - selectedTrailingMatching); + // Record the matching direction and parameter bindings used for this call. + cs.recordMatchCallArgumentResult(cs.getConstraintLocator(locator), + *callArgumentMatch); // If there was a disjunction because both forward and backward were // possible, increase the score for forward matches to bias toward the @@ -9716,10 +9716,10 @@ ConstraintSystem::simplifyApplicableFnConstraint( // have an explicit inout argument. if (type1.getPointer() == desugar2) { if (!isOperator || !hasInOut()) { - recordTrailingClosureMatch( + recordMatchCallArgumentResult( getConstraintLocator( outerLocator.withPathElement(ConstraintLocator::ApplyArgument)), - TrailingClosureMatching::Forward); + MatchCallArgumentResult::forArity(func1->getNumParams())); return SolutionKind::Solved; } } @@ -10864,6 +10864,12 @@ void ConstraintSystem::recordPotentialHole(Type type) { }); } +void ConstraintSystem::recordMatchCallArgumentResult( + ConstraintLocator *locator, MatchCallArgumentResult result) { + assert(locator->isLastElement()); + argumentMatchingChoices.push_back({locator, result}); +} + ConstraintSystem::SolutionKind ConstraintSystem::simplifyFixConstraint( ConstraintFix *fix, Type type1, Type type2, ConstraintKind matchKind, TypeMatchOptions flags, ConstraintLocatorBuilder locator) { diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 81beeb07318..7e5ecbb6871 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -137,12 +137,10 @@ Solution ConstraintSystem::finalize() { solution.DisjunctionChoices.insert(choice); } - // Remember all of the trailing closure matching choices we made. - for (auto &trailingClosureMatch : trailingClosureMatchingChoices) { - auto inserted = solution.trailingClosureMatchingChoices.insert( - trailingClosureMatch); - assert((inserted.second || - inserted.first->second == trailingClosureMatch.second)); + // Remember all of the argument/parameter matching choices we made. + for (auto &argumentMatch : argumentMatchingChoices) { + auto inserted = solution.argumentMatchingChoices.insert(argumentMatch); + assert(inserted.second || inserted.first->second == argumentMatch.second); (void)inserted; } @@ -234,9 +232,9 @@ void ConstraintSystem::applySolution(const Solution &solution) { DisjunctionChoices.push_back(choice); } - // Remember all of the trailing closure matching choices we made. - for (auto &trailingClosureMatch : solution.trailingClosureMatchingChoices) { - trailingClosureMatchingChoices.push_back(trailingClosureMatch); + // Remember all of the argument/parameter matching choices we made. + for (auto &argumentMatch : solution.argumentMatchingChoices) { + argumentMatchingChoices.push_back(argumentMatch); } // Register the solution's opened types. @@ -483,7 +481,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs) numFixedRequirements = cs.FixedRequirements.size(); numDisjunctionChoices = cs.DisjunctionChoices.size(); numAppliedDisjunctions = cs.AppliedDisjunctions.size(); - numTrailingClosureMatchingChoices = cs.trailingClosureMatchingChoices.size(); + numArgumentMatchingChoices = cs.argumentMatchingChoices.size(); numOpenedTypes = cs.OpenedTypes.size(); numOpenedExistentialTypes = cs.OpenedExistentialTypes.size(); numDefaultedConstraints = cs.DefaultedConstraints.size(); @@ -542,9 +540,8 @@ ConstraintSystem::SolverScope::~SolverScope() { // Remove any applied disjunctions. truncate(cs.AppliedDisjunctions, numAppliedDisjunctions); - // Remove any trailing closure matching choices; - truncate( - cs.trailingClosureMatchingChoices, numTrailingClosureMatchingChoices); + // Remove any argument matching choices; + truncate(cs.argumentMatchingChoices, numArgumentMatchingChoices); // Remove any opened types. truncate(cs.OpenedTypes, numOpenedTypes); diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 4b30132ffc4..5e07c39d2b1 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -3065,6 +3065,16 @@ Type ConstraintSystem::simplifyType(Type type) const { }); } +void Solution::recordSingleArgMatchingChoice(ConstraintLocator *locator) { + auto &cs = getConstraintSystem(); + assert(argumentMatchingChoices.find(locator) == + argumentMatchingChoices.end() && + "recording multiple bindings for same locator"); + argumentMatchingChoices.insert( + {cs.getConstraintLocator(locator, ConstraintLocator::ApplyArgument), + MatchCallArgumentResult::forArity(1)}); +} + Type Solution::simplifyType(Type type) const { if (!(type->hasTypeVariable() || type->hasPlaceholder())) return type; diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index ab7b90cdc56..215e36c61f2 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -1031,14 +1031,13 @@ void Solution::dump(raw_ostream &out) const { out << "\n"; out << "Trailing closure matching:\n"; - for (auto &trailingClosureMatching : trailingClosureMatchingChoices) { + for (auto &argumentMatching : argumentMatchingChoices) { out.indent(2); - trailingClosureMatching.first->dump(sm, out); - switch (trailingClosureMatching.second) { + argumentMatching.first->dump(sm, out); + switch (argumentMatching.second.trailingClosureMatching) { case TrailingClosureMatching::Forward: out << ": forward\n"; break; - case TrailingClosureMatching::Backward: out << ": backward\n"; break; diff --git a/unittests/Sema/ConstraintSimplificationTests.cpp b/unittests/Sema/ConstraintSimplificationTests.cpp index 9b93450884e..7c52996f407 100644 --- a/unittests/Sema/ConstraintSimplificationTests.cpp +++ b/unittests/Sema/ConstraintSimplificationTests.cpp @@ -23,7 +23,8 @@ TEST_F(SemaTest, TestTrailingClosureMatchRecordingForIdenticalFunctions) { auto intType = getStdlibType("Int"); auto floatType = getStdlibType("Float"); - auto func = FunctionType::get({FunctionType::Param(intType)}, floatType); + auto func = FunctionType::get( + {FunctionType::Param(intType), FunctionType::Param(intType)}, floatType); cs.addConstraint( ConstraintKind::ApplicableFunction, func, func, @@ -37,7 +38,9 @@ TEST_F(SemaTest, TestTrailingClosureMatchRecordingForIdenticalFunctions) { const auto &solution = solutions.front(); auto *locator = cs.getConstraintLocator({}, ConstraintLocator::ApplyArgument); - auto choice = solution.trailingClosureMatchingChoices.find(locator); - ASSERT_TRUE(choice != solution.trailingClosureMatchingChoices.end()); - ASSERT_EQ(choice->second, TrailingClosureMatching::Forward); + auto choice = solution.argumentMatchingChoices.find(locator); + ASSERT_TRUE(choice != solution.argumentMatchingChoices.end()); + MatchCallArgumentResult expected{ + TrailingClosureMatching::Forward, {{0}, {1}}, None}; + ASSERT_EQ(choice->second, expected); }