//===--- CSClosure.cpp - Closures -----------------------------------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements constraint generation and solution application for // closures. It provides part of the implementation of the ConstraintSystem // class. // //===----------------------------------------------------------------------===// #include "MiscDiagnostics.h" #include "TypeChecker.h" #include "swift/Sema/ConstraintSystem.h" using namespace swift; using namespace swift::constraints; namespace { // Produce an implicit empty tuple expression. Expr *getVoidExpr(ASTContext &ctx) { auto *voidExpr = TupleExpr::createEmpty(ctx, /*LParenLoc=*/SourceLoc(), /*RParenLoc=*/SourceLoc(), /*Implicit=*/true); voidExpr->setType(ctx.TheEmptyTupleType); return voidExpr; } /// Find any type variable references inside of an AST node. class TypeVariableRefFinder : public ASTWalker { ConstraintSystem &CS; ASTNode Parent; llvm::SmallPtrSetImpl &ReferencedVars; public: TypeVariableRefFinder( ConstraintSystem &cs, ASTNode parent, llvm::SmallPtrSetImpl &referencedVars) : CS(cs), Parent(parent), ReferencedVars(referencedVars) {} std::pair walkToExprPre(Expr *expr) override { if (auto *DRE = dyn_cast(expr)) { if (auto type = CS.getTypeIfAvailable(DRE->getDecl())) inferVariables(type); } return {true, expr}; } std::pair walkToStmtPre(Stmt *stmt) override { // Return statements have to reference outside result type // since all of them are joined by it if it's not specified // explicitly. if (isa(stmt)) { if (auto *closure = getAsExpr(Parent)) { inferVariables(CS.getClosureType(closure)->getResult()); } } return {true, stmt}; } private: void inferVariables(Type type) { type = type->getWithoutSpecifierType(); // Record the type variable itself because it has to // be in scope even when already bound. if (auto *typeVar = type->getAs()) { ReferencedVars.insert(typeVar); // It is possible that contextual type of a parameter/result // has been assigned to e.g. an anonymous or named argument // early, to facilitate closure type checking. Such a // type can have type variables inside e.g. // // func test(_: (UnsafePointer) -> Void) {} // // test { ptr in // ... // } // // Type variable representing `ptr` in the body of // this closure would be bound to `UnsafePointer<$T>` // in this case, where `$T` is a type variable for a // generic parameter `T`. type = CS.getFixedTypeRecursive(typeVar, /*wantRValue=*/false); if (type->isEqual(typeVar)) return; } if (type->hasTypeVariable()) { SmallPtrSet typeVars; type->getTypeVariables(typeVars); ReferencedVars.insert(typeVars.begin(), typeVars.end()); } } }; /// Find any references to not yet resolved outer closure parameters /// used in the body of the inner closure. This is required because /// isolated conjunctions, just like single-expression closures, have /// to be connected to type variables they are going to use, otherwise /// they'll get placed in a separate solver component and would never /// produce a solution. class UnresolvedClosureParameterCollector : public ASTWalker { ConstraintSystem &CS; llvm::SmallSetVector Vars; public: UnresolvedClosureParameterCollector(ConstraintSystem &cs) : CS(cs) {} std::pair walkToExprPre(Expr *expr) override { if (auto *DRE = dyn_cast(expr)) { auto *decl = DRE->getDecl(); if (isa(decl)) { if (auto type = CS.getTypeIfAvailable(decl)) { if (auto *typeVar = type->getAs()) { Vars.insert(typeVar); } else if (type->hasTypeVariable()) { // Parameter or result type could be only partially // resolved e.g. `{ (x: X) -> Void in ... }` where // `X` is a generic type. SmallPtrSet typeVars; type->getTypeVariables(typeVars); Vars.insert(typeVars.begin(), typeVars.end()); } } } } return {true, expr}; } ArrayRef getVariables() const { return Vars.getArrayRef(); } }; // MARK: Constraint generation /// Check whether it makes sense to convert this element into a constrant. static bool isViableElement(ASTNode element) { if (auto *decl = element.dyn_cast()) { // - Ignore variable declarations, they are handled by pattern bindings; // - Ignore #if, the chosen children should appear in the // surrounding context; // - Skip #warning and #error, they are handled during solution // application. if (isa(decl) || isa(decl) || isa(decl)) return false; } if (auto *stmt = element.dyn_cast()) { // Empty brace statements are now viable because they do not require // inference. if (auto *braceStmt = dyn_cast(stmt)) { return braceStmt->getNumElements() > 0; } } return true; } using ElementInfo = std::tuple; static void createConjunction(ConstraintSystem &cs, ArrayRef elements, ConstraintLocator *locator) { bool isIsolated = false; SmallVector constraints; SmallVector referencedVars; if (locator->directlyAt()) { auto *closure = castToExpr(locator->getAnchor()); // Conjunction associated with the body of the closure has to // reference a type variable representing closure type, // otherwise it would get disconnected from its contextual type. referencedVars.push_back(cs.getType(closure)->castTo()); // Body of the closure is always isolated from its context, only // its individual elements are allowed access to type information // from the ouside e.g. parameters/result type. isIsolated = true; } UnresolvedClosureParameterCollector paramCollector(cs); for (const auto &entry : elements) { ASTNode element = std::get<0>(entry); ContextualTypeInfo context = std::get<1>(entry); bool isDiscarded = std::get<2>(entry); ConstraintLocator *elementLoc = std::get<3>(entry); if (!isViableElement(element)) continue; // If this conjunction going to represent a body of a closure, // let's collect references to not yet resolved outer // closure parameters. if (isIsolated) element.walk(paramCollector); constraints.push_back(Constraint::createClosureBodyElement( cs, element, context, elementLoc, isDiscarded)); } // It's possible that there are no viable elements in the body, // because e.g. whole body is an `#if` statement or it only has // declarations that are checked during solution application. // In such cases, let's avoid creating a conjunction. if (constraints.empty()) return; for (auto *externalVar : paramCollector.getVariables()) referencedVars.push_back(externalVar); cs.addUnsolvedConstraint(Constraint::createConjunction( cs, constraints, isIsolated, locator, referencedVars)); } ElementInfo makeElement(ASTNode node, ConstraintLocator *locator, ContextualTypeInfo context = ContextualTypeInfo(), bool isDiscarded = false) { return std::make_tuple(node, context, isDiscarded, locator); } static ProtocolDecl *getSequenceProtocol(ASTContext &ctx, SourceLoc loc, bool inAsyncContext) { return TypeChecker::getProtocol(ctx, loc, inAsyncContext ? KnownProtocolKind::AsyncSequence : KnownProtocolKind::Sequence); } /// Statement visitor that generates constraints for a given closure body. class ClosureConstraintGenerator : public StmtVisitor { friend StmtVisitor; ConstraintSystem &cs; ClosureExpr *closure; ConstraintLocator *locator; public: /// Whether an error was encountered while generating constraints. bool hadError = false; ClosureConstraintGenerator(ConstraintSystem &cs, ClosureExpr *closure, ConstraintLocator *locator) : cs(cs), closure(closure), locator(locator) {} void visitPattern(Pattern *pattern, ContextualTypeInfo context) { auto parentElement = locator->getLastElementAs(); if (!parentElement) { hadError = true; return; } if (auto *stmt = parentElement->getElement().dyn_cast()) { if (isa(stmt)) { visitForEachPattern(pattern, cast(stmt)); return; } if (isa(stmt)) { visitCaseItemPattern(pattern, context); return; } } llvm_unreachable("Unsupported pattern"); } void visitCaseItem(CaseLabelItem *caseItem, ContextualTypeInfo context) { assert(context.purpose == CTP_CaseStmt); // Resolve the pattern. auto *pattern = caseItem->getPattern(); if (!caseItem->isPatternResolved()) { pattern = TypeChecker::resolvePattern(pattern, closure, /*isStmtCondition=*/false); if (!pattern) { hadError = true; return; } } // Let's generate constraints for pattern + where clause. // The assumption is that this shouldn't be too complex // to handle, but if it turns out to be false, this could // always be converted into a conjunction. // Generate constraints for pattern. visitPattern(pattern, context); auto *guardExpr = caseItem->getGuardExpr(); // Generate constraints for `where` clause (if any). if (guardExpr) { guardExpr = cs.generateConstraints(guardExpr, closure); if (!guardExpr) { hadError = true; return; } } // Save information about case item so it could be referenced during // solution application. cs.setCaseLabelItemInfo(caseItem, {pattern, guardExpr}); } private: /// This method handles both pattern and the sequence expression /// associated with `for-in` loop because types in this situation /// flow in both directions: /// /// - From pattern to sequence, informing its element type e.g. /// `for i: Int8 in 0 ..< 8` /// /// - From sequence to pattern, when pattern has no type information. void visitForEachPattern(Pattern *pattern, ForEachStmt *forEachStmt) { auto &ctx = cs.getASTContext(); bool isAsync = forEachStmt->getAwaitLoc().isValid(); // Verify pattern. { auto contextualPattern = ContextualPattern::forRawPattern(pattern, closure); Type patternType = TypeChecker::typeCheckPattern(contextualPattern); if (patternType->hasError()) { hadError = true; return; } } auto *sequenceProto = getSequenceProtocol(ctx, forEachStmt->getForLoc(), isAsync); if (!sequenceProto) { hadError = true; return; } auto *contextualLocator = cs.getConstraintLocator( locator, LocatorPathElt::ContextualType(CTP_ForEachStmt)); // Generate constraints to initialize the pattern. auto initType = cs.generateConstraints(pattern, contextualLocator, /*shouldBindPatternOneWay=*/true, /*patternBinding=*/nullptr, /*patternIndex=*/0); if (!initType) { hadError = true; return; } // Let's generate constraints for sequence associated with `for-in` // statement. We can't do that separately because pattern can inform // a type of the sequence e.g. `for in i: Int8 in 0 ..< 8 { ... }` auto *sequenceExpr = forEachStmt->getSequence(); auto *sequenceLocator = cs.getConstraintLocator(sequenceExpr); { SolutionApplicationTarget target( sequenceExpr, closure, CTP_ForEachSequence, sequenceProto->getDeclaredInterfaceType(), /*isDiscarded=*/false); if (cs.generateConstraints(target, FreeTypeVariableBinding::Disallow)) { hadError = true; return; } cs.setSolutionApplicationTarget(sequenceExpr, target); } Type sequenceType = cs.createTypeVariable(sequenceLocator, TVO_CanBindToNoEscape); // This "workaround" warrants an explanation for posterity. // // The reason why we can't simplify use \c getType(sequenceExpr) here // is due to how dependent member types are handled by \c simplifyTypeImpl // - if the base type didn't change (and it wouldn't because it's a fully // resolved concrete type) after simplification attempt the // whole dependent member type would be just re-created without attempting // to resolve it, so we have to use an intermediary here so that // \c elementType and \c iteratorType can be resolved correctly. cs.addConstraint(ConstraintKind::Conversion, cs.getType(sequenceExpr), sequenceType, sequenceLocator); auto elementAssocType = sequenceProto->getAssociatedType(ctx.Id_Element); Type elementType = DependentMemberType::get(sequenceType, elementAssocType); auto iteratorAssocType = sequenceProto->getAssociatedType( isAsync ? ctx.Id_AsyncIterator : ctx.Id_Iterator); Type iteratorType = DependentMemberType::get(sequenceType, iteratorAssocType); cs.addConstraint( ConstraintKind::Conversion, elementType, initType, cs.getConstraintLocator(contextualLocator, ConstraintLocator::SequenceElementType)); // Reference the makeIterator witness. FuncDecl *makeIterator = isAsync ? ctx.getAsyncSequenceMakeAsyncIterator() : ctx.getSequenceMakeIterator(); Type makeIteratorType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape); cs.addValueWitnessConstraint(LValueType::get(sequenceType), makeIterator, makeIteratorType, closure, FunctionRefKind::Compound, contextualLocator); // After successful constraint generation, let's record // solution application target with all relevant information. { auto target = SolutionApplicationTarget::forForEachStmt( forEachStmt, sequenceProto, closure, /*bindTypeVarsOneWay=*/true, /*contextualPurpose=*/CTP_ForEachSequence); auto &targetInfo = target.getForEachStmtInfo(); targetInfo.sequenceType = sequenceType; targetInfo.elementType = elementType; targetInfo.iteratorType = iteratorType; targetInfo.initType = initType; target.setPattern(pattern); cs.setSolutionApplicationTarget(forEachStmt, target); } } void visitCaseItemPattern(Pattern *pattern, ContextualTypeInfo context) { Type patternType = cs.generateConstraints(pattern, locator, /*bindPatternVarsOneWay=*/true, /*patternBinding=*/nullptr, /*patternIndex=*/0); if (!patternType) { hadError = true; return; } // Convert the contextual type to the pattern, which establishes the // bindings. cs.addConstraint(ConstraintKind::Conversion, context.getType(), patternType, locator); // For any pattern variable that has a parent variable (i.e., another // pattern variable with the same name in the same case), require that // the types be equivalent. pattern->forEachNode([&](Pattern *pattern) { auto namedPattern = dyn_cast(pattern); if (!namedPattern) return; auto var = namedPattern->getDecl(); if (auto parentVar = var->getParentVarDecl()) { cs.addConstraint( ConstraintKind::Equal, cs.getType(parentVar), cs.getType(var), cs.getConstraintLocator( locator, LocatorPathElt::PatternMatch(namedPattern))); } }); } void visitDecl(Decl *decl) { if (isSupportedMultiStatementClosure()) { if (auto patternBinding = dyn_cast(decl)) { SolutionApplicationTarget target(patternBinding); if (cs.generateConstraints(target, FreeTypeVariableBinding::Disallow)) hadError = true; return; } } // Just ignore #if; the chosen children should appear in the // surrounding context. This isn't good for source tools but it // at least works. if (isa(decl)) return; // Skip #warning/#error; we'll handle them when applying the closure. if (isa(decl)) return; // Ignore variable declarations, because they're always handled within // their enclosing pattern bindings. if (isa(decl)) return; // Other declarations will be handled at application time. } void visitBreakStmt(BreakStmt *breakStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: Break"); } void visitContinueStmt(ContinueStmt *continueStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: Continue"); } void visitDeferStmt(DeferStmt *deferStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: Defer"); } void visitFallthroughStmt(FallthroughStmt *fallthroughStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: Fallthrough"); } void visitIfStmt(IfStmt *ifStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: If"); SmallVector elements; // Condition { auto *condLoc = cs.getConstraintLocator(locator, ConstraintLocator::Condition); elements.push_back(makeElement(ifStmt->getCondPointer(), condLoc)); } // Then Branch { auto *thenLoc = cs.getConstraintLocator( locator, LocatorPathElt::TernaryBranch(/*then=*/true)); elements.push_back(makeElement(ifStmt->getThenStmt(), thenLoc)); } // Else Branch (if any). if (auto *elseStmt = ifStmt->getElseStmt()) { auto *elseLoc = cs.getConstraintLocator( locator, LocatorPathElt::TernaryBranch(/*then=*/false)); elements.push_back(makeElement(ifStmt->getElseStmt(), elseLoc)); } createConjunction(cs, elements, locator); } void visitGuardStmt(GuardStmt *guardStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: Guard"); createConjunction(cs, {makeElement(guardStmt->getCondPointer(), cs.getConstraintLocator( locator, ConstraintLocator::Condition)), makeElement(guardStmt->getBody(), locator)}, locator); } void visitWhileStmt(WhileStmt *whileStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: While"); createConjunction(cs, {makeElement(whileStmt->getCondPointer(), cs.getConstraintLocator( locator, ConstraintLocator::Condition)), makeElement(whileStmt->getBody(), locator)}, locator); } void visitDoStmt(DoStmt *doStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: Do"); visitBraceStmt(doStmt->getBody()); } void visitRepeatWhileStmt(RepeatWhileStmt *repeatWhileStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: RepeatWhile"); createConjunction(cs, {makeElement(repeatWhileStmt->getCond(), cs.getConstraintLocator( locator, ConstraintLocator::Condition), getContextForCondition()), makeElement(repeatWhileStmt->getBody(), locator)}, locator); } void visitPoundAssertStmt(PoundAssertStmt *poundAssertStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: PoundAssert"); createConjunction(cs, {makeElement(poundAssertStmt->getCondition(), cs.getConstraintLocator( locator, ConstraintLocator::Condition), getContextForCondition())}, locator); } void visitThrowStmt(ThrowStmt *throwStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: Throw"); Type errType = cs.getASTContext().getErrorDecl()->getDeclaredInterfaceType(); if (!errType) { hadError = true; return; } auto *errorExpr = throwStmt->getSubExpr(); createConjunction( cs, {makeElement( errorExpr, cs.getConstraintLocator( locator, LocatorPathElt::ClosureBodyElement(errorExpr)), {errType, CTP_ThrowStmt})}, locator); } void visitForEachStmt(ForEachStmt *forEachStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: ForEach"); auto *stmtLoc = cs.getConstraintLocator(locator); SmallVector elements; // For-each pattern. // // Note that we don't record a sequence here, it would // be handled together with pattern because pattern can // inform a type of sequence element e.g. `for i: Int8 in 0 ..< 8` { Pattern *pattern = TypeChecker::resolvePattern(forEachStmt->getPattern(), closure, /*isStmtCondition=*/false); if (!pattern) { hadError = true; return; } elements.push_back(makeElement(pattern, stmtLoc)); } // `where` clause if any. if (auto *whereClause = forEachStmt->getWhere()) { elements.push_back( makeElement(whereClause, stmtLoc, getContextForCondition())); } // Body of the `for-in` loop. elements.push_back(makeElement(forEachStmt->getBody(), stmtLoc)); createConjunction(cs, elements, locator); } void visitSwitchStmt(SwitchStmt *switchStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: Switch"); auto *switchLoc = cs.getConstraintLocator( locator, LocatorPathElt::ClosureBodyElement(switchStmt)); SmallVector elements; { auto *subjectExpr = switchStmt->getSubjectExpr(); { elements.push_back(makeElement(subjectExpr, switchLoc)); SolutionApplicationTarget target(subjectExpr, closure, CTP_Unused, Type(), /*isDiscarded=*/false); cs.setSolutionApplicationTarget(switchStmt, target); } for (auto rawCase : switchStmt->getRawCases()) elements.push_back(makeElement(rawCase, switchLoc)); } createConjunction(cs, elements, switchLoc); } void visitDoCatchStmt(DoCatchStmt *doStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: DoCatch"); auto *doLoc = cs.getConstraintLocator( locator, LocatorPathElt::ClosureBodyElement(doStmt)); SmallVector elements; // First, let's record a body of `do` statement. elements.push_back(makeElement(doStmt->getBody(), doLoc)); // After that has been type-checked, let's switch to // individual `catch` statements. for (auto *catchStmt : doStmt->getCatches()) elements.push_back(makeElement(catchStmt, doLoc)); createConjunction(cs, elements, doLoc); } void visitCaseStmt(CaseStmt *caseStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: Case"); Type contextualTy; { auto parent = locator->castLastElementTo() .getElement(); if (parent.isStmt(StmtKind::Switch)) { auto *switchStmt = cast(parent.get()); contextualTy = cs.getType(switchStmt->getSubjectExpr()); } else if (parent.isStmt(StmtKind::DoCatch)) { contextualTy = cs.getASTContext().getExceptionType(); } else { hadError = true; return; } } bindSwitchCasePatternVars(closure, caseStmt); auto *caseLoc = cs.getConstraintLocator( locator, LocatorPathElt::ClosureBodyElement(caseStmt)); SmallVector elements; for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) { elements.push_back( makeElement(&caseLabelItem, caseLoc, {contextualTy, CTP_CaseStmt})); } elements.push_back(makeElement(caseStmt->getBody(), caseLoc)); createConjunction(cs, elements, caseLoc); } void visitBraceStmt(BraceStmt *braceStmt) { if (isSupportedMultiStatementClosure()) { auto &ctx = cs.getASTContext(); if (isChildOf(StmtKind::Case)) { auto *caseStmt = cast( locator->castLastElementTo() .asStmt()); for (auto caseBodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) { auto parentVar = caseBodyVar->getParentVarDecl(); assert(parentVar && "Case body variables always have parents"); cs.setType(caseBodyVar, cs.getType(parentVar)); } } SmallVector elements; for (auto element : braceStmt->getElements()) { bool isDiscarded = element.is() && (!ctx.LangOpts.Playground && !ctx.LangOpts.DebuggerSupport); elements.push_back(makeElement( element, cs.getConstraintLocator( locator, LocatorPathElt::ClosureBodyElement(element)), /*contextualInfo=*/{}, isDiscarded)); } createConjunction(cs, elements, locator); return; } for (auto node : braceStmt->getElements()) { if (auto expr = node.dyn_cast()) { auto generatedExpr = cs.generateConstraints( expr, closure, /*isInputExpression=*/false); if (!generatedExpr) { hadError = true; } } else if (auto stmt = node.dyn_cast()) { visit(stmt); } else { visitDecl(node.get()); } } } void visitReturnStmt(ReturnStmt *returnStmt) { auto contextualTy = cs.getClosureType(closure)->getResult(); // Single-expression closures are effectively a `return` statement, // so let's give them a special locator as to indicate that. if (closure->hasSingleExpressionBody()) { auto *expr = returnStmt->getResult(); assert(expr && "single expression closure without expression?"); expr = cs.generateConstraints(expr, closure, /*isInputExpression=*/false); if (!expr) { hadError = true; return; } cs.addConstraint( ConstraintKind::Conversion, cs.getType(expr), contextualTy, cs.getConstraintLocator( closure, LocatorPathElt::ClosureBody( /*hasReturn=*/!returnStmt->isImplicit()))); return; } Expr *resultExpr; if (returnStmt->hasResult()) { resultExpr = returnStmt->getResult(); assert(resultExpr && "non-empty result without expression?"); } else { // If this is simplify `return`, let's create an empty tuple // which is also useful if contextual turns out to be e.g. `Void?`. resultExpr = getVoidExpr(closure->getASTContext()); } SolutionApplicationTarget target(resultExpr, closure, CTP_ReturnStmt, contextualTy, /*isDiscarded=*/false); if (cs.generateConstraints(target, FreeTypeVariableBinding::Disallow)) { hadError = true; return; } cs.setContextualType(target.getAsExpr(), TypeLoc::withoutLoc(contextualTy), CTP_ReturnStmt); cs.setSolutionApplicationTarget(returnStmt, target); } bool isSupportedMultiStatementClosure() const { return !closure->hasSingleExpressionBody() && cs.participatesInInference(closure); } #define UNSUPPORTED_STMT(STMT) void visit##STMT##Stmt(STMT##Stmt *) { \ llvm_unreachable("Unsupported statement kind " #STMT); \ } UNSUPPORTED_STMT(Yield) UNSUPPORTED_STMT(Fail) #undef UNSUPPORTED_STMT private: ContextualTypeInfo getContextForCondition() const { auto boolDecl = cs.getASTContext().getBoolDecl(); assert(boolDecl && "Bool is missing"); return {boolDecl->getDeclaredInterfaceType(), CTP_Condition}; } bool isChildOf(StmtKind kind) { if (locator->getPath().empty()) return false; auto parentElt = locator->getLastElementAs(); return parentElt ? parentElt->getElement().isStmt(kind) : false; } }; } bool ConstraintSystem::generateConstraints(ClosureExpr *closure) { auto &ctx = closure->getASTContext(); if (participatesInInference(closure)) { ClosureConstraintGenerator generator(*this, closure, getConstraintLocator(closure)); generator.visit(closure->getBody()); if (closure->hasSingleExpressionBody()) return generator.hadError; } // If this closure has an empty body and no explicit result type // let's bind result type to `Void` since that's the only type empty body // can produce. Otherwise, if (multi-statement) closure doesn't have // an explicit result (no `return` statements) let's default it to `Void`. if (!hasExplicitResult(closure)) { auto constraintKind = (closure->hasEmptyBody() && !closure->hasExplicitResultType()) ? ConstraintKind::Bind : ConstraintKind::Defaultable; addConstraint( constraintKind, getClosureType(closure)->getResult(), ctx.TheEmptyTupleType, getConstraintLocator(closure, ConstraintLocator::ClosureResult)); } return false; } bool ConstraintSystem::isInResultBuilderContext(ClosureExpr *closure) const { if (!closure->hasSingleExpressionBody()) { auto *DC = closure->getParent(); do { if (auto *parentClosure = dyn_cast(DC)) { if (resultBuilderTransformed.count(parentClosure)) return true; } } while ((DC = DC->getParent())); } return false; } bool isConditionOfStmt(ConstraintLocatorBuilder locator) { auto last = locator.last(); if (!(last && last->is())) return false; SmallVector path; (void)locator.getLocatorParts(path); path.pop_back(); if (path.empty()) return false; if (auto closureElt = path.back().getAs()) return closureElt->getElement().dyn_cast(); return false; } ConstraintSystem::SolutionKind ConstraintSystem::simplifyClosureBodyElementConstraint( ASTNode element, ContextualTypeInfo context, bool isDiscarded, TypeMatchOptions flags, ConstraintLocatorBuilder locator) { auto *closure = castToExpr(locator.getAnchor()); ClosureConstraintGenerator generator(*this, closure, getConstraintLocator(locator)); if (auto *expr = element.dyn_cast()) { SolutionApplicationTarget target(expr, closure, context.purpose, context.getType(), isDiscarded); if (generateConstraints(target, FreeTypeVariableBinding::Disallow)) return SolutionKind::Error; setSolutionApplicationTarget(expr, target); return SolutionKind::Solved; } else if (auto *stmt = element.dyn_cast()) { generator.visit(stmt); } else if (auto *cond = element.dyn_cast()) { if (generateConstraints(*cond, closure)) return SolutionKind::Error; } else if (auto *pattern = element.dyn_cast()) { generator.visitPattern(pattern, context); } else if (auto *caseItem = element.dyn_cast()) { generator.visitCaseItem(caseItem, context); } else { generator.visit(element.get()); } return generator.hadError ? SolutionKind::Error : SolutionKind::Solved; } // MARK: Solution application namespace { /// Statement visitor that applies constraints for a given closure body. class ClosureConstraintApplication : public StmtVisitor { friend StmtVisitor; Solution &solution; ClosureExpr *closure; Type resultType; RewriteTargetFn rewriteTarget; bool isSingleExpression; public: /// Whether an error was encountered while generating constraints. bool hadError = false; ClosureConstraintApplication( Solution &solution, ClosureExpr *closure, Type resultType, RewriteTargetFn rewriteTarget) : solution(solution), closure(closure), resultType(resultType), rewriteTarget(rewriteTarget), isSingleExpression(closure->hasSingleExpressionBody()) { } private: /// Rewrite an expression without any particularly special context. Expr *rewriteExpr(Expr *expr) { auto result = rewriteTarget( SolutionApplicationTarget(expr, closure, CTP_Unused, Type(), /*isDiscarded=*/false)); if (result) return result->getAsExpr(); return nullptr; } void visitDecl(Decl *decl) { if (isa(decl)) return; // Variable declaration would be handled by a pattern binding. if (isa(decl)) return; // Generate constraints for pattern binding declarations. if (auto patternBinding = dyn_cast(decl)) { SolutionApplicationTarget target(patternBinding); if (!rewriteTarget(target)) { hadError = true; return; } // Allow `typeCheckDecl` to be called after solution is applied // to a pattern binding. That would materialize required // information e.g. accessors and do access/availability checks. } TypeChecker::typeCheckDecl(decl); } ASTNode visitBreakStmt(BreakStmt *breakStmt) { if (auto target = findBreakOrContinueStmtTarget( closure->getASTContext(), closure->getParentSourceFile(), breakStmt->getLoc(), breakStmt->getTargetName(), breakStmt->getTargetLoc(), /*isContinue=*/false, closure)) { breakStmt->setTarget(target); } return breakStmt; } ASTNode visitContinueStmt(ContinueStmt *continueStmt) { if (auto target = findBreakOrContinueStmtTarget( closure->getASTContext(), closure->getParentSourceFile(), continueStmt->getLoc(), continueStmt->getTargetName(), continueStmt->getTargetLoc(), /*isContinue=*/true, closure)) { continueStmt->setTarget(target); } return continueStmt; } ASTNode visitFallthroughStmt(FallthroughStmt *fallthroughStmt) { if (checkFallthroughStmt(closure, fallthroughStmt)) hadError = true; return fallthroughStmt; } ASTNode visitDeferStmt(DeferStmt *deferStmt) { TypeChecker::typeCheckDecl(deferStmt->getTempDecl()); Expr *theCall = deferStmt->getCallExpr(); TypeChecker::typeCheckExpression(theCall, closure); deferStmt->setCallExpr(theCall); return deferStmt; } ASTNode visitIfStmt(IfStmt *ifStmt) { // Rewrite the condition. if (auto condition = rewriteTarget( SolutionApplicationTarget(ifStmt->getCond(), closure))) ifStmt->setCond(*condition->getAsStmtCondition()); else hadError = true; ifStmt->setThenStmt(visit(ifStmt->getThenStmt()).get()); if (auto elseStmt = ifStmt->getElseStmt()) { ifStmt->setElseStmt(visit(elseStmt).get()); } return ifStmt; } ASTNode visitGuardStmt(GuardStmt *guardStmt) { if (auto condition = rewriteTarget( SolutionApplicationTarget(guardStmt->getCond(), closure))) guardStmt->setCond(*condition->getAsStmtCondition()); else hadError = true; auto *body = visit(guardStmt->getBody()).get(); guardStmt->setBody(cast(body)); return guardStmt; } ASTNode visitWhileStmt(WhileStmt *whileStmt) { if (auto condition = rewriteTarget( SolutionApplicationTarget(whileStmt->getCond(), closure))) whileStmt->setCond(*condition->getAsStmtCondition()); else hadError = true; auto *body = visit(whileStmt->getBody()).get(); whileStmt->setBody(cast(body)); return whileStmt; } ASTNode visitDoStmt(DoStmt *doStmt) { auto body = visit(doStmt->getBody()).get(); doStmt->setBody(cast(body)); return doStmt; } ASTNode visitRepeatWhileStmt(RepeatWhileStmt *repeatWhileStmt) { auto body = visit(repeatWhileStmt->getBody()).get(); repeatWhileStmt->setBody(cast(body)); // Rewrite the condition. auto &cs = solution.getConstraintSystem(); auto target = *cs.getSolutionApplicationTarget(repeatWhileStmt->getCond()); if (auto condition = rewriteTarget(target)) repeatWhileStmt->setCond(condition->getAsExpr()); else hadError = true; return repeatWhileStmt; } ASTNode visitPoundAssertStmt(PoundAssertStmt *poundAssertStmt) { // FIXME: This should be done through \c solution instead of // constraint system. auto &cs = solution.getConstraintSystem(); // Rewrite the condition. auto target = *cs.getSolutionApplicationTarget(poundAssertStmt->getCondition()); if (auto result = rewriteTarget(target)) poundAssertStmt->setCondition(result->getAsExpr()); else hadError = true; return poundAssertStmt; } ASTNode visitThrowStmt(ThrowStmt *throwStmt) { auto &cs = solution.getConstraintSystem(); // Rewrite the error. auto target = *cs.getSolutionApplicationTarget(throwStmt->getSubExpr()); if (auto result = rewriteTarget(target)) throwStmt->setSubExpr(result->getAsExpr()); else hadError = true; return throwStmt; } ASTNode visitForEachStmt(ForEachStmt *forEachStmt) { ConstraintSystem &cs = solution.getConstraintSystem(); auto forEachTarget = rewriteTarget(*cs.getSolutionApplicationTarget(forEachStmt)); if (!forEachTarget) hadError = true; auto body = visit(forEachStmt->getBody()).get(); forEachStmt->setBody(cast(body)); // Check to see if the sequence expr is throwing (in async context), // if so require the stmt to have a `try`. hadError |= diagnoseUnhandledThrowsInAsyncContext(closure, forEachStmt); return forEachStmt; } ASTNode visitSwitchStmt(SwitchStmt *switchStmt) { ConstraintSystem &cs = solution.getConstraintSystem(); // Rewrite the switch subject. auto subjectTarget = rewriteTarget(*cs.getSolutionApplicationTarget(switchStmt)); if (subjectTarget) { switchStmt->setSubjectExpr(subjectTarget->getAsExpr()); } else { hadError = true; } // Visit the raw cases. bool limitExhaustivityChecks = false; for (auto rawCase : switchStmt->getRawCases()) { if (auto decl = rawCase.dyn_cast()) { visitDecl(decl); continue; } auto caseStmt = cast(rawCase.get()); visitCaseStmt(caseStmt); // Check restrictions on '@unknown'. if (caseStmt->hasUnknownAttr()) { checkUnknownAttrRestrictions(cs.getASTContext(), caseStmt, limitExhaustivityChecks); } } TypeChecker::checkSwitchExhaustiveness(switchStmt, closure, limitExhaustivityChecks); return switchStmt; } ASTNode visitDoCatchStmt(DoCatchStmt *doStmt) { // Translate the body. auto newBody = visit(doStmt->getBody()); doStmt->setBody(newBody.get()); // Visit the catch blocks. for (auto catchStmt : doStmt->getCatches()) visitCaseStmt(catchStmt); return doStmt; } ASTNode visitCaseStmt(CaseStmt *caseStmt) { // Translate the patterns and guard expressions for each case label item. for (auto &caseItem : caseStmt->getMutableCaseLabelItems()) { SolutionApplicationTarget caseTarget(&caseItem, closure); if (!rewriteTarget(caseTarget)) { hadError = true; } } for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray()) { assert(expected->hasName()); auto prev = expected->getParentVarDecl(); auto type = solution.resolveInterfaceType(solution.getType(prev)); expected->setInterfaceType(type); } // Translate the body. auto *newBody = visit(caseStmt->getBody()).get(); caseStmt->setBody(cast(newBody)); return caseStmt; } ASTNode visitBraceStmt(BraceStmt *braceStmt) { auto &cs = solution.getConstraintSystem(); // Diagnose defer statement being last one in block. if (!braceStmt->empty()) { if (auto stmt = braceStmt->getLastElement().dyn_cast()) { if (auto deferStmt = dyn_cast(stmt)) { auto &diags = closure->getASTContext().Diags; diags .diagnose(deferStmt->getStartLoc(), diag::defer_stmt_at_block_end) .fixItReplace(deferStmt->getStartLoc(), "do"); } } } for (auto &node : braceStmt->getElements()) { if (auto expr = node.dyn_cast()) { // Rewrite the expression. auto target = *cs.getSolutionApplicationTarget(expr); if (auto rewrittenTarget = rewriteTarget(target)) { node = rewrittenTarget->getAsExpr(); if (target.isDiscardedExpr()) TypeChecker::checkIgnoredExpr(castToExpr(node)); } else { hadError = true; } } else if (auto stmt = node.dyn_cast()) { node = visit(stmt); } else { visitDecl(node.get()); } } // Source compatibility workaround. // // func test(_: () -> T?) { // ... // } // // A multi-statement closure passed to `test` that has an optional // `Void` result type inferred from the body allows: // - empty `return`(s); // - to skip `return nil` or `return ()` at the end. // // Implicit `return ()` has to be inserted as the last element // of the body if there is none. This wasn't needed before SE-0326 // because result type was (incorrectly) inferred as `Void` due to // the body being skipped. if (!closure->hasSingleExpressionBody() && closure->getBody() == braceStmt) { if (resultType->getOptionalObjectType() && resultType->lookThroughAllOptionalTypes()->isVoid() && !braceStmt->getLastElement().isStmt(StmtKind::Return)) { return addImplicitVoidReturn(braceStmt); } } return braceStmt; } ASTNode addImplicitVoidReturn(BraceStmt *braceStmt) { auto &ctx = closure->getASTContext(); auto &cs = solution.getConstraintSystem(); auto *resultExpr = getVoidExpr(ctx); cs.cacheExprTypes(resultExpr); auto *returnStmt = new (ctx) ReturnStmt(SourceLoc(), resultExpr, /*implicit=*/true); // For a target for newly created result and apply a solution // to it, to make sure that optional injection happens required // number of times. { SolutionApplicationTarget target(resultExpr, closure, CTP_ReturnStmt, resultType, /*isDiscarded=*/false); cs.setSolutionApplicationTarget(returnStmt, target); visitReturnStmt(returnStmt); } // Re-create brace statement with an additional `return` at the end. SmallVector elements; elements.append(braceStmt->getElements().begin(), braceStmt->getElements().end()); elements.push_back(returnStmt); return BraceStmt::create(ctx, braceStmt->getLBraceLoc(), elements, braceStmt->getRBraceLoc()); } ASTNode visitReturnStmt(ReturnStmt *returnStmt) { if (!returnStmt->hasResult()) { // If contextual is not optional, there is nothing to do here. if (resultType->isVoid()) return returnStmt; // It's possible to infer e.g. `Void?` for cases where // `return` doesn't have an expression. If contextual // type is `Void` wrapped into N optional types, let's // add an implicit `()` expression and let it be injected // into optional required number of times. assert(resultType->getOptionalObjectType() && resultType->lookThroughAllOptionalTypes()->isVoid()); auto &cs = solution.getConstraintSystem(); auto target = *cs.getSolutionApplicationTarget(returnStmt); returnStmt->setResult(target.getAsExpr()); } auto *resultExpr = returnStmt->getResult(); enum { convertToResult, coerceToVoid, coerceFromNever, } mode; auto resultExprType = solution.simplifyType(solution.getType(resultExpr))->getRValueType(); // A closure with a non-void return expression can coerce to a closure // that returns Void. if (resultType->isVoid() && !resultExprType->isVoid()) { mode = coerceToVoid; // A single-expression closure with a Never expression type // coerces to any other function type. } else if (isSingleExpression && resultExprType->isUninhabited()) { mode = coerceFromNever; // Normal rule is to coerce to the return expression to the closure type. } else { mode = convertToResult; } SolutionApplicationTarget resultTarget( resultExpr, closure, mode == convertToResult ? CTP_ReturnStmt : CTP_Unused, mode == convertToResult ? resultType : Type(), /*isDiscarded=*/false); if (auto newResultTarget = rewriteTarget(resultTarget)) resultExpr = newResultTarget->getAsExpr(); switch (mode) { case convertToResult: // Record the coerced expression. returnStmt->setResult(resultExpr); return returnStmt; case coerceToVoid: { // Evaluate the expression, then produce a return statement that // returns nothing. TypeChecker::checkIgnoredExpr(resultExpr); auto &ctx = solution.getConstraintSystem().getASTContext(); auto newReturnStmt = new (ctx) ReturnStmt( returnStmt->getStartLoc(), nullptr, /*implicit=*/true); ASTNode elements[2] = { resultExpr, newReturnStmt }; return BraceStmt::create(ctx, returnStmt->getStartLoc(), elements, returnStmt->getEndLoc(), /*implicit*/ true); } case coerceFromNever: // Replace the return statement with its expression, so that the // expression is evaluated directly. This only works because coercion // from never is limited to single-expression closures. return resultExpr; } return returnStmt; } #define UNSUPPORTED_STMT(STMT) ASTNode visit##STMT##Stmt(STMT##Stmt *) { \ llvm_unreachable("Unsupported statement kind " #STMT); \ } UNSUPPORTED_STMT(Yield) UNSUPPORTED_STMT(Fail) #undef UNSUPPORTED_STMT }; } SolutionApplicationToFunctionResult ConstraintSystem::applySolution( Solution &solution, AnyFunctionRef fn, DeclContext *¤tDC, RewriteTargetFn rewriteTarget) { auto &cs = solution.getConstraintSystem(); auto closure = dyn_cast_or_null(fn.getAbstractClosureExpr()); FunctionType *closureFnType = nullptr; if (closure) { // Update the closure's type. auto closureType = solution.simplifyType(cs.getType(closure)); cs.setType(closure, closureType); // Coerce the parameter types. closureFnType = closureType->castTo(); auto *params = closure->getParameters(); TypeChecker::coerceParameterListToType(params, closureFnType); // Find any isolated parameters in this closure and mark them as isolated. for (auto param : solution.isolatedParams) { if (param->getDeclContext() == closure) param->setIsolated(true); } // Coerce the result type, if it was written explicitly. if (closure->hasExplicitResultType()) { closure->setExplicitResultType(closureFnType->getResult()); } } // Enter the context of the function before performing any additional // transformations. llvm::SaveAndRestore savedDC(currentDC, fn.getAsDeclContext()); // Apply the result builder transform, if there is one. if (auto transform = solution.getAppliedBuilderTransform(fn)) { // Apply the result builder to the closure. We want to be in the // context of the closure for subsequent transforms. auto newBody = applyResultBuilderTransform( solution, *transform, fn.getBody(), fn.getAsDeclContext(), [&](SolutionApplicationTarget target) { auto resultTarget = rewriteTarget(target); if (resultTarget) { if (auto expr = resultTarget->getAsExpr()) solution.setExprTypes(expr); } return resultTarget; }); if (!newBody) return SolutionApplicationToFunctionResult::Failure; fn.setTypecheckedBody(newBody, /*isSingleExpression=*/false); if (closure) { solution.setExprTypes(closure); } return SolutionApplicationToFunctionResult::Success; } assert(closure && "Can only get here with a closure at the moment"); // If this closure is checked as part of the enclosing expression, handle // that now. // // Multi-statement closures are handled separately because they need to // wait until all of the `ExtInfo` flags are propagated from the context // e.g. parameter could be no-escape if closure is applied to a call. if (closure->hasSingleExpressionBody()) { bool hadError = applySolutionToBody(solution, closure, currentDC, rewriteTarget); return hadError ? SolutionApplicationToFunctionResult::Failure : SolutionApplicationToFunctionResult::Success; } // Otherwise, we need to delay type checking of the closure until later. solution.setExprTypes(closure); closure->setBodyState(ClosureExpr::BodyState::ReadyForTypeChecking); return SolutionApplicationToFunctionResult::Delay; } bool ConstraintSystem::applySolutionToBody(Solution &solution, ClosureExpr *closure, DeclContext *¤tDC, RewriteTargetFn rewriteTarget) { auto &cs = solution.getConstraintSystem(); // Enter the context of the function before performing any additional // transformations. llvm::SaveAndRestore savedDC(currentDC, closure); auto closureType = cs.getType(closure)->castTo(); ClosureConstraintApplication application( solution, closure, closureType->getResult(), rewriteTarget); auto body = application.visit(closure->getBody()); if (!body || application.hadError) return true; closure->setBody(cast(body.get()), closure->hasSingleExpressionBody()); closure->setBodyState(ClosureExpr::BodyState::TypeCheckedWithSignature); return false; } void ConjunctionElement::findReferencedVariables( ConstraintSystem &cs, SmallPtrSetImpl &typeVars) const { auto referencedVars = Element->getTypeVariables(); typeVars.insert(referencedVars.begin(), referencedVars.end()); if (Element->getKind() != ConstraintKind::ClosureBodyElement) return; ASTNode element = Element->getClosureElement(); auto *locator = Element->getLocator(); TypeVariableRefFinder refFinder(cs, locator->getAnchor(), typeVars); if (element.is() || element.is() || element.is() || element.isStmt(StmtKind::Return)) element.walk(refFinder); }