diff --git a/include/swift/AST/Expr.h b/include/swift/AST/Expr.h index 8c38e62649f..1bdbfffdbbb 100644 --- a/include/swift/AST/Expr.h +++ b/include/swift/AST/Expr.h @@ -205,6 +205,14 @@ class alignas(8) Expr { enum { NumTupleExprBits = NumExprBits + 3 }; static_assert(NumTupleExprBits <= 32, "fits in an unsigned"); + class UnresolvedDotExprBitfields { + friend class UnresolvedDotExpr; + unsigned : NumExprBits; + unsigned FunctionRefKind : 2; + }; + enum { NumUnresolvedDotExprExprBits = NumExprBits + 2 }; + static_assert(NumUnresolvedDotExprExprBits <= 32, "fits in an unsigned"); + class SubscriptExprBitfields { friend class SubscriptExpr; unsigned : NumExprBits; @@ -254,17 +262,17 @@ class alignas(8) Expr { class OverloadSetRefExprBitfields { friend class OverloadSetRefExpr; unsigned : NumExprBits; + unsigned FunctionRefKind : 2; }; - enum { NumOverloadSetRefExprBits = NumExprBits }; + enum { NumOverloadSetRefExprBits = NumExprBits + 2}; static_assert(NumOverloadSetRefExprBits <= 32, "fits in an unsigned"); class OverloadedDeclRefExprBitfields { friend class OverloadedDeclRefExpr; unsigned : NumOverloadSetRefExprBits; unsigned IsSpecialized : 1; - unsigned FunctionRefKind : 2; }; - enum { NumOverloadedDeclRefExprBits = NumOverloadSetRefExprBits + 3 }; + enum { NumOverloadedDeclRefExprBits = NumOverloadSetRefExprBits + 1 }; static_assert(NumOverloadedDeclRefExprBits <= 32, "fits in an unsigned"); class BooleanLiteralExprBitfields { @@ -422,6 +430,7 @@ protected: UnresolvedDeclRefExprBitfields UnresolvedDeclRefExprBits; TupleExprBitfields TupleExprBits; MemberRefExprBitfields MemberRefExprBits; + UnresolvedDotExprBitfields UnresolvedDotExprBits; SubscriptExprBitfields SubscriptExprBits; DynamicSubscriptExprBitfields DynamicSubscriptExprBits; UnresolvedMemberExprBitfields UnresolvedMemberExprBits; @@ -1400,9 +1409,12 @@ class OverloadSetRefExpr : public Expr { ArrayRef Decls; protected: - OverloadSetRefExpr(ExprKind Kind, ArrayRef decls, bool Implicit, - Type Ty) - : Expr(Kind, Implicit, Ty), Decls(decls) {} + OverloadSetRefExpr(ExprKind Kind, ArrayRef decls, + FunctionRefKind functionRefKind, bool Implicit, Type Ty) + : Expr(Kind, Implicit, Ty), Decls(decls) { + OverloadSetRefExprBits.FunctionRefKind = + static_cast(functionRefKind); + } public: ArrayRef getDecls() const { return Decls; } @@ -1416,6 +1428,17 @@ public: /// concrete base object (which is not a metatype). bool hasBaseObject() const; + /// Retrieve the kind of function reference. + FunctionRefKind getFunctionRefKind() const { + return static_cast( + OverloadSetRefExprBits.FunctionRefKind); + } + + /// Set the kind of function reference. + void setFunctionRefKind(FunctionRefKind refKind) { + OverloadSetRefExprBits.FunctionRefKind = static_cast(refKind); + } + static bool classof(const Expr *E) { return E->getKind() >= ExprKind::First_OverloadSetRefExpr && E->getKind() <= ExprKind::Last_OverloadSetRefExpr; @@ -1432,11 +1455,10 @@ public: bool isSpecialized, FunctionRefKind functionRefKind, bool Implicit, Type Ty = Type()) - : OverloadSetRefExpr(ExprKind::OverloadedDeclRef, Decls, Implicit, Ty), + : OverloadSetRefExpr(ExprKind::OverloadedDeclRef, Decls, functionRefKind, + Implicit, Ty), Loc(Loc) { OverloadedDeclRefExprBits.IsSpecialized = isSpecialized; - OverloadedDeclRefExprBits.FunctionRefKind = - static_cast(functionRefKind); } DeclNameLoc getNameLoc() const { return Loc; } @@ -1449,17 +1471,6 @@ public: return OverloadedDeclRefExprBits.IsSpecialized; } - /// Retrieve the kind of function reference. - FunctionRefKind getFunctionRefKind() const { - return static_cast( - OverloadedDeclRefExprBits.FunctionRefKind); - } - - /// Set the kind of function reference. - void setFunctionRefKind(FunctionRefKind refKind) { - OverloadedDeclRefExprBits.FunctionRefKind = static_cast(refKind); - } - static bool classof(const Expr *E) { return E->getKind() == ExprKind::OverloadedDeclRef; } @@ -2317,8 +2328,12 @@ class UnresolvedDotExpr : public Expr { public: UnresolvedDotExpr(Expr *subexpr, SourceLoc dotloc, DeclName name, DeclNameLoc nameloc, bool Implicit) - : Expr(ExprKind::UnresolvedDot, Implicit), SubExpr(subexpr), DotLoc(dotloc), - NameLoc(nameloc), Name(name) {} + : Expr(ExprKind::UnresolvedDot, Implicit), SubExpr(subexpr), DotLoc(dotloc), + NameLoc(nameloc), Name(name) { + UnresolvedDotExprBits.FunctionRefKind = + static_cast(NameLoc.isCompound() ? FunctionRefKind::Compound + : FunctionRefKind::Unapplied); + } SourceLoc getLoc() const { return NameLoc.getBaseNameLoc(); } @@ -2337,6 +2352,16 @@ public: DeclName getName() const { return Name; } DeclNameLoc getNameLoc() const { return NameLoc; } + /// Retrieve the kind of function reference. + FunctionRefKind getFunctionRefKind() const { + return static_cast(UnresolvedDotExprBits.FunctionRefKind); + } + + /// Set the kind of function reference. + void setFunctionRefKind(FunctionRefKind refKind) { + UnresolvedDotExprBits.FunctionRefKind = static_cast(refKind); + } + static bool classof(const Expr *E) { return E->getKind() == ExprKind::UnresolvedDot; } diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index 9099edcf26f..407a4c7e7a8 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -1860,7 +1860,8 @@ public: } void visitUnresolvedDotExpr(UnresolvedDotExpr *E) { printCommon(E, "unresolved_dot_expr") - << " field '" << E->getName() << "'"; + << " field '" << E->getName() << "'" + << " function_ref=" << getFunctionRefKindStr(E->getFunctionRefKind()); if (E->getBase()) { OS << '\n'; printRec(E->getBase()); diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 6ce577170fb..662847a3c82 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -6787,6 +6787,7 @@ Expr *TypeChecker::callWitness(Expr *base, DeclContext *dc, = cs.getTypeOfMemberReference(base->getType(), witness, /*isTypeReference=*/false, /*isDynamicResult=*/false, + FunctionRefKind::DoubleApply, dotLocator); // Form the call argument. diff --git a/lib/Sema/CSDiag.cpp b/lib/Sema/CSDiag.cpp index 75c1099af8e..4429d5934d4 100644 --- a/lib/Sema/CSDiag.cpp +++ b/lib/Sema/CSDiag.cpp @@ -2220,7 +2220,8 @@ bool FailureDiagnosis::diagnoseGeneralMemberFailure(Constraint *constraint) { MemberLookupResult result = CS->performMemberLookup(constraint->getKind(), constraint->getMember(), - baseTy, constraint->getLocator(), + baseTy, constraint->getFunctionRefKind(), + constraint->getLocator(), /*includeInaccessibleMembers*/true); switch (result.OverallResult) { @@ -4535,7 +4536,7 @@ bool FailureDiagnosis::visitSubscriptExpr(SubscriptExpr *SE) { MemberLookupResult result = CS->performMemberLookup(ConstraintKind::ValueMember, subscriptName, - baseType, locator, + baseType, FunctionRefKind::DoubleApply, locator, /*includeInaccessibleMembers*/true); @@ -5716,7 +5717,9 @@ bool FailureDiagnosis::visitUnresolvedMemberExpr(UnresolvedMemberExpr *E) { MemberLookupResult result = CS->performMemberLookup(memberConstraint->getKind(), memberConstraint->getMember(), - baseObjTy, memberConstraint->getLocator(), + baseObjTy, + memberConstraint->getFunctionRefKind(), + memberConstraint->getLocator(), /*includeInaccessibleMembers*/true); switch (result.OverallResult) { diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index d829b2cc29d..e79c027cc6b 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -34,7 +34,8 @@ static Expr *skipImplicitConversions(Expr *expr) { } /// \brief Find the declaration directly referenced by this expression. -static ValueDecl *findReferencedDecl(Expr *expr, DeclNameLoc &loc) { +static std::pair +findReferencedDecl(Expr *expr, DeclNameLoc &loc) { do { expr = expr->getSemanticsProvidingExpr(); @@ -45,10 +46,10 @@ static ValueDecl *findReferencedDecl(Expr *expr, DeclNameLoc &loc) { if (auto dre = dyn_cast(expr)) { loc = dre->getNameLoc(); - return dre->getDecl(); + return { dre->getDecl(), dre->getFunctionRefKind() }; } - return nullptr; + return { nullptr, FunctionRefKind::Unapplied }; } while (true); } @@ -1002,12 +1003,11 @@ namespace { // The base must have a member of the given name, such that accessing // that member through the base returns a value convertible to the type // of this expression. - // FIXME: use functionRefKind auto baseTy = base->getType(); auto tv = CS.createTypeVariable( CS.getConstraintLocator(expr, ConstraintLocator::Member), TVO_CanBindToLValue); - CS.addValueMemberConstraint(baseTy, name, tv, + CS.addValueMemberConstraint(baseTy, name, tv, functionRefKind, CS.getConstraintLocator(expr, ConstraintLocator::Member)); return tv; } @@ -1119,11 +1119,12 @@ namespace { // UnresolvedSubscriptExpr from SubscriptExpr. if (decl) { OverloadChoice choice(base->getType(), decl, /*isSpecialized=*/false, - FunctionRefKind::SingleApply); + FunctionRefKind::DoubleApply); CS.addBindOverloadConstraint(fnTy, choice, subscriptMemberLocator); } else { CS.addValueMemberConstraint(baseTy, Context.Id_subscript, - fnTy, subscriptMemberLocator); + fnTy, FunctionRefKind::DoubleApply, + subscriptMemberLocator); } // Add the constraint that the index expression's type be convertible @@ -1208,6 +1209,7 @@ namespace { segment->getType(), segmentTyV, Identifier(), + FunctionRefKind::Compound, locator)); DeclName segmentName(C, C.Id_init, { C.Id_stringInterpolationSegment }); @@ -1215,6 +1217,7 @@ namespace { tvMeta, methodTy, segmentName, + FunctionRefKind::DoubleApply, locator)); } @@ -1384,7 +1387,7 @@ namespace { choices.push_back(OverloadChoice(Type(), decls[i], expr->isSpecialized(), - /*FIXME:*/FunctionRefKind::DoubleApply)); + expr->getFunctionRefKind())); } // If there are no valid overloads, give up. @@ -1420,6 +1423,10 @@ namespace { auto baseLocator = CS.getConstraintLocator( expr, ConstraintLocator::MemberRefBase); + FunctionRefKind functionRefKind = + expr->getArgument() ? FunctionRefKind::DoubleApply + : FunctionRefKind::Compound; + auto memberLocator = CS.getConstraintLocator(expr, ConstraintLocator::UnresolvedMember); auto baseTy = CS.createTypeVariable(baseLocator, /*options=*/0); @@ -1434,7 +1441,8 @@ namespace { // member, i.e., an enum case or a static variable. auto baseMetaTy = MetatypeType::get(baseTy); CS.addUnresolvedValueMemberConstraint(baseMetaTy, expr->getName(), - memberTy, memberLocator); + memberTy, functionRefKind, + memberLocator); // If there is an argument, apply it. if (auto arg = expr->getArgument()) { @@ -1494,7 +1502,7 @@ namespace { /*options=*/0); auto methodTy = FunctionType::get(argsTy, resultTy); CS.addValueMemberConstraint(baseTy, expr->getName(), - methodTy, + methodTy, expr->getFunctionRefKind(), CS.getConstraintLocator(expr, ConstraintLocator::ConstructorMember)); // The result of the expression is the partial application of the @@ -1503,7 +1511,7 @@ namespace { } return addMemberRefConstraints(expr, expr->getBase(), expr->getName(), - /*FIXME:*/FunctionRefKind::DoubleApply); + expr->getFunctionRefKind()); } Type visitUnresolvedSpecializeExpr(UnresolvedSpecializeExpr *expr) { @@ -2495,7 +2503,9 @@ namespace { if (CS.shouldAttemptFixes()) { Constraint *coerceConstraint = Constraint::create(CS, ConstraintKind::ExplicitConversion, - fromType, toType, DeclName(), locator); + fromType, toType, DeclName(), + FunctionRefKind::Compound, + locator); Constraint *downcastConstraint = Constraint::createFixed(CS, ConstraintKind::CheckedCast, FixKind::CoerceToCheckedCast, fromType, @@ -2730,10 +2740,14 @@ namespace { // type-checked down to a call; turn it back into an overloaded // member reference expression. DeclNameLoc memberLoc; - if (auto member = findReferencedDecl(dotCall->getFn(), memberLoc)) { + auto memberAndFunctionRef = findReferencedDecl(dotCall->getFn(), + memberLoc); + if (memberAndFunctionRef.first) { auto base = skipImplicitConversions(dotCall->getArg()); return new (TC.Context) MemberRefExpr(base, - dotCall->getDotLoc(), member, memberLoc, + dotCall->getDotLoc(), + memberAndFunctionRef.first, + memberLoc, expr->isImplicit()); } } @@ -2744,10 +2758,13 @@ namespace { // actually matter; turn it back into an overloaded member reference // expression. DeclNameLoc memberLoc; - if (auto member = findReferencedDecl(dotIgnored->getRHS(), memberLoc)) { + auto memberAndFunctionRef = findReferencedDecl(dotIgnored->getRHS(), + memberLoc); + if (memberAndFunctionRef.first) { auto base = skipImplicitConversions(dotIgnored->getLHS()); return new (TC.Context) MemberRefExpr(base, - dotIgnored->getDotLoc(), member, + dotIgnored->getDotLoc(), + memberAndFunctionRef.first, memberLoc, expr->isImplicit()); } } @@ -3073,9 +3090,9 @@ Type swift::checkMemberType(DeclContext &DC, Type BaseTy, TypeVariableOptions::TVO_CanBindToLValue); CS.addConstraint(Constraint::createDisjunction(CS, { Constraint::create(CS, ConstraintKind::TypeMember, Ty, - TV, DeclName(Id), Loc), + TV, DeclName(Id), FunctionRefKind::Compound, Loc), Constraint::create(CS, ConstraintKind::ValueMember, Ty, - TV, DeclName(Id), Loc) + TV, DeclName(Id), FunctionRefKind::DoubleApply, Loc) }, Loc)); Ty = TV; } @@ -3172,7 +3189,8 @@ bool swift::isExtensionApplied(DeclContext &DC, Type BaseTy, return; } // Add constraints accordingly. - CS.addConstraint(Constraint::create(CS, Kind, First, Second, DeclName(), Loc)); + CS.addConstraint(Constraint::create(CS, Kind, First, Second, DeclName(), + FunctionRefKind::Compound, Loc)); }; // For every requirement, add a constraint. @@ -3221,6 +3239,7 @@ static bool canSatisfy(Type T1, Type T2, DeclContext &DC, ConstraintKind Kind, T2 = T2.transform(Trans); } CS.addConstraint(Constraint::create(CS, Kind, T1, T2, DeclName(), + FunctionRefKind::Compound, CS.getConstraintLocator(nullptr))); SmallVector Solutions; return AllowFreeVariables ? @@ -3252,7 +3271,8 @@ swift::resolveValueMember(DeclContext &DC, Type BaseTy, DeclName Name) { } ConstraintSystem CS(*TC, &DC, None); MemberLookupResult LookupResult = CS.performMemberLookup( - ConstraintKind::ValueMember, Name, BaseTy, nullptr, false); + ConstraintKind::ValueMember, Name, BaseTy, FunctionRefKind::DoubleApply, + nullptr, false); if (LookupResult.ViableCandidates.empty()) return Result; ConstraintLocator *Locator = CS.getConstraintLocator(nullptr); diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index b1fe19062f5..1474285e425 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -2122,10 +2122,9 @@ commit_to_conversions: } ConstraintSystem::SolutionKind -ConstraintSystem::simplifyConstructionConstraint(Type valueType, - FunctionType *fnType, - unsigned flags, - ConstraintLocator *locator) { +ConstraintSystem::simplifyConstructionConstraint( + Type valueType, FunctionType *fnType, unsigned flags, + FunctionRefKind functionRefKind, ConstraintLocator *locator) { // Desugar the value type. auto desugarValueType = valueType->getDesugaredType(); @@ -2231,7 +2230,7 @@ ConstraintSystem::simplifyConstructionConstraint(Type valueType, // variable T. T2 is the result type provided via the construction // constraint itself. addValueMemberConstraint(MetatypeType::get(valueType, TC.Context), name, - FunctionType::get(tv, resultType), + FunctionType::get(tv, resultType), functionRefKind, getConstraintLocator( fnLocator, ConstraintLocator::ConstructorMember)); @@ -2617,12 +2616,9 @@ getArgumentLabels(ConstraintSystem &cs, ConstraintLocatorBuilder locator) { /// referenced. MemberLookupResult ConstraintSystem:: performMemberLookup(ConstraintKind constraintKind, DeclName memberName, - Type baseTy, ConstraintLocator *memberLocator, + Type baseTy, FunctionRefKind functionRefKind, + ConstraintLocator *memberLocator, bool includeInaccessibleMembers) { - // FIXME: FunctionRefKind::DoubleApply is a hack that maintains all - // label info. - FunctionRefKind functionRefKind = FunctionRefKind::DoubleApply; - Type baseObjTy = baseTy->getRValueType(); // Dig out the instance type and figure out what members of the instance type @@ -3139,7 +3135,8 @@ ConstraintSystem::simplifyMemberConstraint(const Constraint &constraint) { MemberLookupResult result = performMemberLookup(constraint.getKind(), constraint.getMember(), - baseTy, constraint.getLocator(), + baseTy, constraint.getFunctionRefKind(), + constraint.getLocator(), /*includeInaccessibleMembers*/false); Type memberTy = constraint.getSecondType(); @@ -3192,6 +3189,7 @@ ConstraintSystem::simplifyMemberConstraint(const Constraint &constraint) { baseObjTy->getOptionalObjectType(), constraint.getSecondType(), constraint.getMember(), + constraint.getFunctionRefKind(), constraint.getLocator())); return SolutionKind::Solved; } @@ -3224,6 +3222,7 @@ ConstraintSystem::simplifyMemberConstraint(const Constraint &constraint) { addValueMemberConstraint(baseObjTy->getOptionalObjectType(), constraint.getMember(), constraint.getSecondType(), + constraint.getFunctionRefKind(), constraint.getLocator()); return SolutionKind::Solved; } @@ -3451,7 +3450,8 @@ retry: if (auto meta2 = dyn_cast(desugar2)) { // Construct the instance from the input arguments. return simplifyConstructionConstraint(meta2->getInstanceType(), func1, - flags, + flags, + FunctionRefKind::SingleApply, getConstraintLocator(outerLocator)); } @@ -3749,14 +3749,17 @@ ConstraintSystem::simplifyRestrictedConstraint(ConversionRestrictionKind restric auto int8Con = Constraint::create(*this, ConstraintKind::Bind, btv2, TC.getInt8Type(DC), DeclName(), + FunctionRefKind::Compound, getConstraintLocator(locator)); auto uint8Con = Constraint::create(*this, ConstraintKind::Bind, btv2, TC.getUInt8Type(DC), DeclName(), + FunctionRefKind::Compound, getConstraintLocator(locator)); auto voidCon = Constraint::create(*this, ConstraintKind::Bind, btv2, TC.Context.TheEmptyTupleType, DeclName(), + FunctionRefKind::Compound, getConstraintLocator(locator)); Constraint *disjunctionChoices[] = {int8Con, uint8Con, voidCon}; @@ -4113,6 +4116,7 @@ ConstraintSystem::simplifyRestrictedConstraint(ConversionRestrictionKind restric Constraint::create(*this, ConstraintKind::BridgedToObjectiveC, arg, Type(), DeclName(), + FunctionRefKind::Compound, getConstraintLocator( locator.withPathElement( LocatorPathElt::getGenericArgument( diff --git a/lib/Sema/Constraint.cpp b/lib/Sema/Constraint.cpp index 6e5935e9fed..59e6ebe1246 100644 --- a/lib/Sema/Constraint.cpp +++ b/lib/Sema/Constraint.cpp @@ -39,12 +39,16 @@ Constraint::Constraint(ConstraintKind kind, ArrayRef constraints, } Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, - DeclName Member, ConstraintLocator *locator, + DeclName Member, FunctionRefKind functionRefKind, + ConstraintLocator *locator, ArrayRef typeVars) : Kind(Kind), HasRestriction(false), HasFix(false), IsActive(false), RememberChoice(false), IsFavored(false), NumTypeVariables(typeVars.size()), Types { First, Second, Member }, Locator(locator) { + TheFunctionRefKind = static_cast(functionRefKind); + assert(getFunctionRefKind() == functionRefKind); + switch (Kind) { case ConstraintKind::Bind: case ConstraintKind::Equal: @@ -128,7 +132,7 @@ Constraint::Constraint(ConstraintKind kind, Constraint::Constraint(ConstraintKind kind, Fix fix, Type first, Type second, ConstraintLocator *locator, ArrayRef typeVars) - : Kind(kind), TheFix(fix.getKind()), FixData(fix.getData()), + : Kind(kind), FixData(fix.getData()), TheFix(fix.getKind()), HasRestriction(false), HasFix(true), IsActive(false), RememberChoice(false), IsFavored(false), NumTypeVariables(typeVars.size()), @@ -165,7 +169,7 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const { case ConstraintKind::ApplicableFunction: case ConstraintKind::OptionalObject: return create(cs, getKind(), getFirstType(), getSecondType(), - DeclName(), getLocator()); + DeclName(), FunctionRefKind::Compound, getLocator()); case ConstraintKind::BindOverload: return createBindOverload(cs, getFirstType(), getOverloadChoice(), @@ -175,17 +179,17 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const { case ConstraintKind::UnresolvedValueMember: case ConstraintKind::TypeMember: return create(cs, getKind(), getFirstType(), Type(), getMember(), - getLocator()); + getFunctionRefKind(), getLocator()); case ConstraintKind::Defaultable: return create(cs, getKind(), getFirstType(), getSecondType(), - getMember(), getLocator()); + getMember(), getFunctionRefKind(), getLocator()); case ConstraintKind::Archetype: case ConstraintKind::Class: case ConstraintKind::BridgedToObjectiveC: return create(cs, getKind(), getFirstType(), Type(), DeclName(), - getLocator()); + FunctionRefKind::Compound, getLocator()); case ConstraintKind::Disjunction: return createDisjunction(cs, getNestedConstraints(), getLocator()); @@ -514,6 +518,7 @@ static void uniqueTypeVariables(SmallVectorImpl &typeVars) { Constraint *Constraint::create(ConstraintSystem &cs, ConstraintKind kind, Type first, Type second, DeclName member, + FunctionRefKind functionRefKind, ConstraintLocator *locator) { // Collect type variables. SmallVector typeVars; @@ -526,7 +531,8 @@ Constraint *Constraint::create(ConstraintSystem &cs, ConstraintKind kind, // Create the constraint. unsigned size = totalSizeToAlloc(typeVars.size()); void *mem = cs.getAllocator().Allocate(size, alignof(Constraint)); - return new (mem) Constraint(kind, first, second, member, locator, typeVars); + return new (mem) Constraint(kind, first, second, member, functionRefKind, + locator, typeVars); } Constraint *Constraint::createBindOverload(ConstraintSystem &cs, Type type, diff --git a/lib/Sema/Constraint.h b/lib/Sema/Constraint.h index 948c5b60c5c..582e95dd6db 100644 --- a/lib/Sema/Constraint.h +++ b/lib/Sema/Constraint.h @@ -19,6 +19,7 @@ #define SWIFT_SEMA_CONSTRAINT_H #include "OverloadChoice.h" +#include "swift/AST/FunctionRefKind.h" #include "swift/AST/Identifier.h" #include "swift/AST/Type.h" #include "llvm/ADT/ArrayRef.h" @@ -295,12 +296,12 @@ class Constraint final : public llvm::ilist_node, /// The kind of restriction placed on this constraint. ConversionRestrictionKind Restriction : 8; - /// The kind of fix to be applied to the constraint before visiting it. - FixKind TheFix; - /// Data associated with the fix. uint16_t FixData; + /// The kind of fix to be applied to the constraint before visiting it. + FixKind TheFix; + /// Whether the \c Restriction field is valid. unsigned HasRestriction : 1; @@ -324,6 +325,9 @@ class Constraint final : public llvm::ilist_node, /// The type variables themselves are tail-allocated. unsigned NumTypeVariables : 11; + /// The kind of function reference, for member references. + unsigned TheFunctionRefKind : 2; + union { struct { /// \brief The first type. @@ -362,7 +366,9 @@ class Constraint final : public llvm::ilist_node, /// Construct a new constraint. Constraint(ConstraintKind kind, Type first, Type second, DeclName member, - ConstraintLocator *locator, ArrayRef typeVars); + FunctionRefKind functionRefKind, + ConstraintLocator *locator, + ArrayRef typeVars); /// Construct a new overload-binding constraint. Constraint(Type type, OverloadChoice choice, ConstraintLocator *locator, @@ -387,6 +393,7 @@ public: /// Create a new constraint. static Constraint *create(ConstraintSystem &cs, ConstraintKind Kind, Type First, Type Second, DeclName Member, + FunctionRefKind functionRefKind, ConstraintLocator *locator); /// Create an overload-binding constraint. @@ -524,6 +531,16 @@ public: || kind == ConstraintKind::TypeMember; } + /// Determine the kind of function reference we have for a member reference. + FunctionRefKind getFunctionRefKind() const { + if (Kind == ConstraintKind::ValueMember || + Kind == ConstraintKind::UnresolvedValueMember) + return static_cast(TheFunctionRefKind); + + // Conservative answer: drop all of the labels. + return FunctionRefKind::Compound; + } + /// Retrieve the set of constraints in a disjunction. ArrayRef getNestedConstraints() const { assert(Kind == ConstraintKind::Disjunction); diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 6b687aba8b7..c852632228d 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -406,7 +406,8 @@ ConstraintSystem::getMemberType(TypeVariableType *baseTypeVar, auto memberTypeVar = createTypeVariable(loc, options); addConstraint(Constraint::create(*this, ConstraintKind::TypeMember, baseTypeVar, memberTypeVar, - assocType->getName(), loc)); + assocType->getName(), + FunctionRefKind::Compound, loc)); return memberTypeVar; }); } @@ -441,7 +442,9 @@ namespace { TVO_PrefersSubtypeBinding); CS.addConstraint(Constraint::create(CS, ConstraintKind::TypeMember, baseTypeVar, memberTypeVar, - member->getName(), locator)); + member->getName(), + FunctionRefKind::Compound, + locator)); return memberTypeVar; } @@ -472,7 +475,9 @@ namespace { // Bind the member's type variable as a type member of the base. CS.addConstraint(Constraint::create(CS, ConstraintKind::TypeMember, baseTypeVar, memberTypeVar, - member->getName(), locator)); + member->getName(), + FunctionRefKind::Compound, + locator)); if (!archetype) { // If the nested type is not an archetype (because it was constrained @@ -1176,12 +1181,10 @@ ConstraintSystem::getTypeOfMemberReference( Type baseTy, ValueDecl *value, bool isTypeReference, bool isDynamicResult, + FunctionRefKind functionRefKind, ConstraintLocatorBuilder locator, const DeclRefExpr *base, llvm::DenseMap *replacementsPtr) { - // FIXME: Should receive the function reference kind as a parameter. - FunctionRefKind functionRefKind = FunctionRefKind::DoubleApply; - // Figure out the instance type used for the base. TypeVariableType *baseTypeVar = nullptr; Type baseObjTy = getFixedTypeRecursive(baseTy, baseTypeVar, @@ -1253,11 +1256,12 @@ ConstraintSystem::getTypeOfMemberReference( auto isClassBoundExistential = false; llvm::DenseMap localReplacements; auto &replacements = replacementsPtr ? *replacementsPtr : localReplacements; + bool isCurriedInstanceReference = value->isInstanceMember() && !isInstance; + unsigned numRemovedArgumentLabels = + getNumRemovedArgumentLabels(TC.Context, value, isCurriedInstanceReference, + functionRefKind); + if (auto genericFn = value->getInterfaceType()->getAs()){ - bool isCurriedInstanceReference = value->isInstanceMember() && !isInstance; - unsigned numRemovedArgumentLabels = - getNumRemovedArgumentLabels(TC.Context, value, isCurriedInstanceReference, - functionRefKind); openedType = openFunctionType(genericFn, numRemovedArgumentLabels, locator, replacements, innerDC, outerDC, /*skipProtocolSelfConstraint=*/true); @@ -1302,6 +1306,9 @@ ConstraintSystem::getTypeOfMemberReference( selfTy = outerDC->getDeclaredTypeOfContext(); } + // Remove argument labels, if needed. + openedType = removeArgumentLabels(openedType, numRemovedArgumentLabels); + // If we have a type reference, look through the metatype. if (isTypeReference) openedType = openedType->castTo()->getInstanceType(); @@ -1509,6 +1516,7 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator, std::tie(openedFullType, refType) = getTypeOfMemberReference(choice.getBaseType(), choice.getDecl(), isTypeReference, isDynamicResult, + choice.getFunctionRefKind(), locator, base, nullptr); } else { std::tie(openedFullType, refType) diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h index fcca32a6d1f..f7a42080211 100644 --- a/lib/Sema/ConstraintSystem.h +++ b/lib/Sema/ConstraintSystem.h @@ -1356,7 +1356,7 @@ public: assert(first && "Missing first type"); assert(second && "Missing second type"); auto c = Constraint::create(*this, kind, first, second, DeclName(), - locator); + FunctionRefKind::Compound, locator); if (isFavored) c->setFavored(); addConstraint(c); } @@ -1370,25 +1370,29 @@ public: /// \brief Add a value member constraint to the constraint system. void addValueMemberConstraint(Type baseTy, DeclName name, Type memberTy, + FunctionRefKind functionRefKind, ConstraintLocator *locator) { assert(baseTy); assert(memberTy); assert(name); addConstraint(Constraint::create(*this, ConstraintKind::ValueMember, - baseTy, memberTy, name, locator)); + baseTy, memberTy, name, functionRefKind, + locator)); } /// \brief Add a value member constraint for an UnresolvedMemberRef /// to the constraint system. void addUnresolvedValueMemberConstraint(Type baseTy, DeclName name, Type memberTy, + FunctionRefKind functionRefKind, ConstraintLocator *locator) { assert(baseTy); assert(memberTy); assert(name); addConstraint(Constraint::create(*this, ConstraintKind::UnresolvedValueMember, - baseTy, memberTy, name, locator)); + baseTy, memberTy, name, functionRefKind, + locator)); } /// \brief Add an archetype constraint. @@ -1396,7 +1400,7 @@ public: assert(baseTy); addConstraint(Constraint::create(*this, ConstraintKind::Archetype, baseTy, Type(), DeclName(), - locator)); + FunctionRefKind::Compound, locator)); } /// \brief Remove an inactive constraint from the current constraint graph. @@ -1634,6 +1638,7 @@ public: Type baseTy, ValueDecl *decl, bool isTypeReference, bool isDynamicResult, + FunctionRefKind functionRefKind, ConstraintLocatorBuilder locator, const DeclRefExpr *base = nullptr, llvm::DenseMap @@ -1830,6 +1835,7 @@ public: /// referenced. MemberLookupResult performMemberLookup(ConstraintKind constraintKind, DeclName memberName, Type baseTy, + FunctionRefKind functionRefKind, ConstraintLocator *memberLocator, bool includeInaccessibleMembers); @@ -1865,6 +1871,7 @@ private: SolutionKind simplifyConstructionConstraint(Type valueType, FunctionType *fnType, unsigned flags, + FunctionRefKind functionRefKind, ConstraintLocator *locator); /// \brief Attempt to simplify the given conformance constraint. diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index ff4f38ea15f..2c7d59a5e42 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -2044,7 +2044,9 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) { cs.addConstraint(Constraint::create( cs, ConstraintKind::TypeMember, SequenceType, iteratorType, - tc.Context.Id_Iterator, iteratorLocator)); + tc.Context.Id_Iterator, + FunctionRefKind::Compound, + iteratorLocator)); // Determine the element type of the iterator. // FIXME: Should look up the type witness. @@ -2052,7 +2054,9 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) { cs.addConstraint(Constraint::create( cs, ConstraintKind::TypeMember, iteratorType, elementType, - tc.Context.Id_Element, elementLocator)); + tc.Context.Id_Element, + FunctionRefKind::Compound, + elementLocator)); } diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index ebecac76136..67ffd3ae9a2 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -944,6 +944,7 @@ matchWitness(TypeChecker &tc, = cs->getTypeOfMemberReference(selfTy, witness, /*isTypeReference=*/false, /*isDynamicResult=*/false, + FunctionRefKind::DoubleApply, witnessLocator, /*base=*/nullptr, /*opener=*/nullptr); @@ -970,6 +971,7 @@ matchWitness(TypeChecker &tc, = cs->getTypeOfMemberReference(selfTy, req, /*isTypeReference=*/false, /*isDynamicResult=*/false, + FunctionRefKind::DoubleApply, reqLocator, /*base=*/nullptr, &replacements); diff --git a/test/Sema/suppress-argument-labels-in-types.swift b/test/Sema/suppress-argument-labels-in-types.swift index 36ddad93004..f5cabf9a70b 100644 --- a/test/Sema/suppress-argument-labels-in-types.swift +++ b/test/Sema/suppress-argument-labels-in-types.swift @@ -1,4 +1,4 @@ -// RUN: %target-swift-frontend -parse -verify -suppress-argument-labels-in-types %s +// RUN: %target-swift-frontend -module-name TestModule -parse -verify -suppress-argument-labels-in-types %s // Test non-overloaded global function references. func f1(a: Int, b: Int) -> Int { } @@ -10,9 +10,7 @@ func testF1(a: Int, b: Int) { _ = f1(a:b:)(1, 2) // compound name suppresses argument labels - let i: Int = f1 // expected-error{{cannot convert value of type '(Int, Int) -> Int' to specified type 'Int'}} - - _ = i + let _: Int = f1 // expected-error{{cannot convert value of type '(Int, Int) -> Int' to specified type 'Int'}} } // Test multiple levels of currying. @@ -65,3 +63,116 @@ func testF4(a: Int, b: Int, c: Double, d: Double) { let _: (x: Int, y: Int) -> Int = f4 let _: (x: Double, y: Double) -> Double = f4 } + +// Test module-qualified function references. +func testModuleQualifiedRef(a: Int, b: Int, c: Double, d: Double) { + _ = TestModule.f1(a: a, b: a) // okay: direct call requires argument labels + _ = (TestModule.f1)(a: a, b: a) // okay: direct call requires argument labels + _ = ((TestModule.f1))(a: a, b: a) // okay: direct call requires argument labels + + _ = TestModule.f1(a:b:)(1, 2) // compound name suppresses argument labels + + let _: Int = TestModule.f1 // expected-error{{cannot convert value of type '(Int, Int) -> Int' to specified type 'Int'}} + + _ = TestModule.f4(a: a, b: a) // okay: direct call requires argument labels + _ = (TestModule.f4)(a: a, b: a) // okay: direct call requires argument labels + _ = ((TestModule.f4))(a: a, b: a) // okay: direct call requires argument labels + _ = TestModule.f4(c: c, d: d) // okay: direct call requires argument labels + _ = (TestModule.f4)(c: c, d: d) // okay: direct call requires argument labels + _ = ((TestModule.f4))(c: c, d: d) // okay: direct call requires argument labels + + _ = TestModule.f4(a:b:)(1, 2) // compound name suppresses argument labels + _ = TestModule.f4(c:d:)(1.5, 2.5) // compound name suppresses argument labels + + let _: (Int, Int) -> Int = TestModule.f4 + let _: (Double, Double) -> Double = TestModule.f4 + + // Note: these will become ill-formed when the rest of SE-0111 is + // implemented. For now, they check that the labels were removed by the type + // system. + let _: (x: Int, y: Int) -> Int = TestModule.f4 + let _: (x: Double, y: Double) -> Double = TestModule.f4 +} + +// Test member references. +struct S0 { + init(a: Int, b: Int) { } + + func f1(a: Int, b: Int) -> Int { } + func f2(a: Int, b: Int) -> (Int) -> (Int) -> Int { } + + func f4(a: Int, b: Int) -> Int { } + func f4(c: Double, d: Double) -> Double { } + + subscript (a a: Int, b b: Int) -> Int { + get { } + set { } + } +} + +func testS0Methods(s0: S0, a: Int, b: Int, c: Double, d: Double) { + _ = s0.f1(a: a, b: a) // okay: direct call requires argument labels + _ = (s0.f1)(a: a, b: a) // okay: direct call requires argument labels + _ = ((s0.f1))(a: a, b: a) // okay: direct call requires argument labels + + _ = s0.f1(a:b:)(a, b) // compound name suppresses argument labels + + let _: Int = s0.f1 // expected-error{{cannot convert value of type '(Int, Int) -> Int' to specified type 'Int'}} + + _ = s0.f2(a: a, b: b)(a) // okay: direct call requires argument labels + _ = s0.f2(a: a, b: b)(a)(b) // okay: direct call requires argument labels + + _ = s0.f4(a: a, b: a) // okay: direct call requires argument labels + _ = (s0.f4)(a: a, b: a) // okay: direct call requires argument labels + _ = ((s0.f4))(a: a, b: a) // okay: direct call requires argument labels + _ = s0.f4(c: c, d: d) // okay: direct call requires argument labels + _ = (s0.f4)(c: c, d: d) // okay: direct call requires argument labels + _ = ((s0.f4))(c: c, d: d) // okay: direct call requires argument labels + + _ = s0.f4(a:b:)(1, 2) // compound name suppresses argument labels + _ = s0.f4(c:d:)(1.5, 2.5) // compound name suppresses argument labels + + let _: (Int, Int) -> Int = s0.f4 + let _: (Double, Double) -> Double = s0.f4 + + // Note: these will become ill-formed when the rest of SE-0111 is + // implemented. For now, they check that the labels were removed by the type + // system. + let _: (x: Int, y: Int) -> Int = s0.f4 + let _: (x: Double, y: Double) -> Double = s0.f4 +} + +// Curried instance methods. +func testS0CurriedInstanceMethods(s0: S0, a: Int, b: Int) { + _ = S0.f1(s0)(a: a, b: a) // okay: direct call requires argument labels + _ = (S0.f1)(s0)(a: a, b: a) // okay: direct call requires argument labels + _ = ((S0.f1))(s0)(a: a, b: a) // okay: direct call requires argument labels + + _ = S0.f1(a:b:)(s0)(a, b) // compound name suppresses argument labels + + let _: Int = S0.f1 // expected-error{{cannot convert value of type '(S0) -> (Int, Int) -> Int' to specified type 'Int'}} + let f1OneLevel = S0.f1(s0) + let _: Int = f1OneLevel // expected-error{{cannot convert value of type '(Int, Int) -> Int' to specified type 'Int'}} +} + +// Initializers. +func testS0Initializers(s0: S0, a: Int, b: Int) { + let _ = S0(a: a, b: b) // okay: direct call requires argument labels + let _ = S0.init(a: a, b: b) // okay: direct call requires argument labels + + let _ = S0.init(a:b:)(a, b) // compound name suppresses argument labels + + // Curried references to the initializer drop argument labels. + let s0c1 = S0.init + let _: Int = s0c1 // expected-error{{cannot convert value of type '(Int, Int) -> S0' to specified type 'Int'}} + + let s0c2 = S0.init(a:b:) + let _: Int = s0c2 // expected-error{{cannot convert value of type '(Int, Int) -> S0' to specified type 'Int'}} +} + +func testS0Subscripts(s0: S0, a: Int, b: Int) { + let _ = s0[a: a, b: b] + + var s0Var = s0 + s0Var[a: a, b: b] = a +}