//===--- CSSyntacticElement.cpp - Syntactic Element Constraints -----------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2022 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements SyntacticElement constraint generation and solution // application, which is used to type-check the bodies of closures. It provides // part of the implementation of the ConstraintSystem class. // //===----------------------------------------------------------------------===// #include "MiscDiagnostics.h" #include "TypeChecker.h" #include "TypeCheckAvailability.h" #include "swift/Sema/ConstraintSystem.h" #include "swift/Sema/IDETypeChecking.h" using namespace swift; using namespace swift::constraints; namespace { // Produce an implicit empty tuple expression. Expr *getVoidExpr(ASTContext &ctx, SourceLoc contextLoc = SourceLoc()) { auto *voidExpr = TupleExpr::createEmpty(ctx, /*LParenLoc=*/contextLoc, /*RParenLoc=*/contextLoc, /*Implicit=*/true); voidExpr->setType(ctx.TheEmptyTupleType); return voidExpr; } /// Find any type variable references inside of an AST node. class TypeVariableRefFinder : public ASTWalker { /// A stack of all closures the walker encountered so far. SmallVector ClosureDCs; ConstraintSystem &CS; ASTNode Parent; llvm::SmallPtrSetImpl &ReferencedVars; public: TypeVariableRefFinder( ConstraintSystem &cs, ASTNode parent, llvm::SmallPtrSetImpl &referencedVars) : CS(cs), Parent(parent), ReferencedVars(referencedVars) { if (auto *closure = getAsExpr(Parent)) ClosureDCs.push_back(closure); } PreWalkResult walkToExprPre(Expr *expr) override { if (auto *closure = dyn_cast(expr)) { ClosureDCs.push_back(closure); } if (auto *DRE = dyn_cast(expr)) { auto *decl = DRE->getDecl(); if (auto type = CS.getTypeIfAvailable(DRE->getDecl())) { auto &ctx = CS.getASTContext(); // If this is not one of the closure parameters which // is inferrable from the body, let's replace type // variables with errors to avoid bringing external // information to the element component. if (type->hasTypeVariable() && !(isa(decl) || decl->getName() == ctx.Id_builderSelf)) { // If there are type variables left in the simplified version, // it means that this is an invalid external declaration // relative to this element's context. if (CS.simplifyType(type)->hasTypeVariable()) { auto transformedTy = type.transform([&](Type type) { if (auto *typeVar = type->getAs()) { return ErrorType::get(CS.getASTContext()); } return type; }); CS.setType(decl, transformedTy); return Action::Continue(expr); } } inferVariables(type); return Action::Continue(expr); } auto var = dyn_cast(decl); if (!var) return Action::Continue(expr); if (auto *wrappedVar = var->getOriginalWrappedProperty()) { // If there is no type it means that the body of the // closure hasn't been resolved yet, so we can // just skip it and wait for \c applyPropertyWrapperToParameter // to assign types. if (wrappedVar->hasImplicitPropertyWrapper()) return Action::Continue(expr); auto outermostWrapperAttr = wrappedVar->getOutermostAttachedPropertyWrapper(); // If the attribute doesn't have a type it could only mean // that the declaration was incorrect. if (!CS.hasType(outermostWrapperAttr->getTypeExpr())) return Action::Continue(expr); auto wrapperType = CS.simplifyType(CS.getType(outermostWrapperAttr->getTypeExpr())); if (var->getName().hasDollarPrefix()) { // $ is the projected value var CS.setType(var, computeProjectedValueType(wrappedVar, wrapperType)); } else { // _ is the wrapper var CS.setType(var, wrapperType); } return Action::Continue(expr); } // If there is no type recorded yet, let's check whether // it is a placeholder variable implicitly generated by the // compiler. if (auto *PB = var->getParentPatternBinding()) { if (auto placeholderTy = isPlaceholderVar(PB)) { auto openedTy = CS.replaceInferableTypesWithTypeVars( placeholderTy, CS.getConstraintLocator(expr)); inferVariables(openedTy); CS.setType(var, openedTy); } } } return Action::Continue(expr); } PostWalkResult walkToExprPost(Expr *expr) override { if (auto *closure = dyn_cast(expr)) { ClosureDCs.pop_back(); } return Action::Continue(expr); } PreWalkResult walkToStmtPre(Stmt *stmt) override { // Return statements have to reference outside result type // since all of them are joined by it if it's not specified // explicitly. if (isa(stmt)) { if (auto *closure = getAsExpr(Parent)) { // Return is only viable if it belongs to a parent closure. if (currentClosureDC() == closure) inferVariables(CS.getClosureType(closure)->getResult()); } } return Action::Continue(stmt); } private: DeclContext *currentClosureDC() const { return ClosureDCs.empty() ? nullptr : ClosureDCs.back(); } void inferVariables(Type type) { type = type->getWithoutSpecifierType(); // Record the type variable itself because it has to // be in scope even when already bound. if (auto *typeVar = type->getAs()) { ReferencedVars.insert(typeVar); // It is possible that contextual type of a parameter/result // has been assigned to e.g. an anonymous or named argument // early, to facilitate closure type checking. Such a // type can have type variables inside e.g. // // func test(_: (UnsafePointer) -> Void) {} // // test { ptr in // ... // } // // Type variable representing `ptr` in the body of // this closure would be bound to `UnsafePointer<$T>` // in this case, where `$T` is a type variable for a // generic parameter `T`. type = CS.getFixedTypeRecursive(typeVar, /*wantRValue=*/false); if (type->isEqual(typeVar)) return; } if (type->hasTypeVariable()) { SmallPtrSet typeVars; type->getTypeVariables(typeVars); ReferencedVars.insert(typeVars.begin(), typeVars.end()); } } }; /// Find any references to not yet resolved outer closure parameters /// used in the body of the inner closure. This is required because /// isolated conjunctions, just like single-expression closures, have /// to be connected to type variables they are going to use, otherwise /// they'll get placed in a separate solver component and would never /// produce a solution. class UnresolvedClosureParameterCollector : public ASTWalker { ConstraintSystem &CS; llvm::SmallSetVector Vars; public: UnresolvedClosureParameterCollector(ConstraintSystem &cs) : CS(cs) {} PreWalkResult walkToExprPre(Expr *expr) override { if (auto *DRE = dyn_cast(expr)) { auto *decl = DRE->getDecl(); if (isa(decl)) { if (auto type = CS.getTypeIfAvailable(decl)) { if (auto *typeVar = type->getAs()) { Vars.insert(typeVar); } else if (type->hasTypeVariable()) { // Parameter or result type could be only partially // resolved e.g. `{ (x: X) -> Void in ... }` where // `X` is a generic type. SmallPtrSet typeVars; type->getTypeVariables(typeVars); Vars.insert(typeVars.begin(), typeVars.end()); } } } } return Action::Continue(expr); } ArrayRef getVariables() const { return Vars.getArrayRef(); } }; // MARK: Constraint generation /// Check whether it makes sense to convert this element into a constraint. static bool isViableElement(ASTNode element) { if (auto *decl = element.dyn_cast()) { // - Ignore variable declarations, they are handled by pattern bindings; // - Ignore #if, the chosen children should appear in the // surrounding context; // - Skip #warning and #error, they are handled during solution // application. if (isa(decl) || isa(decl) || isa(decl)) return false; } if (auto *stmt = element.dyn_cast()) { // Empty brace statements are now viable because they do not require // inference. if (auto *braceStmt = dyn_cast(stmt)) { return braceStmt->getNumElements() > 0; } } return true; } using ElementInfo = std::tuple; static void createConjunction(ConstraintSystem &cs, ArrayRef elements, ConstraintLocator *locator) { bool isIsolated = false; SmallVector constraints; SmallVector referencedVars; if (locator->directlyAt()) { auto *closure = castToExpr(locator->getAnchor()); // Conjunction associated with the body of the closure has to // reference a type variable representing closure type, // otherwise it would get disconnected from its contextual type. referencedVars.push_back(cs.getType(closure)->castTo()); // Result builder could be generic but attribute allows its use // in "unbound" form (i.e. `@Builder` where `Builder` is defined // as `struct Builder`). Generic parameters of such a result // builder type are inferable from context, namely from `build*` // calls injected by the transform, and are not always resolved at // the time conjunction is created. // // Conjunction needs to reference all the type variables associated // with result builder just like parameters and result type of // the closure in order to stay connected to its context. if (auto builder = cs.getAppliedResultBuilderTransform(closure)) { SmallPtrSet builderVars; builder->builderType->getTypeVariables(builderVars); referencedVars.append(builderVars.begin(), builderVars.end()); } // Body of the closure is always isolated from its context, only // its individual elements are allowed access to type information // from the outside e.g. parameters/result type. isIsolated = true; } UnresolvedClosureParameterCollector paramCollector(cs); for (const auto &entry : elements) { ASTNode element = std::get<0>(entry); ContextualTypeInfo context = std::get<1>(entry); bool isDiscarded = std::get<2>(entry); ConstraintLocator *elementLoc = std::get<3>(entry); if (!isViableElement(element)) continue; // If this conjunction going to represent a body of a closure, // let's collect references to not yet resolved outer // closure parameters. if (isIsolated) element.walk(paramCollector); constraints.push_back(Constraint::createSyntacticElement( cs, element, context, elementLoc, isDiscarded)); } // It's possible that there are no viable elements in the body, // because e.g. whole body is an `#if` statement or it only has // declarations that are checked during solution application. // In such cases, let's avoid creating a conjunction. if (constraints.empty()) return; for (auto *externalVar : paramCollector.getVariables()) referencedVars.push_back(externalVar); cs.addUnsolvedConstraint(Constraint::createConjunction( cs, constraints, isIsolated, locator, referencedVars)); } ElementInfo makeElement(ASTNode node, ConstraintLocator *locator, ContextualTypeInfo context = ContextualTypeInfo(), bool isDiscarded = false) { return std::make_tuple(node, context, isDiscarded, locator); } struct SyntacticElementContext : public llvm::PointerUnion { // Inherit the constructors from PointerUnion. using PointerUnion::PointerUnion; static SyntacticElementContext forFunctionRef(AnyFunctionRef ref) { if (auto *decl = ref.getAbstractFunctionDecl()) { return {decl}; } return {ref.getAbstractClosureExpr()}; } static SyntacticElementContext forClosure(ClosureExpr *closure) { return {closure}; } static SyntacticElementContext forFunction(AbstractFunctionDecl *func) { return {func}; } DeclContext *getAsDeclContext() const { if (auto *fn = this->dyn_cast()) { return fn; } else if (auto *closure = this->dyn_cast()) { return closure; } else { llvm_unreachable("unsupported kind"); } } NullablePtr getAsAbstractClosureExpr() const { return this->dyn_cast(); } NullablePtr getAsAbstractFunctionDecl() const { return this->dyn_cast(); } Optional getAsAnyFunctionRef() const { if (auto *fn = this->dyn_cast()) { return {fn}; } else if (auto *closure = this->dyn_cast()) { return {closure}; } else { return None; } } BraceStmt *getBody() const { if (auto *fn = this->dyn_cast()) { return fn->getBody(); } else if (auto *closure = this->dyn_cast()) { return closure->getBody(); } else { llvm_unreachable("unsupported kind"); } } bool isSingleExpressionClosure(ConstraintSystem &cs) { if (auto ref = getAsAnyFunctionRef()) { if (cs.getAppliedResultBuilderTransform(*ref)) return false; if (auto *closure = ref->getAbstractClosureExpr()) return closure->hasSingleExpressionBody(); } return false; } }; /// Statement visitor that generates constraints for a given closure body. class SyntacticElementConstraintGenerator : public StmtVisitor { friend StmtVisitor; ConstraintSystem &cs; SyntacticElementContext context; ConstraintLocator *locator; public: /// Whether an error was encountered while generating constraints. bool hadError = false; SyntacticElementConstraintGenerator(ConstraintSystem &cs, SyntacticElementContext context, ConstraintLocator *locator) : cs(cs), context(context), locator(locator) {} void visitPattern(Pattern *pattern, ContextualTypeInfo context) { auto parentElement = locator->getLastElementAs(); if (!parentElement) { hadError = true; return; } if (auto *stmt = parentElement->getElement().dyn_cast()) { if (isa(stmt)) { visitForEachPattern(pattern, cast(stmt)); return; } if (isa(stmt)) { visitCaseItemPattern(pattern, context); return; } } llvm_unreachable("Unsupported pattern"); } void visitCaseItem(CaseLabelItem *caseItem, ContextualTypeInfo contextInfo) { assert(contextInfo.purpose == CTP_CaseStmt); // Resolve the pattern. auto *pattern = caseItem->getPattern(); if (!caseItem->isPatternResolved()) { pattern = TypeChecker::resolvePattern(pattern, context.getAsDeclContext(), /*isStmtCondition=*/false); if (!pattern) { hadError = true; return; } caseItem->setPattern(pattern, /*resolved=*/true); } // Let's generate constraints for pattern + where clause. // The assumption is that this shouldn't be too complex // to handle, but if it turns out to be false, this could // always be converted into a conjunction. // Generate constraints for pattern. visitPattern(pattern, contextInfo); auto *guardExpr = caseItem->getGuardExpr(); // Generate constraints for `where` clause (if any). if (guardExpr) { guardExpr = cs.generateConstraints(guardExpr, context.getAsDeclContext()); if (!guardExpr) { hadError = true; return; } } // Save information about case item so it could be referenced during // solution application. cs.setCaseLabelItemInfo(caseItem, {pattern, guardExpr}); } private: /// This method handles both pattern and the sequence expression /// associated with `for-in` loop because types in this situation /// flow in both directions: /// /// - From pattern to sequence, informing its element type e.g. /// `for i: Int8 in 0 ..< 8` /// /// - From sequence to pattern, when pattern has no type information. void visitForEachPattern(Pattern *pattern, ForEachStmt *forEachStmt) { auto target = SolutionApplicationTarget::forForEachStmt( forEachStmt, context.getAsDeclContext(), /*bindTypeVarsOneWay=*/false); if (cs.generateConstraints(target, FreeTypeVariableBinding::Disallow)) { hadError = true; return; } // After successful constraint generation, let's record // solution application target with all relevant information. cs.setSolutionApplicationTarget(forEachStmt, target); } void visitCaseItemPattern(Pattern *pattern, ContextualTypeInfo context) { Type patternType = cs.generateConstraints( pattern, locator, /*bindPatternVarsOneWay=*/false, /*patternBinding=*/nullptr, /*patternIndex=*/0); if (!patternType) { hadError = true; return; } // Convert the contextual type to the pattern, which establishes the // bindings. cs.addConstraint(ConstraintKind::Conversion, context.getType(), patternType, locator); // For any pattern variable that has a parent variable (i.e., another // pattern variable with the same name in the same case), require that // the types be equivalent. pattern->forEachNode([&](Pattern *pattern) { auto namedPattern = dyn_cast(pattern); if (!namedPattern) return; auto var = namedPattern->getDecl(); if (auto parentVar = var->getParentVarDecl()) { cs.addConstraint( ConstraintKind::Equal, cs.getType(parentVar), cs.getType(var), cs.getConstraintLocator( locator, LocatorPathElt::PatternMatch(namedPattern))); } }); } void visitPatternBinding(PatternBindingDecl *patternBinding, SmallVectorImpl &patterns) { auto *baseLoc = cs.getConstraintLocator( locator, LocatorPathElt::SyntacticElement(patternBinding)); for (unsigned index : range(patternBinding->getNumPatternEntries())) { if (patternBinding->isInitializerChecked(index)) continue; auto *pattern = TypeChecker::resolvePattern( patternBinding->getPattern(index), patternBinding->getDeclContext(), /*isStmtCondition=*/true); if (!pattern) { hadError = true; return; } // Reset binding to point to the resolved pattern. This is required // before calling `forPatternBindingDecl`. patternBinding->setPattern(index, pattern, patternBinding->getInitContext(index)); patterns.push_back(makeElement( patternBinding, cs.getConstraintLocator( baseLoc, LocatorPathElt::PatternBindingElement(index)))); } } Optional getTargetForPattern(PatternBindingDecl *patternBinding, unsigned index, Type patternType) { auto hasPropertyWrapper = [&](Pattern *pattern) -> bool { if (auto *singleVar = pattern->getSingleVar()) return singleVar->hasAttachedPropertyWrapper(); return false; }; auto *pattern = patternBinding->getPattern(index); auto *init = patternBinding->getInit(index); if (!init && patternBinding->isDefaultInitializable(index) && pattern->hasStorage()) { init = TypeChecker::buildDefaultInitializer(patternType); } // A property wrapper initializer (either user-defined // or a synthesized one) has to be pre-checked before use. // // This is not a problem in top-level code because pattern // bindings go through `typeCheckExpression` which does // pre-check automatically and result builders do not allow // declaring local wrapped variables (yet). if (hasPropertyWrapper(pattern)) { auto target = SolutionApplicationTarget::forInitialization( init, patternBinding->getDeclContext(), patternType, patternBinding, index, /*bindPatternVarsOneWay=*/false); if (ConstraintSystem::preCheckTarget( target, /*replaceInvalidRefsWithErrors=*/true, /*LeaveCLosureBodyUnchecked=*/false)) return None; return target; } if (init) { return SolutionApplicationTarget::forInitialization( init, patternBinding->getDeclContext(), patternType, patternBinding, index, /*bindPatternVarsOneWay=*/false); } return SolutionApplicationTarget::forUninitializedVar(patternBinding, index, patternType); } void visitPatternBindingElement(PatternBindingDecl *patternBinding) { assert(locator->isLastElement()); auto index = locator->castLastElementTo() .getIndex(); if (patternBinding->isInitializerChecked(index)) return; auto contextualPattern = ContextualPattern::forPatternBindingDecl(patternBinding, index); Type patternType = TypeChecker::typeCheckPattern(contextualPattern); // Fail early if pattern couldn't be type-checked. if (!patternType || patternType->hasError()) { hadError = true; return; } auto target = getTargetForPattern(patternBinding, index, patternType); if (!target) { hadError = true; return; } // Keep track of this binding entry. cs.setSolutionApplicationTarget({patternBinding, index}, *target); if (isPlaceholderVar(patternBinding)) return; if (cs.generateConstraints(*target, FreeTypeVariableBinding::Disallow)) { hadError = true; return; } } void visitDecl(Decl *decl) { if (!context.isSingleExpressionClosure(cs)) { if (auto patternBinding = dyn_cast(decl)) { if (locator->isLastElement()) visitPatternBindingElement(patternBinding); else llvm_unreachable("cannot visit pattern binding directly"); return; } } // Just ignore #if; the chosen children should appear in the // surrounding context. This isn't good for source tools but it // at least works. if (isa(decl)) return; // Skip #warning/#error; we'll handle them when applying the closure. if (isa(decl)) return; // Ignore variable declarations, because they're always handled within // their enclosing pattern bindings. if (isa(decl)) return; // Other declarations will be handled at application time. } void visitBreakStmt(BreakStmt *breakStmt) { } void visitContinueStmt(ContinueStmt *continueStmt) { } void visitDeferStmt(DeferStmt *deferStmt) { } void visitFallthroughStmt(FallthroughStmt *fallthroughStmt) { } void visitStmtCondition(LabeledConditionalStmt *S, SmallVectorImpl &elements, ConstraintLocator *locator) { auto *condLocator = cs.getConstraintLocator(locator, ConstraintLocator::Condition); for (auto &condition : S->getCond()) elements.push_back(makeElement(&condition, condLocator)); } void visitIfStmt(IfStmt *ifStmt) { SmallVector elements; // Condition visitStmtCondition(ifStmt, elements, locator); // Then Branch { auto *thenLoc = cs.getConstraintLocator( locator, LocatorPathElt::TernaryBranch(/*then=*/true)); elements.push_back(makeElement(ifStmt->getThenStmt(), thenLoc)); } // Else Branch (if any). if (auto *elseStmt = ifStmt->getElseStmt()) { auto *elseLoc = cs.getConstraintLocator( locator, LocatorPathElt::TernaryBranch(/*then=*/false)); elements.push_back(makeElement(ifStmt->getElseStmt(), elseLoc)); } createConjunction(cs, elements, locator); } void visitGuardStmt(GuardStmt *guardStmt) { SmallVector elements; visitStmtCondition(guardStmt, elements, locator); elements.push_back(makeElement(guardStmt->getBody(), locator)); createConjunction(cs, elements, locator); } void visitWhileStmt(WhileStmt *whileStmt) { SmallVector elements; visitStmtCondition(whileStmt, elements, locator); elements.push_back(makeElement(whileStmt->getBody(), locator)); createConjunction(cs, elements, locator); } void visitDoStmt(DoStmt *doStmt) { visitBraceStmt(doStmt->getBody()); } void visitRepeatWhileStmt(RepeatWhileStmt *repeatWhileStmt) { createConjunction(cs, {makeElement(repeatWhileStmt->getCond(), cs.getConstraintLocator( locator, ConstraintLocator::Condition), getContextForCondition()), makeElement(repeatWhileStmt->getBody(), locator)}, locator); } void visitPoundAssertStmt(PoundAssertStmt *poundAssertStmt) { createConjunction(cs, {makeElement(poundAssertStmt->getCondition(), cs.getConstraintLocator( locator, ConstraintLocator::Condition), getContextForCondition())}, locator); } void visitThrowStmt(ThrowStmt *throwStmt) { if (!cs.getASTContext().getErrorDecl()) { hadError = true; return; } auto errType = cs.getASTContext().getErrorExistentialType(); auto *errorExpr = throwStmt->getSubExpr(); createConjunction( cs, {makeElement( errorExpr, cs.getConstraintLocator( locator, LocatorPathElt::SyntacticElement(errorExpr)), {errType, CTP_ThrowStmt})}, locator); } void visitForEachStmt(ForEachStmt *forEachStmt) { auto *stmtLoc = cs.getConstraintLocator(locator); SmallVector elements; // For-each pattern. // // Note that we don't record a sequence or where clause here, // they would be handled together with pattern because pattern can // inform a type of sequence element e.g. `for i: Int8 in 0 ..< 8` elements.push_back(makeElement(forEachStmt->getPattern(), stmtLoc)); // Body of the `for-in` loop. elements.push_back(makeElement(forEachStmt->getBody(), stmtLoc)); createConjunction(cs, elements, locator); } void visitSwitchStmt(SwitchStmt *switchStmt) { auto *switchLoc = cs.getConstraintLocator( locator, LocatorPathElt::SyntacticElement(switchStmt)); SmallVector elements; { auto *subjectExpr = switchStmt->getSubjectExpr(); { elements.push_back(makeElement(subjectExpr, switchLoc)); SolutionApplicationTarget target(subjectExpr, context.getAsDeclContext(), CTP_Unused, Type(), /*isDiscarded=*/false); cs.setSolutionApplicationTarget(switchStmt, target); } for (auto rawCase : switchStmt->getRawCases()) elements.push_back(makeElement(rawCase, switchLoc)); } createConjunction(cs, elements, switchLoc); } void visitDoCatchStmt(DoCatchStmt *doStmt) { auto *doLoc = cs.getConstraintLocator( locator, LocatorPathElt::SyntacticElement(doStmt)); SmallVector elements; // First, let's record a body of `do` statement. elements.push_back(makeElement(doStmt->getBody(), doLoc)); // After that has been type-checked, let's switch to // individual `catch` statements. for (auto *catchStmt : doStmt->getCatches()) elements.push_back(makeElement(catchStmt, doLoc)); createConjunction(cs, elements, doLoc); } void visitCaseStmt(CaseStmt *caseStmt) { Type contextualTy; { auto parent = locator->castLastElementTo() .getElement(); if (parent.isStmt(StmtKind::Switch)) { auto *switchStmt = cast(parent.get()); contextualTy = cs.getType(switchStmt->getSubjectExpr()); } else if (parent.isStmt(StmtKind::DoCatch)) { contextualTy = cs.getASTContext().getErrorExistentialType(); } else { hadError = true; return; } } auto *caseLoc = cs.getConstraintLocator( locator, LocatorPathElt::SyntacticElement(caseStmt)); SmallVector elements; for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) { elements.push_back( makeElement(&caseLabelItem, caseLoc, {contextualTy, CTP_CaseStmt})); } elements.push_back(makeElement(caseStmt->getBody(), caseLoc)); createConjunction(cs, elements, caseLoc); } void visitBraceStmt(BraceStmt *braceStmt) { if (context.isSingleExpressionClosure(cs)) { for (auto node : braceStmt->getElements()) { if (auto expr = node.dyn_cast()) { auto generatedExpr = cs.generateConstraints( expr, context.getAsDeclContext(), /*isInputExpression=*/false); if (!generatedExpr) { hadError = true; } } else if (auto stmt = node.dyn_cast()) { visit(stmt); } else { visitDecl(node.get()); } } return; } auto &ctx = cs.getASTContext(); if (isChildOf(StmtKind::Case)) { auto *caseStmt = cast( locator->castLastElementTo() .asStmt()); if (recordInferredSwitchCasePatternVars(caseStmt)) { hadError = true; } } SmallVector elements; for (auto element : braceStmt->getElements()) { bool isDiscarded = element.is() && (!ctx.LangOpts.Playground && !ctx.LangOpts.DebuggerSupport); if (auto *decl = element.dyn_cast()) { if (auto *PDB = dyn_cast(decl)) { visitPatternBinding(PDB, elements); continue; } } elements.push_back( makeElement(element, cs.getConstraintLocator( locator, LocatorPathElt::SyntacticElement(element)), /*contextualInfo=*/{}, isDiscarded)); } createConjunction(cs, elements, locator); } void visitReturnStmt(ReturnStmt *returnStmt) { // Single-expression closures are effectively a `return` statement, // so let's give them a special locator as to indicate that. // Return statements might not have a result if we have a closure whose // implicit returned value is coerced to Void. if (context.isSingleExpressionClosure(cs) && returnStmt->hasResult()) { auto *expr = returnStmt->getResult(); assert(expr && "single expression closure without expression?"); expr = cs.generateConstraints(expr, context.getAsDeclContext(), /*isInputExpression=*/false); if (!expr) { hadError = true; return; } auto contextualResultInfo = getContextualResultInfo(); cs.addConstraint(ConstraintKind::Conversion, cs.getType(expr), contextualResultInfo.getType(), cs.getConstraintLocator( context.getAsAbstractClosureExpr().get(), LocatorPathElt::ClosureBody( /*hasReturn=*/!returnStmt->isImplicit()))); return; } Expr *resultExpr; if (returnStmt->hasResult()) { resultExpr = returnStmt->getResult(); assert(resultExpr && "non-empty result without expression?"); } else { // If this is simplify `return`, let's create an empty tuple // which is also useful if contextual turns out to be e.g. `Void?`. // Also, attach return stmt source location so if there is a contextual // mismatch we can produce a diagnostic in a valid source location. resultExpr = getVoidExpr(cs.getASTContext(), returnStmt->getEndLoc()); } auto contextualResultInfo = getContextualResultInfo(); SolutionApplicationTarget target(resultExpr, context.getAsDeclContext(), contextualResultInfo.purpose, contextualResultInfo.getType(), /*isDiscarded=*/false); if (cs.generateConstraints(target, FreeTypeVariableBinding::Disallow)) { hadError = true; return; } cs.setContextualType(target.getAsExpr(), TypeLoc::withoutLoc(contextualResultInfo.getType()), contextualResultInfo.purpose); cs.setSolutionApplicationTarget(returnStmt, target); } ContextualTypeInfo getContextualResultInfo() const { auto funcRef = context.getAsAnyFunctionRef(); if (!funcRef) return {Type(), CTP_Unused}; if (auto transform = cs.getAppliedResultBuilderTransform(*funcRef)) return {transform->bodyResultType, CTP_ReturnStmt}; if (auto *closure = getAsExpr(funcRef->getAbstractClosureExpr())) return {cs.getClosureType(closure)->getResult(), CTP_ClosureResult}; return {funcRef->getBodyResultType(), CTP_ReturnStmt}; } #define UNSUPPORTED_STMT(STMT) void visit##STMT##Stmt(STMT##Stmt *) { \ llvm_unreachable("Unsupported statement kind " #STMT); \ } UNSUPPORTED_STMT(Yield) UNSUPPORTED_STMT(Fail) #undef UNSUPPORTED_STMT private: ContextualTypeInfo getContextForCondition() const { auto boolDecl = cs.getASTContext().getBoolDecl(); assert(boolDecl && "Bool is missing"); return {boolDecl->getDeclaredInterfaceType(), CTP_Condition}; } bool isChildOf(StmtKind kind) { if (locator->getPath().empty()) return false; auto parentElt = locator->getLastElementAs(); return parentElt ? parentElt->getElement().isStmt(kind) : false; } bool recordInferredSwitchCasePatternVars(CaseStmt *caseStmt) { llvm::SmallDenseMap, 4> patternVars; auto recordVar = [&](VarDecl *var) { if (!var->hasName()) return; patternVars[var->getName()].push_back(var); }; for (auto &caseItem : caseStmt->getMutableCaseLabelItems()) { assert(caseItem.isPatternResolved()); auto *pattern = caseItem.getPattern(); pattern->forEachVariable([&](VarDecl *var) { recordVar(var); }); } for (auto bodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) { if (!bodyVar->hasName()) continue; const auto &variants = patternVars[bodyVar->getName()]; auto getType = [&](VarDecl *var) { auto type = cs.simplifyType(cs.getType(var)); assert(!type->hasTypeVariable()); return type; }; switch (variants.size()) { case 0: break; case 1: // If there is only one choice here, let's use it directly. cs.setType(bodyVar, getType(variants.front())); break; default: { // If there are multiple choices it could only mean multiple // patterns e.g. `.a(let x), .b(let x), ...:`. Let's join them. Type joinType = getType(variants.front()); SmallVector conflicts; for (auto *var : llvm::drop_begin(variants)) { auto varType = getType(var); // Type mismatch between different patterns. if (!joinType->isEqual(varType)) conflicts.push_back(var); } if (!conflicts.empty()) { if (!cs.shouldAttemptFixes()) return true; // dfdf auto *locator = cs.getConstraintLocator(bodyVar); if (cs.recordFix(RenameConflictingPatternVariables::create( cs, joinType, conflicts, locator))) return true; } cs.setType(bodyVar, joinType); } } } return false; } }; } bool ConstraintSystem::generateConstraints(ClosureExpr *closure) { auto &ctx = closure->getASTContext(); if (participatesInInference(closure)) { SyntacticElementConstraintGenerator generator( *this, SyntacticElementContext::forClosure(closure), getConstraintLocator(closure)); generator.visit(closure->getBody()); if (closure->hasSingleExpressionBody()) return generator.hadError; } // If this closure has an empty body and no explicit result type // let's bind result type to `Void` since that's the only type empty body // can produce. Otherwise, if (multi-statement) closure doesn't have // an explicit result (no `return` statements) let's default it to `Void`. if (!hasExplicitResult(closure)) { auto constraintKind = (closure->hasEmptyBody() && !closure->hasExplicitResultType()) ? ConstraintKind::Bind : ConstraintKind::Defaultable; addConstraint( constraintKind, getClosureType(closure)->getResult(), ctx.TheEmptyTupleType, getConstraintLocator(closure, ConstraintLocator::ClosureResult)); } return false; } bool ConstraintSystem::generateConstraints(AnyFunctionRef fn, BraceStmt *body) { NullablePtr locator; if (auto *func = fn.getAbstractFunctionDecl()) { locator = getConstraintLocator(func); } else { locator = getConstraintLocator(fn.getAbstractClosureExpr()); } SyntacticElementConstraintGenerator generator( *this, SyntacticElementContext::forFunctionRef(fn), locator.get()); generator.visit(body); return generator.hadError; } bool ConstraintSystem::isInResultBuilderContext(ClosureExpr *closure) const { if (!closure->hasSingleExpressionBody()) { auto *DC = closure->getParent(); do { // Result builder is applied to a function/getter body. if (auto *AFD = dyn_cast(DC)) { if (resultBuilderTransformed.count(AFD)) return true; } if (auto *parentClosure = dyn_cast(DC)) { if (resultBuilderTransformed.count(parentClosure)) return true; } } while ((DC = DC->getParent())); } return false; } bool isConditionOfStmt(ConstraintLocatorBuilder locator) { auto last = locator.last(); if (!(last && last->is())) return false; SmallVector path; (void)locator.getLocatorParts(path); path.pop_back(); if (path.empty()) return false; if (auto closureElt = path.back().getAs()) return closureElt->getElement().dyn_cast(); return false; } ConstraintSystem::SolutionKind ConstraintSystem::simplifySyntacticElementConstraint( ASTNode element, ContextualTypeInfo contextInfo, bool isDiscarded, TypeMatchOptions flags, ConstraintLocatorBuilder locator) { auto anchor = locator.getAnchor(); Optional context; if (auto *closure = getAsExpr(anchor)) { context = SyntacticElementContext::forClosure(closure); } else if (auto *fn = getAsDecl(anchor)) { context = SyntacticElementContext::forFunction(fn); } else { return SolutionKind::Error; } SyntacticElementConstraintGenerator generator(*this, *context, getConstraintLocator(locator)); if (auto *expr = element.dyn_cast()) { SolutionApplicationTarget target(expr, context->getAsDeclContext(), contextInfo.purpose, contextInfo.getType(), isDiscarded); if (generateConstraints(target, FreeTypeVariableBinding::Disallow)) return SolutionKind::Error; setSolutionApplicationTarget(expr, target); return SolutionKind::Solved; } else if (auto *stmt = element.dyn_cast()) { generator.visit(stmt); } else if (auto *cond = element.dyn_cast()) { if (generateConstraints({*cond}, context->getAsDeclContext())) return SolutionKind::Error; } else if (auto *pattern = element.dyn_cast()) { generator.visitPattern(pattern, contextInfo); } else if (auto *caseItem = element.dyn_cast()) { generator.visitCaseItem(caseItem, contextInfo); } else { generator.visit(element.get()); } return generator.hadError ? SolutionKind::Error : SolutionKind::Solved; } // MARK: Solution application namespace { /// Statement visitor that applies constraints for a given closure body. class SyntacticElementSolutionApplication : public StmtVisitor { friend StmtVisitor; friend class ResultBuilderRewriter; protected: Solution &solution; SyntacticElementContext context; Type resultType; RewriteTargetFn rewriteTarget; /// All `func`s declared in the body of the closure. SmallVector LocalFuncs; public: /// Whether an error was encountered while generating constraints. bool hadError = false; SyntacticElementSolutionApplication(Solution &solution, SyntacticElementContext context, Type resultType, RewriteTargetFn rewriteTarget) : solution(solution), context(context), resultType(resultType), rewriteTarget(rewriteTarget) {} virtual ~SyntacticElementSolutionApplication() {} private: ASTNode visit(Stmt *S) { auto rewritten = ASTVisitor::visit(S); if (!rewritten) return {}; if (auto *stmt = getAsStmt(rewritten)) performStmtDiagnostics(stmt, context.getAsDeclContext()); return rewritten; } void visitDecl(Decl *decl) { if (isa(decl)) return; // Generate constraints for pattern binding declarations. if (auto patternBinding = dyn_cast(decl)) { SolutionApplicationTarget target(patternBinding); // If this is a placeholder varaible with an initializer, let's set // the inferred type, and ask `typeCheckDecl` to type-check initializer. if (isPlaceholderVar(patternBinding) && patternBinding->getInit(0)) { auto *pattern = patternBinding->getPattern(0); pattern->setType( solution.getResolvedType(patternBinding->getSingleVar())); TypeChecker::typeCheckDecl(decl); return; } if (!rewriteTarget(target)) { hadError = true; return; } // Allow `typeCheckDecl` to be called after solution is applied // to a pattern binding. That would materialize required // information e.g. accessors and do access/availability checks. } // Local functions cannot be type-checked in-order because they can // capture variables declared after them. Let's save them to be // processed after the solution has been applied to the body. if (auto *func = dyn_cast(decl)) { LocalFuncs.push_back(func); return; } TypeChecker::typeCheckDecl(decl); } ASTNode visitBreakStmt(BreakStmt *breakStmt) { auto *DC = context.getAsDeclContext(); if (auto target = findBreakOrContinueStmtTarget( DC->getASTContext(), DC->getParentSourceFile(), breakStmt->getLoc(), breakStmt->getTargetName(), breakStmt->getTargetLoc(), /*isContinue=*/false, context.getAsDeclContext())) { breakStmt->setTarget(target); } return breakStmt; } ASTNode visitContinueStmt(ContinueStmt *continueStmt) { auto *DC = context.getAsDeclContext(); if (auto target = findBreakOrContinueStmtTarget( DC->getASTContext(), DC->getParentSourceFile(), continueStmt->getLoc(), continueStmt->getTargetName(), continueStmt->getTargetLoc(), /*isContinue=*/true, context.getAsDeclContext())) { continueStmt->setTarget(target); } return continueStmt; } ASTNode visitFallthroughStmt(FallthroughStmt *fallthroughStmt) { if (checkFallthroughStmt(context.getAsDeclContext(), fallthroughStmt)) hadError = true; return fallthroughStmt; } ASTNode visitDeferStmt(DeferStmt *deferStmt) { TypeChecker::typeCheckDecl(deferStmt->getTempDecl()); Expr *theCall = deferStmt->getCallExpr(); TypeChecker::typeCheckExpression(theCall, context.getAsDeclContext()); deferStmt->setCallExpr(theCall); return deferStmt; } ASTNode visitIfStmt(IfStmt *ifStmt) { // Rewrite the condition. if (auto condition = rewriteTarget(SolutionApplicationTarget( ifStmt->getCond(), context.getAsDeclContext()))) ifStmt->setCond(*condition->getAsStmtCondition()); else hadError = true; ifStmt->setThenStmt(visit(ifStmt->getThenStmt()).get()); if (auto elseStmt = ifStmt->getElseStmt()) { ifStmt->setElseStmt(visit(elseStmt).get()); } return ifStmt; } ASTNode visitGuardStmt(GuardStmt *guardStmt) { if (auto condition = rewriteTarget(SolutionApplicationTarget( guardStmt->getCond(), context.getAsDeclContext()))) guardStmt->setCond(*condition->getAsStmtCondition()); else hadError = true; auto *body = visit(guardStmt->getBody()).get(); guardStmt->setBody(cast(body)); return guardStmt; } ASTNode visitWhileStmt(WhileStmt *whileStmt) { if (auto condition = rewriteTarget(SolutionApplicationTarget( whileStmt->getCond(), context.getAsDeclContext()))) whileStmt->setCond(*condition->getAsStmtCondition()); else hadError = true; auto *body = visit(whileStmt->getBody()).get(); whileStmt->setBody(cast(body)); return whileStmt; } virtual ASTNode visitDoStmt(DoStmt *doStmt) { auto body = visit(doStmt->getBody()).get(); doStmt->setBody(cast(body)); return doStmt; } ASTNode visitRepeatWhileStmt(RepeatWhileStmt *repeatWhileStmt) { auto body = visit(repeatWhileStmt->getBody()).get(); repeatWhileStmt->setBody(cast(body)); // Rewrite the condition. auto &cs = solution.getConstraintSystem(); auto target = *cs.getSolutionApplicationTarget(repeatWhileStmt->getCond()); if (auto condition = rewriteTarget(target)) repeatWhileStmt->setCond(condition->getAsExpr()); else hadError = true; return repeatWhileStmt; } ASTNode visitPoundAssertStmt(PoundAssertStmt *poundAssertStmt) { // FIXME: This should be done through \c solution instead of // constraint system. auto &cs = solution.getConstraintSystem(); // Rewrite the condition. auto target = *cs.getSolutionApplicationTarget(poundAssertStmt->getCondition()); if (auto result = rewriteTarget(target)) poundAssertStmt->setCondition(result->getAsExpr()); else hadError = true; return poundAssertStmt; } ASTNode visitThrowStmt(ThrowStmt *throwStmt) { auto &cs = solution.getConstraintSystem(); // Rewrite the error. auto target = *cs.getSolutionApplicationTarget(throwStmt->getSubExpr()); if (auto result = rewriteTarget(target)) throwStmt->setSubExpr(result->getAsExpr()); else hadError = true; return throwStmt; } ASTNode visitForEachStmt(ForEachStmt *forEachStmt) { ConstraintSystem &cs = solution.getConstraintSystem(); auto forEachTarget = rewriteTarget(*cs.getSolutionApplicationTarget(forEachStmt)); if (!forEachTarget) hadError = true; auto body = visit(forEachStmt->getBody()).get(); forEachStmt->setBody(cast(body)); // Check to see if the sequence expr is throwing (in async context), // if so require the stmt to have a `try`. hadError |= diagnoseUnhandledThrowsInAsyncContext( context.getAsDeclContext(), forEachStmt); return forEachStmt; } ASTNode visitSwitchStmt(SwitchStmt *switchStmt) { ConstraintSystem &cs = solution.getConstraintSystem(); // Rewrite the switch subject. auto subjectTarget = rewriteTarget(*cs.getSolutionApplicationTarget(switchStmt)); if (subjectTarget) { switchStmt->setSubjectExpr(subjectTarget->getAsExpr()); } else { hadError = true; } // Visit the raw cases. bool limitExhaustivityChecks = false; for (auto rawCase : switchStmt->getRawCases()) { if (auto decl = rawCase.dyn_cast()) { visitDecl(decl); continue; } auto caseStmt = cast(rawCase.get()); // Body of the `case` statement can contain a `fallthrough` // statement that requires both source and destination // `case` preambles to be type-checked, so bodies of `case` // statements should be visited after preambles. visitCaseStmtPreamble(caseStmt); } for (auto *caseStmt : switchStmt->getCases()) { visitCaseStmtBody(caseStmt); // Check restrictions on '@unknown'. if (caseStmt->hasUnknownAttr()) { checkUnknownAttrRestrictions(cs.getASTContext(), caseStmt, limitExhaustivityChecks); } } TypeChecker::checkSwitchExhaustiveness( switchStmt, context.getAsDeclContext(), limitExhaustivityChecks); return switchStmt; } ASTNode visitDoCatchStmt(DoCatchStmt *doStmt) { // Translate the body. auto newBody = visit(doStmt->getBody()); doStmt->setBody(newBody.get()); // Visit the catch blocks. for (auto catchStmt : doStmt->getCatches()) visitCaseStmt(catchStmt); return doStmt; } void visitCaseStmtPreamble(CaseStmt *caseStmt) { // Translate the patterns and guard expressions for each case label item. for (auto &caseItem : caseStmt->getMutableCaseLabelItems()) { SolutionApplicationTarget caseTarget(&caseItem, context.getAsDeclContext()); if (!rewriteTarget(caseTarget)) { hadError = true; } } bindSwitchCasePatternVars(context.getAsDeclContext(), caseStmt); for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray()) { assert(expected->hasName()); auto prev = expected->getParentVarDecl(); auto type = solution.resolveInterfaceType( solution.getType(prev)->mapTypeOutOfContext()); expected->setInterfaceType(type); } } void visitCaseStmtBody(CaseStmt *caseStmt) { auto *newBody = visit(caseStmt->getBody()).get(); caseStmt->setBody(cast(newBody)); } ASTNode visitCaseStmt(CaseStmt *caseStmt) { visitCaseStmtPreamble(caseStmt); visitCaseStmtBody(caseStmt); return caseStmt; } ASTNode visitBraceElement(ASTNode node) { auto &cs = solution.getConstraintSystem(); if (auto *expr = node.dyn_cast()) { // Rewrite the expression. auto target = *cs.getSolutionApplicationTarget(expr); if (auto rewrittenTarget = rewriteTarget(target)) { node = rewrittenTarget->getAsExpr(); if (target.isDiscardedExpr()) TypeChecker::checkIgnoredExpr(castToExpr(node)); } else { hadError = true; } } else if (auto stmt = node.dyn_cast()) { node = visit(stmt); } else { visitDecl(node.get()); } return node; } ASTNode visitBraceStmt(BraceStmt *braceStmt) { auto &cs = solution.getConstraintSystem(); // Diagnose defer statement being last one in block. if (!braceStmt->empty()) { if (auto stmt = braceStmt->getLastElement().dyn_cast()) { if (auto deferStmt = dyn_cast(stmt)) { auto &diags = cs.getASTContext().Diags; diags .diagnose(deferStmt->getStartLoc(), diag::defer_stmt_at_block_end) .fixItReplace(deferStmt->getStartLoc(), "do"); } } } for (auto &node : braceStmt->getElements()) node = visitBraceElement(node); // Source compatibility workaround. // // func test(_: () -> T?) { // ... // } // // A multi-statement closure passed to `test` that has an optional // `Void` result type inferred from the body allows: // - empty `return`(s); // - to skip `return nil` or `return ()` at the end. // // Implicit `return ()` has to be inserted as the last element // of the body if there is none. This wasn't needed before SE-0326 // because result type was (incorrectly) inferred as `Void` due to // the body being skipped. auto closure = context.getAsAbstractClosureExpr(); if (closure && !closure.get()->hasSingleExpressionBody() && closure.get()->getBody() == braceStmt) { if (resultType->getOptionalObjectType() && resultType->lookThroughAllOptionalTypes()->isVoid() && !braceStmt->getLastElement().isStmt(StmtKind::Return)) { return addImplicitVoidReturn(braceStmt); } } return braceStmt; } ASTNode addImplicitVoidReturn(BraceStmt *braceStmt) { auto &cs = solution.getConstraintSystem(); auto &ctx = cs.getASTContext(); auto *resultExpr = getVoidExpr(ctx); cs.cacheExprTypes(resultExpr); auto *returnStmt = new (ctx) ReturnStmt(SourceLoc(), resultExpr, /*implicit=*/true); // For a target for newly created result and apply a solution // to it, to make sure that optional injection happens required // number of times. { SolutionApplicationTarget target(resultExpr, context.getAsDeclContext(), CTP_ReturnStmt, resultType, /*isDiscarded=*/false); cs.setSolutionApplicationTarget(returnStmt, target); visitReturnStmt(returnStmt); } // Re-create brace statement with an additional `return` at the end. SmallVector elements; elements.append(braceStmt->getElements().begin(), braceStmt->getElements().end()); elements.push_back(returnStmt); return BraceStmt::create(ctx, braceStmt->getLBraceLoc(), elements, braceStmt->getRBraceLoc()); } ASTNode visitReturnStmt(ReturnStmt *returnStmt) { auto &cs = solution.getConstraintSystem(); if (!returnStmt->hasResult()) { // If contextual is not optional, there is nothing to do here. if (resultType->isVoid()) return returnStmt; // It's possible to infer e.g. `Void?` for cases where // `return` doesn't have an expression. If contextual // type is `Void` wrapped into N optional types, let's // add an implicit `()` expression and let it be injected // into optional required number of times. assert(resultType->getOptionalObjectType() && resultType->lookThroughAllOptionalTypes()->isVoid()); auto target = *cs.getSolutionApplicationTarget(returnStmt); returnStmt->setResult(target.getAsExpr()); } auto *resultExpr = returnStmt->getResult(); enum { convertToResult, coerceToVoid, coerceFromNever, } mode; auto resultExprType = solution.simplifyType(solution.getType(resultExpr))->getRValueType(); // A closure with a non-void return expression can coerce to a closure // that returns Void. if (resultType->isVoid() && !resultExprType->isVoid()) { mode = coerceToVoid; // A single-expression closure with a Never expression type // coerces to any other function type. } else if (context.isSingleExpressionClosure(cs) && resultExprType->isUninhabited()) { mode = coerceFromNever; // Normal rule is to coerce to the return expression to the closure type. } else { mode = convertToResult; } Optional resultTarget; if (auto target = cs.getSolutionApplicationTarget(returnStmt)) { resultTarget = *target; } else { // Single-expression closures have to handle returns in a special // way so the target has to be created for them during solution // application based on the resolved type. assert(context.isSingleExpressionClosure(cs)); resultTarget = SolutionApplicationTarget( resultExpr, context.getAsDeclContext(), mode == convertToResult ? CTP_ClosureResult : CTP_Unused, mode == convertToResult ? resultType : Type(), /*isDiscarded=*/false); } if (auto newResultTarget = rewriteTarget(*resultTarget)) { resultExpr = newResultTarget->getAsExpr(); } switch (mode) { case convertToResult: // Record the coerced expression. returnStmt->setResult(resultExpr); return returnStmt; case coerceToVoid: { // Evaluate the expression, then produce a return statement that // returns nothing. TypeChecker::checkIgnoredExpr(resultExpr); auto &ctx = solution.getConstraintSystem().getASTContext(); auto newReturnStmt = new (ctx) ReturnStmt( returnStmt->getStartLoc(), nullptr, /*implicit=*/true); ASTNode elements[2] = { resultExpr, newReturnStmt }; return BraceStmt::create(ctx, returnStmt->getStartLoc(), elements, returnStmt->getEndLoc(), /*implicit*/ true); } case coerceFromNever: // Replace the return statement with its expression, so that the // expression is evaluated directly. This only works because coercion // from never is limited to single-expression closures. return resultExpr; } return returnStmt; } #define UNSUPPORTED_STMT(STMT) ASTNode visit##STMT##Stmt(STMT##Stmt *) { \ llvm_unreachable("Unsupported statement kind " #STMT); \ } UNSUPPORTED_STMT(Yield) UNSUPPORTED_STMT(Fail) #undef UNSUPPORTED_STMT public: /// Apply solution to the closure and return updated body. ASTNode apply() { auto body = visit(context.getBody()); // Since local functions can capture variables that are declared // after them, let's type-check them after all of the pattern // bindings have been resolved by applying solution to the body. for (auto *func : LocalFuncs) TypeChecker::typeCheckDecl(func); return body; } }; class ResultBuilderRewriter : public SyntacticElementSolutionApplication { const AppliedBuilderTransform &Transform; public: ResultBuilderRewriter(Solution &solution, AnyFunctionRef context, const AppliedBuilderTransform &transform, RewriteTargetFn rewriteTarget) : SyntacticElementSolutionApplication( solution, SyntacticElementContext::forFunctionRef(context), transform.bodyResultType, rewriteTarget), Transform(transform) {} bool apply() { auto body = visit(context.getBody()); if (!body || hadError) return true; auto funcRef = context.getAsAnyFunctionRef(); assert(funcRef); funcRef->setTypecheckedBody(castToStmt(body), /*hasSingleExpression=*/false); if (auto *closure = getAsExpr(funcRef->getAbstractClosureExpr())) solution.setExprTypes(closure); return false; } private: ASTNode visitDoStmt(DoStmt *doStmt) override { if (auto transformed = transformDo(doStmt)) return visit(transformed.get()); auto newBody = visit(doStmt->getBody()); if (!newBody) return nullptr; doStmt->setBody(castToStmt(newBody)); return doStmt; } NullablePtr transformDo(DoStmt *doStmt) { if (!doStmt->isImplicit()) return nullptr; // Implicit `do` wraps a statement and it's `type_join` expression. auto *body = doStmt->getBody(); // If there are more than two elements, this `do` doesn't need to // get be transformed. if (body->getNumElements() != 2) return nullptr; auto *stmt = castToStmt(body->getFirstElement()); auto *join = castToExpr(body->getLastElement()); switch (stmt->getKind()) { case StmtKind::If: return transformIf(castToStmt(stmt), join, /*index=*/0); case StmtKind::Switch: return transformSwitch(castToStmt(stmt), join); default: llvm_unreachable("only 'if' and 'switch' statements are transformed"); } } NullablePtr transformSwitch(SwitchStmt *switchStmt, TypeJoinExpr *join) { unsigned caseIndex = 0; for (auto *caseStmt : switchStmt->getCases()) { auto newBody = transformBody(caseStmt->getBody(), join, caseIndex++); if (!newBody) return nullptr; caseStmt->setBody(newBody.get()); } return switchStmt; } NullablePtr transformIf(IfStmt *ifStmt, TypeJoinExpr *join, unsigned index) { // FIXME: Turn this into a condition once warning is an error. (void)diagnoseMissingBuildWithAvailability(ifStmt); auto *joinVar = join->getVar(); // First, let's add assignment to the end of `then` branch { auto *thenBody = castToStmt(ifStmt->getThenStmt()); auto newBody = transformBody(thenBody, join, index); if (!newBody) return nullptr; ifStmt->setThenStmt(newBody.get()); } if (auto *elseStmt = ifStmt->getElseStmt()) { if (auto *innerIfStmt = getAsStmt(elseStmt)) { auto transformedIf = transformIf(innerIfStmt, join, index + 1); if (!transformedIf) return nullptr; ifStmt->setElseStmt(transformedIf.get()); } else { auto newBody = transformBody(castToStmt(elseStmt), join, index + 1); if (!newBody) return nullptr; ifStmt->setElseStmt(newBody.get()); } } else { auto &ctx = getASTContext(); SmallVector elseBranch; elseBranch.push_back( createAssignment(joinVar, join->getElement(index + 1))); ifStmt->setElseStmt(BraceStmt::create(ctx, ifStmt->getEndLoc(), elseBranch, ifStmt->getEndLoc(), /*implicit=*/true)); } return ifStmt; } NullablePtr transformBody(BraceStmt *body, TypeJoinExpr *join, unsigned index) { for (auto &element : body->getElements()) { if (auto *doStmt = getAsStmt(element)) { if (auto transformed = transformDo(doStmt)) element = transformed.get(); } } return addBuilderAssignment(body, join->getVar(), join->getElement(index)); } // Add `$__bulderN = build{Optional, Either}(...)` at the end of a block body. BraceStmt *addBuilderAssignment(BraceStmt *body, DeclRefExpr *joinVar, Expr *builderCall) { SmallVector newBody; llvm::copy(body->getElements(), std::back_inserter(newBody)); newBody.push_back(createAssignment(joinVar, builderCall)); return BraceStmt::create(getASTContext(), body->getLBraceLoc(), newBody, body->getRBraceLoc(), body->isImplicit()); } AssignExpr *createAssignment(DeclRefExpr *destRef, Expr *source) { auto &ctx = getASTContext(); auto &CS = solution.getConstraintSystem(); auto *assignment = new (ctx) AssignExpr(destRef, /*EqualLoc=*/SourceLoc(), source, /*Implicit=*/true); { // Assignment expression is always `Void`. CS.setType(assignment, ctx.TheEmptyTupleType); CS.setSolutionApplicationTarget( {assignment}, {assignment, context.getAsDeclContext(), CTP_Unused, /*contextualType=*/Type(), /*isDiscarded=*/false}); } return assignment; } ASTContext &getASTContext() const { return context.getAsDeclContext()->getASTContext(); } private: /// Look for a #available condition. If there is one, we need to check /// that the resulting type of the "then" doesn't refer to any types that /// are unavailable in the enclosing context. /// /// Note that this is for staging in support for buildLimitedAvailability(); /// the diagnostic is currently a warning, so that existing code that /// compiles today will continue to compile. Once result builder types /// have had the chance to adopt buildLimitedAvailability(), we'll upgrade /// this warning to an error. LLVM_NODISCARD bool diagnoseMissingBuildWithAvailability(IfStmt *ifStmt) { auto findAvailabilityCondition = [](StmtCondition stmtCond) -> const StmtConditionElement * { for (const auto &cond : stmtCond) { switch (cond.getKind()) { case StmtConditionElement::CK_Boolean: case StmtConditionElement::CK_PatternBinding: case StmtConditionElement::CK_HasSymbol: continue; case StmtConditionElement::CK_Availability: return &cond; break; } } return nullptr; }; auto availabilityCond = findAvailabilityCondition(ifStmt->getCond()); if (!availabilityCond) return false; SourceLoc loc = availabilityCond->getStartLoc(); Type bodyType; if (availabilityCond->getAvailability()->isUnavailability()) { BraceStmt *elseBody = nullptr; // For #unavailable, we need to check the "else". if (auto *innerIf = getAsStmt(ifStmt->getElseStmt())) { elseBody = castToStmt(innerIf->getThenStmt()); } else { elseBody = castToStmt(ifStmt->getElseStmt()); } Type elseBodyType = solution.simplifyType(solution.getType(elseBody->getLastElement())); bodyType = elseBodyType; } else { auto *thenBody = castToStmt(ifStmt->getThenStmt()); Type thenBodyType = solution.simplifyType(solution.getType(thenBody->getLastElement())); bodyType = thenBodyType; } auto builderType = solution.simplifyType(Transform.builderType); return bodyType.findIf([&](Type type) { auto nominal = type->getAnyNominal(); if (!nominal) return false; ExportContext where = ExportContext::forFunctionBody(context.getAsDeclContext(), loc); if (auto reason = TypeChecker::checkDeclarationAvailability(nominal, where)) { auto &ctx = getASTContext(); ctx.Diags.diagnose(loc, diag::result_builder_missing_limited_availability, builderType); // Add a note to the result builder with a stub for // buildLimitedAvailability(). if (auto builder = builderType->getAnyNominal()) { SourceLoc buildInsertionLoc; std::string stubIndent; Type componentType; std::tie(buildInsertionLoc, stubIndent, componentType) = determineResultBuilderBuildFixItInfo(builder); if (buildInsertionLoc.isValid()) { std::string fixItString; { llvm::raw_string_ostream out(fixItString); printResultBuilderBuildFunction( builder, componentType, ResultBuilderBuildFunction::BuildLimitedAvailability, stubIndent, out); builder ->diagnose( diag::result_builder_missing_build_limited_availability, builderType) .fixItInsert(buildInsertionLoc, fixItString); } } } return true; } return false; }); } }; } // namespace SolutionApplicationToFunctionResult ConstraintSystem::applySolution( Solution &solution, AnyFunctionRef fn, DeclContext *¤tDC, RewriteTargetFn rewriteTarget) { auto &cs = solution.getConstraintSystem(); auto *closure = getAsExpr(fn.getAbstractClosureExpr()); FunctionType *closureFnType = nullptr; if (closure) { // Update the closure's type. auto closureType = solution.simplifyType(cs.getType(closure)); cs.setType(closure, closureType); // Coerce the parameter types. closureFnType = closureType->castTo(); auto *params = closure->getParameters(); TypeChecker::coerceParameterListToType(params, closureFnType); // Find any isolated parameters in this closure and mark them as isolated. for (auto param : solution.isolatedParams) { if (param->getDeclContext() == closure) param->setIsolated(true); } if (llvm::is_contained(solution.preconcurrencyClosures, closure)) closure->setIsolatedByPreconcurrency(); // Coerce the result type, if it was written explicitly. if (closure->hasExplicitResultType()) { closure->setExplicitResultType(closureFnType->getResult()); } } // Enter the context of the function before performing any additional // transformations. llvm::SaveAndRestore savedDC(currentDC, fn.getAsDeclContext()); // Apply the result builder transform, if there is one. if (auto transform = solution.getAppliedBuilderTransform(fn)) { NullablePtr newBody; if (Context.LangOpts.hasFeature(Feature::ResultBuilderASTTransform)) { BraceStmt *transformedBody = const_cast(transform->transformedBody.get()); fn.setParsedBody(transformedBody, /*singleExpression=*/false); ResultBuilderRewriter rewriter(solution, fn, *transform, rewriteTarget); return rewriter.apply() ? SolutionApplicationToFunctionResult::Failure : SolutionApplicationToFunctionResult::Success; } // Apply the result builder to the closure. We want to be in the // context of the closure for subsequent transforms. newBody = applyResultBuilderTransform( solution, *transform, fn.getBody(), fn.getAsDeclContext(), [&](SolutionApplicationTarget target) { auto resultTarget = rewriteTarget(target); if (resultTarget) { if (auto expr = resultTarget->getAsExpr()) solution.setExprTypes(expr); } return resultTarget; }); if (!newBody) return SolutionApplicationToFunctionResult::Failure; fn.setTypecheckedBody(newBody.get(), /*isSingleExpression=*/false); if (closure) { solution.setExprTypes(closure); } return SolutionApplicationToFunctionResult::Success; } assert(closure && "Can only get here with a closure at the moment"); // If this closure is checked as part of the enclosing expression, handle // that now. // // Multi-statement closures are handled separately because they need to // wait until all of the `ExtInfo` flags are propagated from the context // e.g. parameter could be no-escape if closure is applied to a call. if (closure->hasSingleExpressionBody()) { bool hadError = applySolutionToBody(solution, closure, currentDC, rewriteTarget); return hadError ? SolutionApplicationToFunctionResult::Failure : SolutionApplicationToFunctionResult::Success; } // Otherwise, we need to delay type checking of the closure until later. solution.setExprTypes(closure); closure->setBodyState(ClosureExpr::BodyState::ReadyForTypeChecking); return SolutionApplicationToFunctionResult::Delay; } bool ConstraintSystem::applySolutionToBody(Solution &solution, AnyFunctionRef fn, DeclContext *¤tDC, RewriteTargetFn rewriteTarget) { // Enter the context of the function before performing any additional // transformations. llvm::SaveAndRestore savedDC(currentDC, fn.getAsDeclContext()); Type resultTy; if (auto transform = solution.getAppliedBuilderTransform(fn)) { resultTy = solution.simplifyType(transform->bodyResultType); } else if (auto *closure = getAsExpr(fn.getAbstractClosureExpr())) { resultTy = solution.getResolvedType(closure)->castTo()->getResult(); } else { resultTy = fn.getBodyResultType(); } SyntacticElementSolutionApplication application( solution, SyntacticElementContext::forFunctionRef(fn), resultTy, rewriteTarget); auto body = application.apply(); if (!body || application.hadError) return true; fn.setTypecheckedBody(castToStmt(body), fn.hasSingleExpressionBody()); return false; } void ConjunctionElement::findReferencedVariables( ConstraintSystem &cs, SmallPtrSetImpl &typeVars) const { auto referencedVars = Element->getTypeVariables(); typeVars.insert(referencedVars.begin(), referencedVars.end()); if (Element->getKind() != ConstraintKind::SyntacticElement) return; ASTNode element = Element->getSyntacticElement(); auto *locator = Element->getLocator(); TypeVariableRefFinder refFinder(cs, locator->getAnchor(), typeVars); if (auto *patternBinding = dyn_cast_or_null(element.dyn_cast())) { // Let's not walk into placeholder variable initializers, since they // are type-checked separately right now. if (isPlaceholderVar(patternBinding)) return; if (auto patternBindingElt = locator ->getLastElementAs()) { if (auto *init = patternBinding->getInit(patternBindingElt->getIndex())) init->walk(refFinder); return; } } if (element.is() || element.is() || element.is() || element.isStmt(StmtKind::Return)) element.walk(refFinder); } Type constraints::isPlaceholderVar(PatternBindingDecl *PB) { auto *var = PB->getSingleVar(); if (!var) return Type(); if (!var->getName().hasDollarPrefix()) return Type(); auto *pattern = PB->getPattern(0); if (auto *typedPattern = dyn_cast(pattern)) { auto type = typedPattern->getType(); if (type && type->hasPlaceholder()) return type; } return Type(); }