//===--- Stmt.cpp - Swift Language Statement ASTs -------------------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements the Stmt class and subclasses. // //===----------------------------------------------------------------------===// #include "swift/AST/Stmt.h" #include "swift/AST/ASTContext.h" #include "swift/AST/ASTWalker.h" #include "swift/AST/AvailabilitySpec.h" #include "swift/AST/Decl.h" #include "swift/AST/Expr.h" #include "swift/AST/Pattern.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/Basic/Assertions.h" #include "swift/Basic/Statistic.h" #include "llvm/ADT/PointerUnion.h" using namespace swift; #define STMT(Id, _) \ static_assert(IsTriviallyDestructible::value, \ "Stmts are BumpPtrAllocated; the destructor is never called"); #include "swift/AST/StmtNodes.def" //===----------------------------------------------------------------------===// // Stmt methods. //===----------------------------------------------------------------------===// StringRef Stmt::getKindName(StmtKind K) { switch (K) { #define STMT(Id, Parent) case StmtKind::Id: return #Id; #include "swift/AST/StmtNodes.def" } llvm_unreachable("bad StmtKind"); } StringRef Stmt::getDescriptiveKindName(StmtKind K) { switch (K) { case StmtKind::Brace: return "brace"; case StmtKind::Return: return "return"; case StmtKind::Yield: return "yield"; case StmtKind::Then: return "then"; case StmtKind::Defer: return "defer"; case StmtKind::If: return "if"; case StmtKind::Guard: return "guard"; case StmtKind::While: return "while"; case StmtKind::Do: return "do"; case StmtKind::DoCatch: return "do-catch"; case StmtKind::RepeatWhile: return "repeat-while"; case StmtKind::ForEach: return "for-in"; case StmtKind::Switch: return "switch"; case StmtKind::Case: return "case"; case StmtKind::Break: return "break"; case StmtKind::Continue: return "continue"; case StmtKind::Fallthrough: return "fallthrough"; case StmtKind::Fail: return "return"; case StmtKind::Throw: return "throw"; case StmtKind::Discard: return "discard"; case StmtKind::PoundAssert: return "#assert"; } llvm_unreachable("Unhandled case in switch!"); } // Helper functions to check statically whether a method has been // overridden from its implementation in Stmt. The sort of thing you // need when you're avoiding v-tables. namespace { template constexpr bool isOverriddenFromStmt(ReturnType (Class::*)() const) { return true; } template constexpr bool isOverriddenFromStmt(ReturnType (Stmt::*)() const) { return false; } template struct Dispatch; /// Dispatch down to a concrete override. template <> struct Dispatch { template static SourceLoc getStartLoc(const T *S) { return S->getStartLoc(); } template static SourceLoc getEndLoc(const T *S) { return S->getEndLoc(); } template static SourceRange getSourceRange(const T *S) { return S->getSourceRange(); } }; /// Default implementations for when a method isn't overridden. template <> struct Dispatch { template static SourceLoc getStartLoc(const T *S) { return S->getSourceRange().Start; } template static SourceLoc getEndLoc(const T *S) { return S->getSourceRange().End; } template static SourceRange getSourceRange(const T *S) { return { S->getStartLoc(), S->getEndLoc() }; } }; } // end anonymous namespace template static SourceRange getSourceRangeImpl(const T *S) { static_assert(isOverriddenFromStmt(&T::getSourceRange) || (isOverriddenFromStmt(&T::getStartLoc) && isOverriddenFromStmt(&T::getEndLoc)), "Stmt subclass must implement either getSourceRange() " "or getStartLoc()/getEndLoc()"); return Dispatch::getSourceRange(S); } SourceRange Stmt::getSourceRange() const { switch (getKind()) { #define STMT(ID, PARENT) \ case StmtKind::ID: return getSourceRangeImpl(cast(this)); #include "swift/AST/StmtNodes.def" } llvm_unreachable("statement type not handled!"); } template static SourceLoc getStartLocImpl(const T *S) { return Dispatch::getStartLoc(S); } SourceLoc Stmt::getStartLoc() const { switch (getKind()) { #define STMT(ID, PARENT) \ case StmtKind::ID: return getStartLocImpl(cast(this)); #include "swift/AST/StmtNodes.def" } llvm_unreachable("statement type not handled!"); } template static SourceLoc getEndLocImpl(const T *S) { return Dispatch::getEndLoc(S); } SourceLoc Stmt::getEndLoc() const { switch (getKind()) { #define STMT(ID, PARENT) \ case StmtKind::ID: return getEndLocImpl(cast(this)); #include "swift/AST/StmtNodes.def" } llvm_unreachable("statement type not handled!"); } BraceStmt::BraceStmt(SourceLoc lbloc, ArrayRef elts, SourceLoc rbloc, std::optional implicit) : Stmt(StmtKind::Brace, getDefaultImplicitFlag(implicit, lbloc)), LBLoc(lbloc), RBLoc(rbloc) { Bits.BraceStmt.NumElements = elts.size(); std::uninitialized_copy(elts.begin(), elts.end(), getTrailingObjects()); #ifndef NDEBUG for (auto elt : elts) if (auto *decl = elt.dyn_cast()) assert(!isa(decl) && "accessors should not be added here"); #endif } BraceStmt *BraceStmt::create(ASTContext &ctx, SourceLoc lbloc, ArrayRef elts, SourceLoc rbloc, std::optional implicit) { assert(std::none_of(elts.begin(), elts.end(), [](ASTNode node) -> bool { return node.isNull(); }) && "null element in BraceStmt"); void *Buffer = ctx.Allocate(totalSizeToAlloc(elts.size()), alignof(BraceStmt)); return ::new(Buffer) BraceStmt(lbloc, elts, rbloc, implicit); } SourceLoc BraceStmt::getStartLoc() const { if (LBLoc) { return LBLoc; } return getContentStartLoc(); } SourceLoc BraceStmt::getEndLoc() const { if (RBLoc) { return RBLoc; } return getContentEndLoc(); } SourceLoc BraceStmt::getContentStartLoc() const { for (auto elt : getElements()) { if (auto loc = elt.getStartLoc()) { return loc; } } return SourceLoc(); } SourceLoc BraceStmt::getContentEndLoc() const { for (auto elt : llvm::reverse(getElements())) { if (auto loc = elt.getEndLoc()) { return loc; } } return SourceLoc(); } ASTNode BraceStmt::findAsyncNode() { // TODO: Statements don't track their ASTContext/evaluator, so I am not making // this a request. It probably should be a request at some point. // // While we're at it, it would be very nice if this could be a const // operation, but the AST-walking is not a const operation. // A walker that looks for 'async' and 'await' expressions // that aren't nested within closures or nested declarations. class FindInnerAsync : public ASTWalker { ASTNode AsyncNode; /// Walk only the macro arguments. MacroWalking getMacroWalkingBehavior() const override { return MacroWalking::Arguments; } PreWalkResult walkToExprPre(Expr *expr) override { // If we've found an 'await', record it and terminate the traversal. if (isa(expr)) { AsyncNode = expr; return Action::Stop(); } // Do not recurse into other closures. if (isa(expr)) return Action::SkipNode(expr); return Action::Continue(expr); } PreWalkAction walkToDeclPre(Decl *decl) override { // Do not walk into function or type declarations. if (auto *patternBinding = dyn_cast(decl)) { if (patternBinding->isAsyncLet()) AsyncNode = patternBinding; return Action::Continue(); } return Action::SkipNode(); } PreWalkResult walkToStmtPre(Stmt *stmt) override { if (auto forEach = dyn_cast(stmt)) { if (forEach->getAwaitLoc().isValid()) { AsyncNode = forEach; return Action::Stop(); } } return Action::Continue(stmt); } public: ASTNode getAsyncNode() { return AsyncNode; } }; FindInnerAsync asyncFinder; walk(asyncFinder); return asyncFinder.getAsyncNode(); } static bool hasSingleActiveElement(ArrayRef elts) { return elts.size() == 1; } ASTNode BraceStmt::getSingleActiveElement() const { return hasSingleActiveElement(getElements()) ? getLastElement() : nullptr; } Expr *BraceStmt::getSingleActiveExpression() const { return getSingleActiveElement().dyn_cast(); } Stmt *BraceStmt::getSingleActiveStatement() const { return getSingleActiveElement().dyn_cast(); } IsSingleValueStmtResult Stmt::mayProduceSingleValue(ASTContext &ctx) const { return evaluateOrDefault(ctx.evaluator, IsSingleValueStmtRequest{this, &ctx}, IsSingleValueStmtResult::circularReference()); } SourceLoc ReturnStmt::getStartLoc() const { if (ReturnLoc.isInvalid() && Result) return Result->getStartLoc(); return ReturnLoc; } SourceLoc ReturnStmt::getEndLoc() const { if (Result && Result->getEndLoc().isValid()) return Result->getEndLoc(); return ReturnLoc; } YieldStmt *YieldStmt::create(const ASTContext &ctx, SourceLoc yieldLoc, SourceLoc lpLoc, ArrayRef yields, SourceLoc rpLoc, std::optional implicit) { void *buffer = ctx.Allocate(totalSizeToAlloc(yields.size()), alignof(YieldStmt)); return ::new(buffer) YieldStmt(yieldLoc, lpLoc, yields, rpLoc, implicit); } SourceLoc YieldStmt::getEndLoc() const { return RPLoc.isInvalid() ? getYields()[0]->getEndLoc() : RPLoc; } ThenStmt *ThenStmt::createParsed(ASTContext &ctx, SourceLoc thenLoc, Expr *result) { return new (ctx) ThenStmt(thenLoc, result, /*isImplicit*/ false); } ThenStmt *ThenStmt::createImplicit(ASTContext &ctx, Expr *result) { return new (ctx) ThenStmt(SourceLoc(), result, /*isImplicit*/ true); } SourceRange ThenStmt::getSourceRange() const { return SourceRange::combine(ThenLoc, getResult()->getSourceRange()); } SourceLoc ThrowStmt::getEndLoc() const { return SubExpr->getEndLoc(); } SourceLoc DiscardStmt::getEndLoc() const { return SubExpr->getEndLoc(); } DeferStmt *DeferStmt::create(DeclContext *dc, SourceLoc deferLoc) { ASTContext &ctx = dc->getASTContext(); auto params = ParameterList::createEmpty(ctx); DeclName name(ctx, ctx.getIdentifier("$defer"), params); auto *const funcDecl = FuncDecl::createImplicit( ctx, StaticSpellingKind::None, name, /*NameLoc=*/deferLoc, /*Async=*/false, /*Throws=*/false, /*ThrownType=*/Type(), /*GenericParams=*/nullptr, params, TupleType::getEmpty(ctx), dc); // Form the call, which will be emitted on any path that needs to run the // code. auto DRE = new (ctx) DeclRefExpr(funcDecl, DeclNameLoc(deferLoc), /*Implicit*/ true, AccessSemantics::DirectToStorage); auto call = CallExpr::createImplicitEmpty(ctx, DRE); return new (ctx) DeferStmt(deferLoc, funcDecl, call); } SourceLoc DeferStmt::getEndLoc() const { return tempDecl->getBody()->getEndLoc(); } /// Dig the original user's body of the defer out for AST fidelity. BraceStmt *DeferStmt::getBodyAsWritten() const { return tempDecl->getBody(); } bool LabeledStmt::isPossibleContinueTarget() const { switch (getKind()) { #define LABELED_STMT(ID, PARENT) #define STMT(ID, PARENT) case StmtKind::ID: #include "swift/AST/StmtNodes.def" llvm_unreachable("not a labeled statement"); // Sema has diagnostics with hard-coded expectations about what // statements return false from this method. case StmtKind::If: case StmtKind::Guard: case StmtKind::Switch: return false; case StmtKind::Do: case StmtKind::DoCatch: case StmtKind::RepeatWhile: case StmtKind::ForEach: case StmtKind::While: return true; } llvm_unreachable("statement kind unhandled!"); } bool LabeledStmt::requiresLabelOnJump() const { switch (getKind()) { #define LABELED_STMT(ID, PARENT) #define STMT(ID, PARENT) case StmtKind::ID: #include "swift/AST/StmtNodes.def" llvm_unreachable("not a labeled statement"); case StmtKind::If: case StmtKind::Do: case StmtKind::DoCatch: case StmtKind::Guard: // Guard doesn't allow labels, so no break/continue. return true; case StmtKind::RepeatWhile: case StmtKind::ForEach: case StmtKind::Switch: case StmtKind::While: return false; } llvm_unreachable("statement kind unhandled!"); } void ForEachStmt::setPattern(Pattern *p) { Pat = p; Pat->markOwnedByStatement(this); } Expr *ForEachStmt::getTypeCheckedSequence() const { if (auto *expansion = dyn_cast(getParsedSequence())) return expansion; return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr; } DoCatchStmt *DoCatchStmt::create(DeclContext *dc, LabeledStmtInfo labelInfo, SourceLoc doLoc, SourceLoc throwsLoc, TypeLoc thrownType, Stmt *body, ArrayRef catches, std::optional implicit) { ASTContext &ctx = dc->getASTContext(); void *mem = ctx.Allocate(totalSizeToAlloc(catches.size()), alignof(DoCatchStmt)); return ::new (mem) DoCatchStmt(dc, labelInfo, doLoc, throwsLoc, thrownType, body, catches, implicit); } bool CaseLabelItem::isSyntacticallyExhaustive() const { return getGuardExpr() == nullptr && !getPattern()->isRefutablePattern(); } bool DoCatchStmt::isSyntacticallyExhaustive() const { for (auto clause : getCatches()) { for (auto &LabelItem : clause->getCaseLabelItems()) { if (LabelItem.isSyntacticallyExhaustive()) return true; } } return false; } Type DoCatchStmt::getExplicitCaughtType() const { ASTContext &ctx = DC->getASTContext(); return CatchNode(const_cast(this)).getExplicitCaughtType(ctx); } Type DoCatchStmt::getCaughtErrorType() const { // Check for an explicitly-specified error type. if (Type explicitError = getExplicitCaughtType()) return explicitError; auto firstPattern = getCatches() .front() ->getCaseLabelItems() .front() .getPattern(); if (firstPattern->hasType()) return firstPattern->getType(); return Type(); } void LabeledConditionalStmt::setCond(StmtCondition e) { // When set a condition into a Conditional Statement, inform each of the // variables bound in any patterns that this is the owning statement for the // pattern. for (auto &elt : e) if (auto pat = elt.getPatternOrNull()) pat->markOwnedByStatement(this); Cond = e; } /// Whether or not this conditional stmt rebinds self with a `let self` /// or `let self = self` condition. /// - If `requiresCaptureListRef` is `true`, additionally requires that the /// RHS of the self condition references a var defined in a capture list. /// - If `requireLoadExpr` is `true`, additionally requires that the RHS of /// the self condition is a `LoadExpr`. bool LabeledConditionalStmt::rebindsSelf(ASTContext &Ctx, bool requiresCaptureListRef, bool requireLoadExpr) const { return llvm::any_of(getCond(), [&Ctx, requiresCaptureListRef, requireLoadExpr](const auto &cond) { return cond.rebindsSelf(Ctx, requiresCaptureListRef, requireLoadExpr); }); } /// Whether or not this conditional stmt rebinds self with a `let self` /// or `let self = self` condition. /// - If `requiresCaptureListRef` is `true`, additionally requires that the /// RHS of the self condition references a var defined in a capture list. /// - If `requireLoadExpr` is `true`, additionally requires that the RHS of /// the self condition is a `LoadExpr`. bool StmtConditionElement::rebindsSelf(ASTContext &Ctx, bool requiresCaptureListRef, bool requireLoadExpr) const { auto pattern = getPatternOrNull(); if (!pattern) { return false; } // Check whether or not this pattern defines a new `self` decl bool isSelfRebinding = false; if (pattern->getBoundName() == Ctx.Id_self) { isSelfRebinding = true; } else if (auto OSP = dyn_cast(pattern)) { if (auto subPattern = OSP->getSubPattern()) { isSelfRebinding = subPattern->getBoundName() == Ctx.Id_self; } } if (!isSelfRebinding) { return false; } // Check that the RHS expr is exactly `self` and not something else Expr *exprToCheckForDRE = getInitializerOrNull(); if (!exprToCheckForDRE) { return false; } if (requireLoadExpr && !isa(exprToCheckForDRE)) { return false; } if (auto *load = dyn_cast(exprToCheckForDRE)) { exprToCheckForDRE = load->getSubExpr(); } if (auto *DRE = dyn_cast( exprToCheckForDRE->getSemanticsProvidingExpr())) { auto *decl = DRE->getDecl(); bool definedInCaptureList = false; if (auto varDecl = dyn_cast_or_null(DRE->getDecl())) { definedInCaptureList = varDecl->isCaptureList(); } if (requiresCaptureListRef && !definedInCaptureList) { return false; } bool isVariableNamedSelf = false; if (decl && decl->hasName()) { isVariableNamedSelf = decl->getName().isSimpleName(Ctx.Id_self); } return isVariableNamedSelf; } return false; } SourceRange ConditionalPatternBindingInfo::getSourceRange() const { SourceLoc Start; if (IntroducerLoc.isValid()) Start = IntroducerLoc; else Start = ThePattern->getStartLoc(); SourceLoc End = Initializer->getEndLoc(); if (Start.isValid() && End.isValid()) { return SourceRange(Start, End); } else { return SourceRange(); } } PoundAvailableInfo * PoundAvailableInfo::create(ASTContext &ctx, SourceLoc PoundLoc, SourceLoc LParenLoc, ArrayRef queries, SourceLoc RParenLoc, bool isUnavailability) { unsigned size = totalSizeToAlloc(queries.size()); void *Buffer = ctx.Allocate(size, alignof(PoundAvailableInfo)); return ::new (Buffer) PoundAvailableInfo(PoundLoc, LParenLoc, queries, RParenLoc, isUnavailability); } SemanticAvailabilitySpecs PoundAvailableInfo::getSemanticAvailabilitySpecs( const DeclContext *declContext) const { return SemanticAvailabilitySpecs(getQueries(), declContext); } SourceLoc PoundAvailableInfo::getEndLoc() const { if (RParenLoc.isInvalid()) { if (NumQueries == 0) { if (LParenLoc.isInvalid()) return PoundLoc; return LParenLoc; } return getQueries()[NumQueries - 1]->getSourceRange().End; } return RParenLoc; } SourceRange StmtConditionElement::getSourceRange() const { switch (getKind()) { case StmtConditionElement::CK_Boolean: return getBoolean()->getSourceRange(); case StmtConditionElement::CK_Availability: return getAvailability()->getSourceRange(); case StmtConditionElement::CK_HasSymbol: return getHasSymbolInfo()->getSourceRange(); case StmtConditionElement::CK_PatternBinding: return getPatternBinding()->getSourceRange(); } llvm_unreachable("Unhandled StmtConditionElement in switch."); } PoundHasSymbolInfo *PoundHasSymbolInfo::create(ASTContext &Ctx, SourceLoc PoundLoc, SourceLoc LParenLoc, Expr *SymbolExpr, SourceLoc RParenLoc) { return new (Ctx) PoundHasSymbolInfo(PoundLoc, LParenLoc, SymbolExpr, RParenLoc); } SourceLoc StmtConditionElement::getStartLoc() const { switch (getKind()) { case StmtConditionElement::CK_Boolean: return getBoolean()->getStartLoc(); case StmtConditionElement::CK_Availability: return getAvailability()->getStartLoc(); case StmtConditionElement::CK_PatternBinding: return getPatternBinding()->getStartLoc(); case StmtConditionElement::CK_HasSymbol: return getHasSymbolInfo()->getStartLoc(); } llvm_unreachable("Unhandled StmtConditionElement in switch."); } SourceLoc StmtConditionElement::getEndLoc() const { switch (getKind()) { case StmtConditionElement::CK_Boolean: return getBoolean()->getEndLoc(); case StmtConditionElement::CK_Availability: return getAvailability()->getEndLoc(); case StmtConditionElement::CK_PatternBinding: return getPatternBinding()->getEndLoc(); case StmtConditionElement::CK_HasSymbol: return getHasSymbolInfo()->getEndLoc(); } llvm_unreachable("Unhandled StmtConditionElement in switch."); } static StmtCondition exprToCond(Expr *C, ASTContext &Ctx) { StmtConditionElement Arr[] = { StmtConditionElement(C) }; return Ctx.AllocateCopy(Arr); } IfStmt::IfStmt(SourceLoc IfLoc, Expr *Cond, BraceStmt *Then, SourceLoc ElseLoc, Stmt *Else, std::optional implicit, ASTContext &Ctx) : IfStmt(LabeledStmtInfo(), IfLoc, exprToCond(Cond, Ctx), Then, ElseLoc, Else, implicit) {} ArrayRef IfStmt::getBranches(SmallVectorImpl &scratch) const { assert(scratch.empty()); scratch.push_back(getThenStmt()); auto *elseBranch = getElseStmt(); while (elseBranch) { if (auto *IS = dyn_cast(elseBranch)) { // Look through else ifs. elseBranch = IS->getElseStmt(); scratch.push_back(IS->getThenStmt()); continue; } // An unconditional else, we're done. scratch.push_back(elseBranch); break; } return scratch; } bool IfStmt::isSyntacticallyExhaustive() const { auto *elseBranch = getElseStmt(); while (elseBranch) { // Look through else ifs. if (auto *IS = dyn_cast(elseBranch)) { elseBranch = IS->getElseStmt(); continue; } // An unconditional else. return true; } return false; } GuardStmt::GuardStmt(SourceLoc GuardLoc, Expr *Cond, BraceStmt *Body, std::optional implicit, ASTContext &Ctx) : GuardStmt(GuardLoc, exprToCond(Cond, Ctx), Body, implicit) {} SourceLoc RepeatWhileStmt::getEndLoc() const { return Cond->getEndLoc(); } SourceRange CaseLabelItem::getSourceRange() const { if (auto *E = getGuardExpr()) return { getPattern()->getStartLoc(), E->getEndLoc() }; return getPattern()->getSourceRange(); } SourceLoc CaseLabelItem::getStartLoc() const { return getPattern()->getStartLoc(); } SourceLoc CaseLabelItem::getEndLoc() const { if (auto *E = getGuardExpr()) return E->getEndLoc(); return getPattern()->getEndLoc(); } CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc, ArrayRef caseLabelItems, SourceLoc unknownAttrLoc, SourceLoc itemTerminatorLoc, BraceStmt *body, std::optional> caseBodyVariables, std::optional implicit, NullablePtr fallthroughStmt) : Stmt(StmtKind::Case, getDefaultImplicitFlag(implicit, itemIntroducerLoc)), UnknownAttrLoc(unknownAttrLoc), ItemIntroducerLoc(itemIntroducerLoc), ItemTerminatorLoc(itemTerminatorLoc), ParentKind(parentKind), BodyAndHasFallthrough(body, fallthroughStmt.isNonNull()), CaseBodyVariables(caseBodyVariables) { Bits.CaseStmt.NumPatterns = caseLabelItems.size(); assert(Bits.CaseStmt.NumPatterns > 0 && "case block must have at least one pattern"); assert( !(parentKind == CaseParentKind::DoCatch && fallthroughStmt.isNonNull()) && "Only switch cases can have a fallthrough."); if (hasFallthroughDest()) { *getTrailingObjects() = fallthroughStmt.get(); } MutableArrayRef items{getTrailingObjects(), static_cast(Bits.CaseStmt.NumPatterns)}; // At the beginning mark all of our var decls as being owned by this // statement. In the typechecker we wireup the case stmt var decl list since // we know everything is lined up/typechecked then. for (unsigned i : range(Bits.CaseStmt.NumPatterns)) { new (&items[i]) CaseLabelItem(caseLabelItems[i]); items[i].getPattern()->markOwnedByStatement(this); } for (auto *vd : caseBodyVariables.value_or(MutableArrayRef())) { vd->setParentPatternStmt(this); } } namespace { static MutableArrayRef getCaseVarDecls(ASTContext &ctx, ArrayRef labelItems) { // Grab the first case label item pattern and use it to initialize the case // body var decls. SmallVector tmp; labelItems.front().getPattern()->collectVariables(tmp); return ctx.AllocateTransform( llvm::ArrayRef(tmp), [&](VarDecl *vOld) -> VarDecl * { auto *vNew = new (ctx) VarDecl( /*IsStatic*/ false, vOld->getIntroducer(), vOld->getNameLoc(), vOld->getName(), vOld->getDeclContext()); vNew->setImplicit(); return vNew; }); } struct FallthroughFinder : ASTWalker { FallthroughStmt *result; FallthroughFinder() : result(nullptr) {} MacroWalking getMacroWalkingBehavior() const override { return MacroWalking::Arguments; } // We walk through statements. If we find a fallthrough, then we got what // we came for. PreWalkResult walkToStmtPre(Stmt *s) override { if (auto *f = dyn_cast(s)) { result = f; } return Action::Continue(s); } // Expressions, patterns and decls cannot contain fallthrough statements, so // there is no reason to walk into them. PreWalkResult walkToExprPre(Expr *e) override { return Action::SkipNode(e); } PreWalkResult walkToPatternPre(Pattern *p) override { return Action::SkipNode(p); } PreWalkAction walkToDeclPre(Decl *d) override { return Action::SkipNode(); } PreWalkAction walkToTypeReprPre(TypeRepr *t) override { return Action::SkipNode(); } static FallthroughStmt *findFallthrough(Stmt *s) { FallthroughFinder finder; s->walk(finder); return finder.result; } }; } // namespace CaseStmt * CaseStmt::createParsedSwitchCase(ASTContext &ctx, SourceLoc introducerLoc, ArrayRef caseLabelItems, SourceLoc unknownAttrLoc, SourceLoc colonLoc, BraceStmt *body) { auto caseVarDecls = getCaseVarDecls(ctx, caseLabelItems); auto fallthroughStmt = FallthroughFinder().findFallthrough(body); return create(ctx, CaseParentKind::Switch, introducerLoc, caseLabelItems, unknownAttrLoc, colonLoc, body, caseVarDecls, /*implicit=*/false, fallthroughStmt); } CaseStmt *CaseStmt::createParsedDoCatch(ASTContext &ctx, SourceLoc catchLoc, ArrayRef caseLabelItems, BraceStmt *body) { auto caseVarDecls = getCaseVarDecls(ctx, caseLabelItems); return create(ctx, CaseParentKind::DoCatch, catchLoc, caseLabelItems, /*unknownAttrLoc=*/SourceLoc(), body->getStartLoc(), body, caseVarDecls, /*implicit=*/false, /*fallthroughStmt=*/nullptr); } CaseStmt * CaseStmt::create(ASTContext &ctx, CaseParentKind ParentKind, SourceLoc caseLoc, ArrayRef caseLabelItems, SourceLoc unknownAttrLoc, SourceLoc colonLoc, BraceStmt *body, std::optional> caseVarDecls, std::optional implicit, NullablePtr fallthroughStmt) { void *mem = ctx.Allocate(totalSizeToAlloc( fallthroughStmt.isNonNull(), caseLabelItems.size()), alignof(CaseStmt)); return ::new (mem) CaseStmt(ParentKind, caseLoc, caseLabelItems, unknownAttrLoc, colonLoc, body, caseVarDecls, implicit, fallthroughStmt); } DoStmt *DoStmt::createImplicit(ASTContext &C, LabeledStmtInfo labelInfo, ArrayRef body) { return new (C) DoStmt(labelInfo, /*doLoc=*/SourceLoc(), BraceStmt::createImplicit(C, body), /*implicit=*/true); } SourceLoc DoStmt::getStartLoc() const { if (auto LabelOrDoLoc = getLabelLocOrKeywordLoc(DoLoc)) { return LabelOrDoLoc; } return Body->getStartLoc(); } SourceLoc DoStmt::getEndLoc() const { return Body->getEndLoc(); } namespace { template CaseStmt *findNextCaseStmt( CaseIterator first, CaseIterator last, const CaseStmt *caseStmt) { for(auto caseIter = first; caseIter != last; ++caseIter) { if (*caseIter == caseStmt) { ++caseIter; return caseIter == last ? nullptr : *caseIter; } } return nullptr; } } CaseStmt *CaseStmt::findNextCaseStmt() const { auto parent = getParentStmt(); if (!parent) return nullptr; if (auto switchParent = dyn_cast(parent)) { return ::findNextCaseStmt( switchParent->getCases().begin(), switchParent->getCases().end(), this); } auto doCatchParent = cast(parent); return ::findNextCaseStmt( doCatchParent->getCatches().begin(), doCatchParent->getCatches().end(), this); } SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc, Expr *SubjectExpr, SourceLoc LBraceLoc, ArrayRef Cases, SourceLoc RBraceLoc, SourceLoc EndLoc, ASTContext &C) { void *p = C.Allocate(totalSizeToAlloc(Cases.size()), alignof(SwitchStmt)); SwitchStmt *theSwitch = ::new (p) SwitchStmt(LabelInfo, SwitchLoc, SubjectExpr, LBraceLoc, Cases.size(), RBraceLoc, EndLoc); std::uninitialized_copy(Cases.begin(), Cases.end(), theSwitch->getTrailingObjects()); for (auto *caseStmt : theSwitch->getCases()) caseStmt->setParentStmt(theSwitch); return theSwitch; } LabeledStmt *BreakStmt::getTarget() const { auto &eval = getDeclContext()->getASTContext().evaluator; return evaluateOrDefault(eval, BreakTargetRequest{this}, nullptr); } LabeledStmt *ContinueStmt::getTarget() const { auto &eval = getDeclContext()->getASTContext().evaluator; return evaluateOrDefault(eval, ContinueTargetRequest{this}, nullptr); } FallthroughStmt *FallthroughStmt::createParsed(SourceLoc Loc, DeclContext *DC) { auto &ctx = DC->getASTContext(); return new (ctx) FallthroughStmt(Loc, DC); } CaseStmt *FallthroughStmt::getFallthroughSource() const { auto &eval = getDeclContext()->getASTContext().evaluator; return evaluateOrDefault(eval, FallthroughSourceAndDestRequest{this}, {}) .Source; } CaseStmt *FallthroughStmt::getFallthroughDest() const { auto &eval = getDeclContext()->getASTContext().evaluator; return evaluateOrDefault(eval, FallthroughSourceAndDestRequest{this}, {}) .Dest; } SourceLoc swift::extractNearestSourceLoc(const Stmt *S) { return S->getStartLoc(); } ArrayRef SwitchStmt::getBranches(SmallVectorImpl &scratch) const { assert(scratch.empty()); for (auto *CS : getCases()) scratch.push_back(CS->getBody()); return scratch; } ArrayRef DoCatchStmt::getBranches(SmallVectorImpl &scratch) const { assert(scratch.empty()); scratch.push_back(getBody()); for (auto *CS : getCatches()) scratch.push_back(CS->getBody()); return scratch; } // See swift/Basic/Statistic.h for declaration: this enables tracing Stmts, is // defined here to avoid too much layering violation / circular linkage // dependency. struct StmtTraceFormatter : public UnifiedStatsReporter::TraceFormatter { void traceName(const void *Entity, raw_ostream &OS) const override { if (!Entity) return; const Stmt *S = static_cast(Entity); OS << Stmt::getKindName(S->getKind()); } void traceLoc(const void *Entity, SourceManager *SM, clang::SourceManager *CSM, raw_ostream &OS) const override { if (!Entity) return; const Stmt *S = static_cast(Entity); S->getSourceRange().print(OS, *SM, false); } }; static StmtTraceFormatter TF; template<> const UnifiedStatsReporter::TraceFormatter* FrontendStatsTracer::getTraceFormatter() { return &TF; }