//===--- Refactoring.cpp ---------------------------------------------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// #include "swift/Refactoring/Refactoring.h" #include "swift/AST/ASTContext.h" #include "swift/AST/ASTPrinter.h" #include "swift/AST/Decl.h" #include "swift/AST/DiagnosticsRefactoring.h" #include "swift/AST/Expr.h" #include "swift/AST/ForeignAsyncConvention.h" #include "swift/AST/GenericParamList.h" #include "swift/AST/NameLookup.h" #include "swift/AST/Pattern.h" #include "swift/AST/ProtocolConformance.h" #include "swift/AST/Stmt.h" #include "swift/AST/Types.h" #include "swift/AST/USRGeneration.h" #include "swift/Basic/Edit.h" #include "swift/Basic/StringExtras.h" #include "swift/ClangImporter/ClangImporter.h" #include "swift/Frontend/Frontend.h" #include "swift/IDE/IDERequests.h" #include "swift/Index/Index.h" #include "swift/Parse/Lexer.h" #include "swift/Sema/IDETypeChecking.h" #include "swift/Subsystems.h" #include "clang/Rewrite/Core/RewriteBuffer.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringSet.h" using namespace swift; using namespace swift::ide; using namespace swift::index; namespace { class ContextFinder : public SourceEntityWalker { SourceFile &SF; ASTContext &Ctx; SourceManager &SM; SourceRange Target; std::function IsContext; SmallVector AllContexts; bool contains(ASTNode Enclosing) { auto Result = SM.rangeContains(Enclosing.getSourceRange(), Target); if (Result && IsContext(Enclosing)) AllContexts.push_back(Enclosing); return Result; } public: ContextFinder(SourceFile &SF, ASTNode TargetNode, std::function IsContext = [](ASTNode N) { return true; }) : SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr), Target(TargetNode.getSourceRange()), IsContext(IsContext) {} ContextFinder(SourceFile &SF, SourceLoc TargetLoc, std::function IsContext = [](ASTNode N) { return true; }) : SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr), Target(TargetLoc), IsContext(IsContext) { assert(TargetLoc.isValid() && "Invalid loc to find"); } bool walkToDeclPre(Decl *D, CharSourceRange Range) override { return contains(D); } bool walkToStmtPre(Stmt *S) override { return contains(S); } bool walkToExprPre(Expr *E) override { return contains(E); } void resolve() { walk(SF); } ArrayRef getContexts() const { return llvm::makeArrayRef(AllContexts); } }; class Renamer { protected: const SourceManager &SM; protected: Renamer(const SourceManager &SM, StringRef OldName) : SM(SM), Old(OldName) {} // Implementor's interface. virtual void doRenameLabel(CharSourceRange Label, RefactoringRangeKind RangeKind, unsigned NameIndex) = 0; virtual void doRenameBase(CharSourceRange Range, RefactoringRangeKind RangeKind) = 0; public: const DeclNameViewer Old; public: virtual ~Renamer() {} /// Adds a replacement to rename the given base name range /// \return true if the given range does not match the old name bool renameBase(CharSourceRange Range, RefactoringRangeKind RangeKind) { assert(Range.isValid()); if (stripBackticks(Range).str() != Old.base()) return true; doRenameBase(Range, RangeKind); return false; } /// Adds replacements to rename the given label ranges /// \return true if the label ranges do not match the old name bool renameLabels(ArrayRef LabelRanges, Optional FirstTrailingLabel, LabelRangeType RangeType, bool isCallSite) { if (isCallSite) return renameLabelsLenient(LabelRanges, FirstTrailingLabel, RangeType); assert(!FirstTrailingLabel); ArrayRef OldLabels = Old.args(); if (OldLabels.size() != LabelRanges.size()) return true; size_t Index = 0; for (const auto &LabelRange : LabelRanges) { assert(LabelRange.isValid()); if (!labelRangeMatches(LabelRange, RangeType, OldLabels[Index])) return true; splitAndRenameLabel(LabelRange, RangeType, Index++); } return false; } bool isOperator() const { return Lexer::isOperator(Old.base()); } private: /// Returns the range of the (possibly escaped) identifier at the start of /// \p Range and updates \p IsEscaped to indicate whether it's escaped or not. CharSourceRange getLeadingIdentifierRange(CharSourceRange Range, bool &IsEscaped) { assert(Range.isValid() && Range.getByteLength()); IsEscaped = Range.str().front() == '`'; SourceLoc Start = Range.getStart(); if (IsEscaped) Start = Start.getAdvancedLoc(1); return Lexer::getCharSourceRangeFromSourceRange(SM, Start); } CharSourceRange stripBackticks(CharSourceRange Range) { StringRef Content = Range.str(); if (Content.size() < 3 || Content.front() != '`' || Content.back() != '`') { return Range; } return CharSourceRange(Range.getStart().getAdvancedLoc(1), Range.getByteLength() - 2); } void splitAndRenameLabel(CharSourceRange Range, LabelRangeType RangeType, size_t NameIndex) { switch (RangeType) { case LabelRangeType::CallArg: return splitAndRenameCallArg(Range, NameIndex); case LabelRangeType::Param: return splitAndRenameParamLabel(Range, NameIndex, /*IsCollapsible=*/true); case LabelRangeType::NoncollapsibleParam: return splitAndRenameParamLabel(Range, NameIndex, /*IsCollapsible=*/false); case LabelRangeType::Selector: return doRenameLabel( Range, RefactoringRangeKind::SelectorArgumentLabel, NameIndex); case LabelRangeType::None: llvm_unreachable("expected a label range"); } } void splitAndRenameParamLabel(CharSourceRange Range, size_t NameIndex, bool IsCollapsible) { // Split parameter range foo([a b]: Int) into decl argument label [a] and // parameter name [b] or noncollapsible parameter name [b] if IsCollapsible // is false (as for subscript decls). If we have only foo([a]: Int), then we // add an empty range for the local name, or for the decl argument label if // IsCollapsible is false. StringRef Content = Range.str(); size_t ExternalNameEnd = Content.find_first_of(" \t\n\v\f\r/"); if (ExternalNameEnd == StringRef::npos) { // foo([a]: Int) if (IsCollapsible) { doRenameLabel(Range, RefactoringRangeKind::DeclArgumentLabel, NameIndex); doRenameLabel(CharSourceRange{Range.getEnd(), 0}, RefactoringRangeKind::ParameterName, NameIndex); } else { doRenameLabel(CharSourceRange{Range.getStart(), 0}, RefactoringRangeKind::DeclArgumentLabel, NameIndex); doRenameLabel(Range, RefactoringRangeKind::NoncollapsibleParameterName, NameIndex); } } else { // foo([a b]: Int) CharSourceRange Ext{Range.getStart(), unsigned(ExternalNameEnd)}; // Note: we consider the leading whitespace part of the parameter name // if the parameter is collapsible, since if the parameter is collapsed // into a matching argument label, we want to remove the whitespace too. // FIXME: handle comments foo(a /*...*/b: Int). size_t LocalNameStart = Content.find_last_of(" \t\n\v\f\r/"); assert(LocalNameStart != StringRef::npos); if (!IsCollapsible) ++LocalNameStart; auto LocalLoc = Range.getStart().getAdvancedLocOrInvalid(LocalNameStart); CharSourceRange Local{LocalLoc, unsigned(Content.size() - LocalNameStart)}; doRenameLabel(Ext, RefactoringRangeKind::DeclArgumentLabel, NameIndex); if (IsCollapsible) { doRenameLabel(Local, RefactoringRangeKind::ParameterName, NameIndex); } else { doRenameLabel(Local, RefactoringRangeKind::NoncollapsibleParameterName, NameIndex); } } } void splitAndRenameCallArg(CharSourceRange Range, size_t NameIndex) { // Split call argument foo([a: ]1) into argument name [a] and the remainder // [: ]. StringRef Content = Range.str(); size_t Colon = Content.find(':'); // FIXME: leading whitespace? if (Colon == StringRef::npos) { assert(Content.empty()); doRenameLabel(Range, RefactoringRangeKind::CallArgumentCombined, NameIndex); return; } // Include any whitespace before the ':'. assert(Colon == Content.substr(0, Colon).size()); Colon = Content.substr(0, Colon).rtrim().size(); CharSourceRange Arg{Range.getStart(), unsigned(Colon)}; doRenameLabel(Arg, RefactoringRangeKind::CallArgumentLabel, NameIndex); auto ColonLoc = Range.getStart().getAdvancedLocOrInvalid(Colon); assert(ColonLoc.isValid()); CharSourceRange Rest{ColonLoc, unsigned(Content.size() - Colon)}; doRenameLabel(Rest, RefactoringRangeKind::CallArgumentColon, NameIndex); } bool labelRangeMatches(CharSourceRange Range, LabelRangeType RangeType, StringRef Expected) { if (Range.getByteLength()) { bool IsEscaped = false; CharSourceRange ExistingLabelRange = getLeadingIdentifierRange(Range, IsEscaped); StringRef ExistingLabel = ExistingLabelRange.str(); bool IsSingleName = Range == ExistingLabelRange || (IsEscaped && Range.getByteLength() == ExistingLabel.size() + 2); switch (RangeType) { case LabelRangeType::NoncollapsibleParam: if (IsSingleName && Expected.empty()) // subscript([x]: Int) return true; LLVM_FALLTHROUGH; case LabelRangeType::CallArg: case LabelRangeType::Param: case LabelRangeType::Selector: return ExistingLabel == (Expected.empty() ? "_" : Expected); case LabelRangeType::None: llvm_unreachable("Unhandled label range type"); } } return Expected.empty(); } bool renameLabelsLenient(ArrayRef LabelRanges, Optional FirstTrailingLabel, LabelRangeType RangeType) { ArrayRef OldNames = Old.args(); // First, match trailing closure arguments in reverse if (FirstTrailingLabel) { auto TrailingLabels = LabelRanges.drop_front(*FirstTrailingLabel); LabelRanges = LabelRanges.take_front(*FirstTrailingLabel); for (auto LabelIndex: llvm::reverse(indices(TrailingLabels))) { CharSourceRange Label = TrailingLabels[LabelIndex]; if (Label.getByteLength()) { if (OldNames.empty()) return true; while (!labelRangeMatches(Label, LabelRangeType::Selector, OldNames.back())) { if ((OldNames = OldNames.drop_back()).empty()) return true; } splitAndRenameLabel(Label, LabelRangeType::Selector, OldNames.size() - 1); OldNames = OldNames.drop_back(); continue; } // empty labelled trailing closure label if (LabelIndex) { if (OldNames.empty()) return true; while (!OldNames.back().empty()) { if ((OldNames = OldNames.drop_back()).empty()) return true; } splitAndRenameLabel(Label, LabelRangeType::Selector, OldNames.size() - 1); OldNames = OldNames.drop_back(); continue; } // unlabelled trailing closure label OldNames = OldNames.drop_back(); continue; } } // Next, match the non-trailing arguments. size_t NameIndex = 0; for (CharSourceRange Label : LabelRanges) { // empty label if (!Label.getByteLength()) { // first name pos if (!NameIndex) { while (!OldNames[NameIndex].empty()) { if (++NameIndex >= OldNames.size()) return true; } splitAndRenameLabel(Label, RangeType, NameIndex++); continue; } // other name pos if (NameIndex >= OldNames.size() || !OldNames[NameIndex].empty()) { // FIXME: only allow one variadic param continue; // allow for variadic } splitAndRenameLabel(Label, RangeType, NameIndex++); continue; } // non-empty label if (NameIndex >= OldNames.size()) return true; while (!labelRangeMatches(Label, RangeType, OldNames[NameIndex])) { if (++NameIndex >= OldNames.size()) return true; }; splitAndRenameLabel(Label, RangeType, NameIndex++); } return false; } static RegionType getSyntacticRenameRegionType(const ResolvedLoc &Resolved) { if (Resolved.Node.isNull()) return RegionType::Comment; if (Expr *E = Resolved.Node.getAsExpr()) { if (isa(E)) return RegionType::String; } if (Resolved.IsInSelector) return RegionType::Selector; if (Resolved.IsActive) return RegionType::ActiveCode; return RegionType::InactiveCode; } public: RegionType addSyntacticRenameRanges(const ResolvedLoc &Resolved, const RenameLoc &Config) { if (!Resolved.Range.isValid()) return RegionType::Unmatched; auto RegionKind = getSyntacticRenameRegionType(Resolved); // Don't include unknown references coming from active code; if we don't // have a semantic NameUsage for them, then they're likely unrelated symbols // that happen to have the same name. if (RegionKind == RegionType::ActiveCode && Config.Usage == NameUsage::Unknown) return RegionType::Unmatched; assert(Config.Usage != NameUsage::Call || Config.IsFunctionLike); // FIXME: handle escaped keyword names `init` bool IsSubscript = Old.base() == "subscript" && Config.IsFunctionLike; bool IsInit = Old.base() == "init" && Config.IsFunctionLike; // FIXME: this should only be treated specially for instance methods. bool IsCallAsFunction = Old.base() == "callAsFunction" && Config.IsFunctionLike; bool IsSpecialBase = IsInit || IsSubscript || IsCallAsFunction; // Filter out non-semantic special basename locations with no labels. // We've already filtered out those in active code, so these are // any appearance of just 'init', 'subscript', or 'callAsFunction' in // strings, comments, and inactive code. if (IsSpecialBase && (Config.Usage == NameUsage::Unknown && Resolved.LabelType == LabelRangeType::None)) return RegionType::Unmatched; if (!Config.IsFunctionLike || !IsSpecialBase) { if (renameBase(Resolved.Range, RefactoringRangeKind::BaseName)) return RegionType::Mismatch; } else if (IsInit || IsCallAsFunction) { if (renameBase(Resolved.Range, RefactoringRangeKind::KeywordBaseName)) { // The base name doesn't need to match (but may) for calls, but // it should for definitions and references. if (Config.Usage == NameUsage::Definition || Config.Usage == NameUsage::Reference) { return RegionType::Mismatch; } } } else if (IsSubscript && Config.Usage == NameUsage::Definition) { if (renameBase(Resolved.Range, RefactoringRangeKind::KeywordBaseName)) return RegionType::Mismatch; } bool HandleLabels = false; if (Config.IsFunctionLike) { switch (Config.Usage) { case NameUsage::Call: HandleLabels = !isOperator(); break; case NameUsage::Definition: HandleLabels = true; break; case NameUsage::Reference: HandleLabels = Resolved.LabelType == LabelRangeType::Selector || IsSubscript; break; case NameUsage::Unknown: HandleLabels = Resolved.LabelType != LabelRangeType::None; break; } } if (HandleLabels) { bool isCallSite = Config.Usage != NameUsage::Definition && (Config.Usage != NameUsage::Reference || IsSubscript) && Resolved.LabelType == LabelRangeType::CallArg; if (renameLabels(Resolved.LabelRanges, Resolved.FirstTrailingLabel, Resolved.LabelType, isCallSite)) return Config.Usage == NameUsage::Unknown ? RegionType::Unmatched : RegionType::Mismatch; } return RegionKind; } }; class RenameRangeDetailCollector : public Renamer { void doRenameLabel(CharSourceRange Label, RefactoringRangeKind RangeKind, unsigned NameIndex) override { Ranges.push_back({Label, RangeKind, NameIndex}); } void doRenameBase(CharSourceRange Range, RefactoringRangeKind RangeKind) override { Ranges.push_back({Range, RangeKind, None}); } public: RenameRangeDetailCollector(const SourceManager &SM, StringRef OldName) : Renamer(SM, OldName) {} std::vector Ranges; }; class TextReplacementsRenamer : public Renamer { llvm::StringSet<> &ReplaceTextContext; std::vector Replacements; public: const DeclNameViewer New; private: StringRef registerText(StringRef Text) { if (Text.empty()) return Text; return ReplaceTextContext.insert(Text).first->getKey(); } StringRef getCallArgLabelReplacement(StringRef OldLabelRange, StringRef NewLabel) { return NewLabel.empty() ? "" : NewLabel; } StringRef getCallArgColonReplacement(StringRef OldLabelRange, StringRef NewLabel) { // Expected OldLabelRange: foo( []3, a[: ]2, b[ : ]3 ...) // FIXME: Preserve comments: foo([a/*:*/ : /*:*/ ]2, ...) if (NewLabel.empty()) return ""; if (OldLabelRange.empty()) return ": "; return registerText(OldLabelRange); } StringRef getCallArgCombinedReplacement(StringRef OldArgLabel, StringRef NewArgLabel) { // This case only happens when going from foo([]1) to foo([a: ]1). assert(OldArgLabel.empty()); if (NewArgLabel.empty()) return ""; return registerText((Twine(NewArgLabel) + ": ").str()); } StringRef getParamNameReplacement(StringRef OldParam, StringRef OldArgLabel, StringRef NewArgLabel) { // We don't want to get foo(a a: Int), so drop the parameter name if the // argument label will match the original name. // Note: the leading whitespace is part of the parameter range. if (!NewArgLabel.empty() && OldParam.ltrim() == NewArgLabel) return ""; // If we're renaming foo(x: Int) to foo(_:), then use the original argument // label as the parameter name so as to not break references in the body. if (NewArgLabel.empty() && !OldArgLabel.empty() && OldParam.empty()) return registerText((Twine(" ") + OldArgLabel).str()); return registerText(OldParam); } StringRef getDeclArgumentLabelReplacement(StringRef OldLabelRange, StringRef NewArgLabel) { // OldLabelRange is subscript([]a: Int), foo([a]: Int) or foo([a] b: Int) if (NewArgLabel.empty()) return OldLabelRange.empty() ? "" : "_"; if (OldLabelRange.empty()) return registerText((Twine(NewArgLabel) + " ").str()); return registerText(NewArgLabel); } StringRef getReplacementText(StringRef LabelRange, RefactoringRangeKind RangeKind, StringRef OldLabel, StringRef NewLabel) { switch (RangeKind) { case RefactoringRangeKind::CallArgumentLabel: return getCallArgLabelReplacement(LabelRange, NewLabel); case RefactoringRangeKind::CallArgumentColon: return getCallArgColonReplacement(LabelRange, NewLabel); case RefactoringRangeKind::CallArgumentCombined: return getCallArgCombinedReplacement(LabelRange, NewLabel); case RefactoringRangeKind::ParameterName: return getParamNameReplacement(LabelRange, OldLabel, NewLabel); case RefactoringRangeKind::NoncollapsibleParameterName: return LabelRange; case RefactoringRangeKind::DeclArgumentLabel: return getDeclArgumentLabelReplacement(LabelRange, NewLabel); case RefactoringRangeKind::SelectorArgumentLabel: return NewLabel.empty() ? "_" : registerText(NewLabel); default: llvm_unreachable("label range type is none but there are labels"); } } void addReplacement(CharSourceRange LabelRange, RefactoringRangeKind RangeKind, StringRef OldLabel, StringRef NewLabel) { StringRef ExistingLabel = LabelRange.str(); StringRef Text = getReplacementText(ExistingLabel, RangeKind, OldLabel, NewLabel); if (Text != ExistingLabel) Replacements.push_back({LabelRange, Text, {}}); } void doRenameLabel(CharSourceRange Label, RefactoringRangeKind RangeKind, unsigned NameIndex) override { addReplacement(Label, RangeKind, Old.args()[NameIndex], New.args()[NameIndex]); } void doRenameBase(CharSourceRange Range, RefactoringRangeKind) override { if (Old.base() != New.base()) Replacements.push_back({Range, registerText(New.base()), {}}); } public: TextReplacementsRenamer(const SourceManager &SM, StringRef OldName, StringRef NewName, llvm::StringSet<> &ReplaceTextContext) : Renamer(SM, OldName), ReplaceTextContext(ReplaceTextContext), New(NewName) { assert(Old.isValid() && New.isValid()); assert(Old.partsCount() == New.partsCount()); } std::vector getReplacements() const { return std::move(Replacements); } }; static const ValueDecl *getRelatedSystemDecl(const ValueDecl *VD) { if (VD->getModuleContext()->isSystemModule()) return VD; for (auto *Req : VD->getSatisfiedProtocolRequirements()) { if (Req->getModuleContext()->isSystemModule()) return Req; } for (auto Over = VD->getOverriddenDecl(); Over; Over = Over->getOverriddenDecl()) { if (Over->getModuleContext()->isSystemModule()) return Over; } return nullptr; } static Optional getAvailableRenameForDecl(const ValueDecl *VD, Optional RefInfo) { SmallVector Infos; collectRenameAvailabilityInfo(VD, RefInfo, Infos); for (auto &Info : Infos) { if (Info.AvailableKind == RenameAvailableKind::Available) return Info.Kind; } return None; } class RenameRangeCollector : public IndexDataConsumer { public: RenameRangeCollector(StringRef USR, StringRef newName) : USR(USR.str()), newName(newName.str()) {} RenameRangeCollector(const ValueDecl *D, StringRef newName) : newName(newName.str()) { llvm::raw_string_ostream OS(USR); printValueDeclUSR(D, OS); } ArrayRef results() const { return locations; } private: bool indexLocals() override { return true; } void failed(StringRef error) override {} bool startDependency(StringRef name, StringRef path, bool isClangModule, bool isSystem) override { return true; } bool finishDependency(bool isClangModule) override { return true; } Action startSourceEntity(const IndexSymbol &symbol) override { if (symbol.USR == USR) { if (auto loc = indexSymbolToRenameLoc(symbol, newName)) { // Inside capture lists like `{ [test] in }`, 'test' refers to both the // newly declared, captured variable and the referenced variable it is // initialized from. Make sure to only rename it once. auto existingLoc = llvm::find_if(locations, [&](RenameLoc searchLoc) { return searchLoc.Line == loc->Line && searchLoc.Column == loc->Column; }); if (existingLoc == locations.end()) { locations.push_back(std::move(*loc)); } else { assert(existingLoc->OldName == loc->OldName && existingLoc->NewName == loc->NewName && existingLoc->IsFunctionLike == loc->IsFunctionLike && existingLoc->IsNonProtocolType == loc->IsNonProtocolType && "Asked to do a different rename for the same location?"); } } } return IndexDataConsumer::Continue; } bool finishSourceEntity(SymbolInfo symInfo, SymbolRoleSet roles) override { return true; } Optional indexSymbolToRenameLoc(const index::IndexSymbol &symbol, StringRef NewName); private: std::string USR; std::string newName; StringScratchSpace stringStorage; std::vector locations; }; Optional RenameRangeCollector::indexSymbolToRenameLoc(const index::IndexSymbol &symbol, StringRef newName) { if (symbol.roles & (unsigned)index::SymbolRole::Implicit) { return None; } NameUsage usage = NameUsage::Unknown; if (symbol.roles & (unsigned)index::SymbolRole::Call) { usage = NameUsage::Call; } else if (symbol.roles & (unsigned)index::SymbolRole::Definition) { usage = NameUsage::Definition; } else if (symbol.roles & (unsigned)index::SymbolRole::Reference) { usage = NameUsage::Reference; } else { llvm_unreachable("unexpected role"); } bool isFunctionLike = false; bool isNonProtocolType = false; switch (symbol.symInfo.Kind) { case index::SymbolKind::EnumConstant: case index::SymbolKind::Function: case index::SymbolKind::Constructor: case index::SymbolKind::ConversionFunction: case index::SymbolKind::InstanceMethod: case index::SymbolKind::ClassMethod: case index::SymbolKind::StaticMethod: isFunctionLike = true; break; case index::SymbolKind::Class: case index::SymbolKind::Enum: case index::SymbolKind::Struct: isNonProtocolType = true; break; default: break; } StringRef oldName = stringStorage.copyString(symbol.name); return RenameLoc{symbol.line, symbol.column, usage, oldName, newName, isFunctionLike, isNonProtocolType}; } ArrayRef collectSourceFiles(ModuleDecl *MD, SmallVectorImpl &Scratch) { for (auto Unit : MD->getFiles()) { if (auto SF = dyn_cast(Unit)) { Scratch.push_back(SF); } } return llvm::makeArrayRef(Scratch); } /// Get the source file that contains the given range and belongs to the module. SourceFile *getContainingFile(ModuleDecl *M, RangeConfig Range) { SmallVector Files; for (auto File : collectSourceFiles(M, Files)) { if (File->getBufferID()) { if (File->getBufferID().value() == Range.BufferId) { return File; } } } return nullptr; } class RefactoringAction { protected: ModuleDecl *MD; SourceFile *TheFile; SourceEditConsumer &EditConsumer; ASTContext &Ctx; SourceManager &SM; DiagnosticEngine DiagEngine; SourceLoc StartLoc; StringRef PreferredName; public: RefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts, SourceEditConsumer &EditConsumer, DiagnosticConsumer &DiagConsumer); virtual ~RefactoringAction() = default; virtual bool performChange() = 0; }; RefactoringAction:: RefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts, SourceEditConsumer &EditConsumer, DiagnosticConsumer &DiagConsumer): MD(MD), TheFile(getContainingFile(MD, Opts.Range)), EditConsumer(EditConsumer), Ctx(MD->getASTContext()), SM(MD->getASTContext().SourceMgr), DiagEngine(SM), StartLoc(Lexer::getLocForStartOfToken(SM, Opts.Range.getStart(SM))), PreferredName(Opts.PreferredName) { DiagEngine.addConsumer(DiagConsumer); } /// Different from RangeBasedRefactoringAction, TokenBasedRefactoringAction takes /// the input of a given token, e.g., a name or an "if" key word. Contextual /// refactoring kinds can suggest applicable refactorings on that token, e.g. /// rename or reverse if statement. class TokenBasedRefactoringAction : public RefactoringAction { protected: ResolvedCursorInfo CursorInfo; public: TokenBasedRefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts, SourceEditConsumer &EditConsumer, DiagnosticConsumer &DiagConsumer) : RefactoringAction(MD, Opts, EditConsumer, DiagConsumer) { // Resolve the sema token and save it for later use. CursorInfo = evaluateOrDefault(TheFile->getASTContext().evaluator, CursorInfoRequest{ CursorInfoOwner(TheFile, StartLoc)}, ResolvedCursorInfo()); } }; #define CURSOR_REFACTORING(KIND, NAME, ID) \ class RefactoringAction##KIND: public TokenBasedRefactoringAction { \ public: \ RefactoringAction##KIND(ModuleDecl *MD, RefactoringOptions &Opts, \ SourceEditConsumer &EditConsumer, \ DiagnosticConsumer &DiagConsumer) : \ TokenBasedRefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {} \ bool performChange() override; \ static bool isApplicable(const ResolvedCursorInfo &Info, \ DiagnosticEngine &Diag); \ bool isApplicable() { \ return RefactoringAction##KIND::isApplicable(CursorInfo, DiagEngine) ; \ } \ }; #include "swift/Refactoring/RefactoringKinds.def" class RangeBasedRefactoringAction : public RefactoringAction { protected: ResolvedRangeInfo RangeInfo; public: RangeBasedRefactoringAction(ModuleDecl *MD, RefactoringOptions &Opts, SourceEditConsumer &EditConsumer, DiagnosticConsumer &DiagConsumer) : RefactoringAction(MD, Opts, EditConsumer, DiagConsumer), RangeInfo(evaluateOrDefault(MD->getASTContext().evaluator, RangeInfoRequest(RangeInfoOwner(TheFile, Opts.Range.getStart(SM), Opts.Range.getEnd(SM))), ResolvedRangeInfo())) {} }; #define RANGE_REFACTORING(KIND, NAME, ID) \ class RefactoringAction##KIND: public RangeBasedRefactoringAction { \ public: \ RefactoringAction##KIND(ModuleDecl *MD, RefactoringOptions &Opts, \ SourceEditConsumer &EditConsumer, \ DiagnosticConsumer &DiagConsumer) : \ RangeBasedRefactoringAction(MD, Opts, EditConsumer, DiagConsumer) {} \ bool performChange() override; \ static bool isApplicable(const ResolvedRangeInfo &Info, \ DiagnosticEngine &Diag); \ bool isApplicable() { \ return RefactoringAction##KIND::isApplicable(RangeInfo, DiagEngine) ; \ } \ }; #include "swift/Refactoring/RefactoringKinds.def" bool RefactoringActionLocalRename:: isApplicable(const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) { auto ValueRefInfo = dyn_cast(&CursorInfo); if (!ValueRefInfo) return false; Optional RefInfo; if (ValueRefInfo->isRef()) RefInfo = {CursorInfo.getSourceFile(), CursorInfo.getLoc(), ValueRefInfo->isKeywordArgument()}; auto RenameOp = getAvailableRenameForDecl(ValueRefInfo->getValueD(), RefInfo); return RenameOp.has_value() && RenameOp.value() == RefactoringKind::LocalRename; } static void analyzeRenameScope(ValueDecl *VD, Optional RefInfo, DiagnosticEngine &Diags, SmallVectorImpl &Scopes) { Scopes.clear(); if (!getAvailableRenameForDecl(VD, RefInfo).has_value()) { Diags.diagnose(SourceLoc(), diag::value_decl_no_loc, VD->getName()); return; } auto *Scope = VD->getDeclContext(); // There may be sibling decls that the renamed symbol is visible from. switch (Scope->getContextKind()) { case DeclContextKind::GenericTypeDecl: case DeclContextKind::ExtensionDecl: case DeclContextKind::TopLevelCodeDecl: case DeclContextKind::SubscriptDecl: case DeclContextKind::EnumElementDecl: case DeclContextKind::AbstractFunctionDecl: Scope = Scope->getParent(); break; case DeclContextKind::AbstractClosureExpr: case DeclContextKind::Initializer: case DeclContextKind::SerializedLocal: case DeclContextKind::Module: case DeclContextKind::FileUnit: case DeclContextKind::MacroDecl: break; } Scopes.push_back(Scope); } bool RefactoringActionLocalRename::performChange() { if (StartLoc.isInvalid()) { DiagEngine.diagnose(SourceLoc(), diag::invalid_location); return true; } if (!DeclNameViewer(PreferredName).isValid()) { DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName); return true; } if (!TheFile) { DiagEngine.diagnose(StartLoc, diag::location_module_mismatch, MD->getNameStr()); return true; } CursorInfo = evaluateOrDefault(TheFile->getASTContext().evaluator, CursorInfoRequest{CursorInfoOwner(TheFile, StartLoc)}, ResolvedCursorInfo()); auto ValueRefCursorInfo = dyn_cast(&CursorInfo); if (ValueRefCursorInfo && ValueRefCursorInfo->getValueD()) { ValueDecl *VD = ValueRefCursorInfo->typeOrValue(); // The index always uses the outermost shadow for references if (!ValueRefCursorInfo->getShorthandShadowedDecls().empty()) { VD = ValueRefCursorInfo->getShorthandShadowedDecls().back(); } SmallVector Scopes; Optional RefInfo; if (ValueRefCursorInfo->isRef()) RefInfo = {CursorInfo.getSourceFile(), CursorInfo.getLoc(), ValueRefCursorInfo->isKeywordArgument()}; analyzeRenameScope(VD, RefInfo, DiagEngine, Scopes); if (Scopes.empty()) return true; RenameRangeCollector rangeCollector(VD, PreferredName); for (DeclContext *DC : Scopes) indexDeclContext(DC, rangeCollector); auto consumers = DiagEngine.takeConsumers(); assert(consumers.size() == 1); return syntacticRename(TheFile, rangeCollector.results(), EditConsumer, *consumers[0]); } else { DiagEngine.diagnose(StartLoc, diag::unresolved_location); return true; } } StringRef getDefaultPreferredName(RefactoringKind Kind) { switch(Kind) { case RefactoringKind::None: llvm_unreachable("Should be a valid refactoring kind"); case RefactoringKind::GlobalRename: case RefactoringKind::LocalRename: return "newName"; case RefactoringKind::ExtractExpr: case RefactoringKind::ExtractRepeatedExpr: return "extractedExpr"; case RefactoringKind::ExtractFunction: return "extractedFunc"; default: return ""; } } enum class CannotExtractReason { Literal, VoidType, }; class ExtractCheckResult { bool KnownFailure; SmallVector AllReasons; public: ExtractCheckResult(): KnownFailure(true) {} ExtractCheckResult(ArrayRef AllReasons): KnownFailure(false), AllReasons(AllReasons.begin(), AllReasons.end()) {} bool success() { return success({}); } bool success(ArrayRef ExpectedReasons) { if (KnownFailure) return false; bool Result = true; // Check if any reasons aren't covered by the list of expected reasons // provided by the client. for (auto R: AllReasons) { Result &= llvm::is_contained(ExpectedReasons, R); } return Result; } }; /// Check whether a given range can be extracted. /// Return true on successful condition checking,. /// Return false on failed conditions. ExtractCheckResult checkExtractConditions(const ResolvedRangeInfo &RangeInfo, DiagnosticEngine &DiagEngine) { SmallVector AllReasons; // If any declared declaration is refered out of the given range, return false. auto Declared = RangeInfo.DeclaredDecls; auto It = std::find_if(Declared.begin(), Declared.end(), [](DeclaredDecl DD) { return DD.ReferredAfterRange; }); if (It != Declared.end()) { DiagEngine.diagnose(It->VD->getLoc(), diag::value_decl_referenced_out_of_range, It->VD->getName()); return ExtractCheckResult(); } // We cannot extract a range with multi entry points. if (!RangeInfo.HasSingleEntry) { DiagEngine.diagnose(SourceLoc(), diag::multi_entry_range); return ExtractCheckResult(); } // We cannot extract code that is not sure to exit or not. if (RangeInfo.exit() == ExitState::Unsure) { return ExtractCheckResult(); } // We cannot extract expressions of l-value type. if (auto Ty = RangeInfo.getType()) { if (Ty->hasLValueType() || Ty->is()) return ExtractCheckResult(); // Disallow extracting error type expressions/statements // FIXME: diagnose what happened? if (Ty->hasError()) return ExtractCheckResult(); if (Ty->isVoid()) { AllReasons.emplace_back(CannotExtractReason::VoidType); } } // We cannot extract a range with orphaned loop keyword. switch (RangeInfo.Orphan) { case swift::ide::OrphanKind::Continue: DiagEngine.diagnose(SourceLoc(), diag::orphan_loop_keyword, "continue"); return ExtractCheckResult(); case swift::ide::OrphanKind::Break: DiagEngine.diagnose(SourceLoc(), diag::orphan_loop_keyword, "break"); return ExtractCheckResult(); case swift::ide::OrphanKind::None: break; } // Guard statement can not be extracted. if (llvm::any_of(RangeInfo.ContainedNodes, [](ASTNode N) { return N.isStmt(StmtKind::Guard); })) { return ExtractCheckResult(); } // Disallow extracting certain kinds of statements. if (RangeInfo.Kind == RangeKind::SingleStatement) { Stmt *S = RangeInfo.ContainedNodes[0].get(); // These aren't independent statement. if (isa(S) || isa(S)) return ExtractCheckResult(); } // Disallow extracting literals. if (RangeInfo.Kind == RangeKind::SingleExpression) { Expr *E = RangeInfo.ContainedNodes[0].get(); // Until implementing the performChange() part of extracting trailing // closures, we disable them for now. if (isa(E)) return ExtractCheckResult(); if (isa(E)) AllReasons.emplace_back(CannotExtractReason::Literal); } switch (RangeInfo.RangeContext->getContextKind()) { case swift::DeclContextKind::Initializer: case swift::DeclContextKind::SubscriptDecl: case swift::DeclContextKind::EnumElementDecl: case swift::DeclContextKind::AbstractFunctionDecl: case swift::DeclContextKind::AbstractClosureExpr: case swift::DeclContextKind::TopLevelCodeDecl: break; case swift::DeclContextKind::SerializedLocal: case swift::DeclContextKind::Module: case swift::DeclContextKind::FileUnit: case swift::DeclContextKind::GenericTypeDecl: case swift::DeclContextKind::ExtensionDecl: case swift::DeclContextKind::MacroDecl: return ExtractCheckResult(); } return ExtractCheckResult(AllReasons); } bool RefactoringActionExtractFunction:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { switch (Info.Kind) { case RangeKind::PartOfExpression: case RangeKind::SingleDecl: case RangeKind::MultiTypeMemberDecl: case RangeKind::Invalid: return false; case RangeKind::SingleExpression: case RangeKind::SingleStatement: case RangeKind::MultiStatement: { return checkExtractConditions(Info, Diag). success({CannotExtractReason::VoidType}); } } llvm_unreachable("unhandled kind"); } static StringRef correctNameInternal(ASTContext &Ctx, StringRef Name, ArrayRef AllVisibles) { // If we find the collision. bool FoundCollision = false; // The suffixes we cannot use by appending to the original given name. llvm::StringSet<> UsedSuffixes; for (auto VD : AllVisibles) { StringRef S = VD->getBaseName().userFacingName(); if (!S.startswith(Name)) continue; StringRef Suffix = S.substr(Name.size()); if (Suffix.empty()) FoundCollision = true; else UsedSuffixes.insert(Suffix); } if (!FoundCollision) return Name; // Find the first suffix we can use. std::string SuffixToUse; for (unsigned I = 1; ; I ++) { SuffixToUse = std::to_string(I); if (UsedSuffixes.count(SuffixToUse) == 0) break; } return Ctx.getIdentifier((llvm::Twine(Name) + SuffixToUse).str()).str(); } static StringRef correctNewDeclName(DeclContext *DC, StringRef Name) { // Collect all visible decls in the decl context. llvm::SmallVector AllVisibles; VectorDeclConsumer Consumer(AllVisibles); ASTContext &Ctx = DC->getASTContext(); lookupVisibleDecls(Consumer, DC, true); return correctNameInternal(Ctx, Name, AllVisibles); } static Type sanitizeType(Type Ty) { // Transform lvalue type to inout type so that we can print it properly. return Ty.transform([](Type Ty) { if (Ty->is()) { return Type(InOutType::get(Ty->getRValueType()->getCanonicalType())); } return Ty; }); } static SourceLoc getNewFuncInsertLoc(DeclContext *DC, DeclContext*& InsertToContext) { if (auto D = DC->getInnermostDeclarationDeclContext()) { // If extracting from a getter/setter, we should skip both the immediate // getter/setter function and the individual var decl. The pattern binding // decl is the position before which we should insert the newly extracted // function. if (auto *FD = dyn_cast(D)) { ValueDecl *SD = FD->getStorage(); switch (SD->getKind()) { case DeclKind::Var: if (auto *PBD = cast(SD)->getParentPatternBinding()) D = PBD; break; case DeclKind::Subscript: D = SD; break; default: break; } } auto Result = D->getStartLoc(); assert(Result.isValid()); // The insert loc should be before every decl attributes. for (auto Attr : D->getAttrs()) { auto Loc = Attr->getRangeWithAt().Start; if (Loc.isValid() && Loc.getOpaquePointerValue() < Result.getOpaquePointerValue()) Result = Loc; } // The insert loc should be before the doc comments associated with this decl. if (!D->getRawComment().Comments.empty()) { auto Loc = D->getRawComment().Comments.front().Range.getStart(); if (Loc.isValid() && Loc.getOpaquePointerValue() < Result.getOpaquePointerValue()) { Result = Loc; } } InsertToContext = D->getDeclContext(); return Result; } return SourceLoc(); } static std::vector getNotableRegions(StringRef SourceText, unsigned NameOffset, StringRef Name, bool IsFunctionLike = false, bool IsNonProtocolType = false) { auto InputBuffer = llvm::MemoryBuffer::getMemBufferCopy(SourceText,""); CompilerInvocation Invocation{}; Invocation.getFrontendOptions().InputsAndOutputs.addInput( InputFile("", true, InputBuffer.get(), file_types::TY_Swift)); Invocation.getFrontendOptions().ModuleName = "extract"; Invocation.getLangOptions().DisablePoundIfEvaluation = true; auto Instance = std::make_unique(); std::string InstanceSetupError; if (Instance->setup(Invocation, InstanceSetupError)) llvm_unreachable(InstanceSetupError.c_str()); unsigned BufferId = Instance->getPrimarySourceFile()->getBufferID().value(); SourceManager &SM = Instance->getSourceMgr(); SourceLoc NameLoc = SM.getLocForOffset(BufferId, NameOffset); auto LineAndCol = SM.getLineAndColumnInBuffer(NameLoc); UnresolvedLoc UnresoledName{NameLoc, true}; NameMatcher Matcher(*Instance->getPrimarySourceFile()); auto Resolved = Matcher.resolve(llvm::makeArrayRef(UnresoledName), None); assert(!Resolved.empty() && "Failed to resolve generated func name loc"); RenameLoc RenameConfig = { LineAndCol.first, LineAndCol.second, NameUsage::Definition, /*OldName=*/Name, /*NewName=*/"", IsFunctionLike, IsNonProtocolType }; RenameRangeDetailCollector Renamer(SM, Name); Renamer.addSyntacticRenameRanges(Resolved.back(), RenameConfig); auto Ranges = Renamer.Ranges; std::vector NoteRegions(Renamer.Ranges.size()); llvm::transform( Ranges, NoteRegions.begin(), [&SM](RenameRangeDetail &Detail) -> NoteRegion { auto Start = SM.getLineAndColumnInBuffer(Detail.Range.getStart()); auto End = SM.getLineAndColumnInBuffer(Detail.Range.getEnd()); return {Detail.RangeKind, Start.first, Start.second, End.first, End.second, Detail.Index}; }); return NoteRegions; } bool RefactoringActionExtractFunction::performChange() { // Check if the new name is ok. if (!Lexer::isIdentifier(PreferredName)) { DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName); return true; } DeclContext *DC = RangeInfo.RangeContext; DeclContext *InsertToDC = nullptr; SourceLoc InsertLoc = getNewFuncInsertLoc(DC, InsertToDC); // Complain about no inserting position. if (InsertLoc.isInvalid()) { DiagEngine.diagnose(SourceLoc(), diag::no_insert_position); return true; } // Correct the given name if collision happens. PreferredName = correctNewDeclName(InsertToDC, PreferredName); // Collect the paramters to pass down to the new function. std::vector Parameters; for (auto &RD: RangeInfo.ReferencedDecls) { // If the referenced decl is declared elsewhere, no need to pass as parameter if (RD.VD->getDeclContext() != DC) continue; // We don't need to pass down implicitly declared variables, e.g. error in // a catch block. if (RD.VD->isImplicit()) { SourceLoc Loc = RD.VD->getStartLoc(); if (Loc.isValid() && SM.isBeforeInBuffer(RangeInfo.ContentRange.getStart(), Loc) && SM.isBeforeInBuffer(Loc, RangeInfo.ContentRange.getEnd())) continue; } // If the referenced decl is declared inside the range, no need to pass // as parameter. if (RangeInfo.DeclaredDecls.end() != std::find_if(RangeInfo.DeclaredDecls.begin(), RangeInfo.DeclaredDecls.end(), [RD](DeclaredDecl DD) { return RD.VD == DD.VD; })) continue; // We don't need to pass down self. if (auto PD = dyn_cast(RD.VD)) { if (PD->isSelfParameter()) { continue; } } Parameters.emplace_back(RD.VD, sanitizeType(RD.Ty)); } SmallString<64> Buffer; unsigned FuncBegin = Buffer.size(); unsigned FuncNameOffset; { llvm::raw_svector_ostream OS(Buffer); if (!InsertToDC->isLocalContext()) { // Default to be file private. OS << tok::kw_fileprivate << " "; } // Inherit static if the containing function is. if (DC->getContextKind() == DeclContextKind::AbstractFunctionDecl) { if (auto FD = dyn_cast(static_cast(DC))) { if (FD->isStatic()) { OS << tok::kw_static << " "; } } } OS << tok::kw_func << " "; FuncNameOffset = Buffer.size() - FuncBegin; OS << PreferredName; OS << "("; for (auto &RD : Parameters) { OS << "_ " << RD.VD->getBaseName().userFacingName() << ": "; RD.Ty->reconstituteSugar(/*Recursive*/true)->print(OS); if (&RD != &Parameters.back()) OS << ", "; } OS << ")"; if (RangeInfo.UnhandledEffects.contains(EffectKind::Async)) OS << " async"; if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws)) OS << " " << tok::kw_throws; bool InsertedReturnType = false; if (auto Ty = RangeInfo.getType()) { // If the type of the range is not void, specify the return type. if (!Ty->isVoid()) { OS << " " << tok::arrow << " "; sanitizeType(Ty)->reconstituteSugar(/*Recursive*/true)->print(OS); InsertedReturnType = true; } } OS << " {\n"; // Add "return" if the extracted entity is an expression. if (RangeInfo.Kind == RangeKind::SingleExpression && InsertedReturnType) OS << tok::kw_return << " "; OS << RangeInfo.ContentRange.str() << "\n}\n\n"; } unsigned FuncEnd = Buffer.size(); unsigned ReplaceBegin = Buffer.size(); unsigned CallNameOffset; { llvm::raw_svector_ostream OS(Buffer); if (RangeInfo.exit() == ExitState::Positive) OS << tok::kw_return <<" "; if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws)) OS << tok::kw_try << " "; if (RangeInfo.UnhandledEffects.contains(EffectKind::Async)) OS << "await "; CallNameOffset = Buffer.size() - ReplaceBegin; OS << PreferredName << "("; for (auto &RD : Parameters) { // Inout argument needs "&". if (RD.Ty->is()) OS << "&"; OS << RD.VD->getBaseName().userFacingName(); if (&RD != &Parameters.back()) OS << ", "; } OS << ")"; } unsigned ReplaceEnd = Buffer.size(); std::string ExtractedFuncName = PreferredName.str() + "("; for (size_t i = 0; i < Parameters.size(); ++i) { ExtractedFuncName += "_:"; } ExtractedFuncName += ")"; StringRef DeclStr(Buffer.begin() + FuncBegin, FuncEnd - FuncBegin); auto NotableFuncRegions = getNotableRegions(DeclStr, FuncNameOffset, ExtractedFuncName, /*IsFunctionLike=*/true); StringRef CallStr(Buffer.begin() + ReplaceBegin, ReplaceEnd - ReplaceBegin); auto NotableCallRegions = getNotableRegions(CallStr, CallNameOffset, ExtractedFuncName, /*IsFunctionLike=*/true); // Insert the new function's declaration. EditConsumer.accept(SM, InsertLoc, DeclStr, NotableFuncRegions); // Replace the code to extract with the function call. EditConsumer.accept(SM, RangeInfo.ContentRange, CallStr, NotableCallRegions); return false; } class RefactoringActionExtractExprBase { SourceFile *TheFile; ResolvedRangeInfo RangeInfo; DiagnosticEngine &DiagEngine; const bool ExtractRepeated; StringRef PreferredName; SourceEditConsumer &EditConsumer; ASTContext &Ctx; SourceManager &SM; public: RefactoringActionExtractExprBase(SourceFile *TheFile, ResolvedRangeInfo RangeInfo, DiagnosticEngine &DiagEngine, bool ExtractRepeated, StringRef PreferredName, SourceEditConsumer &EditConsumer) : TheFile(TheFile), RangeInfo(RangeInfo), DiagEngine(DiagEngine), ExtractRepeated(ExtractRepeated), PreferredName(PreferredName), EditConsumer(EditConsumer), Ctx(TheFile->getASTContext()), SM(Ctx.SourceMgr){} bool performChange(); }; /// This is to ensure all decl references in two expressions are identical. struct ReferenceCollector: public SourceEntityWalker { SmallVector References; ReferenceCollector(Expr *E) { walk(E); } bool visitDeclReference(ValueDecl *D, CharSourceRange Range, TypeDecl *CtorTyRef, ExtensionDecl *ExtTyRef, Type T, ReferenceMetaData Data) override { References.emplace_back(D); return true; } bool operator==(const ReferenceCollector &Other) const { if (References.size() != Other.References.size()) return false; return std::equal(References.begin(), References.end(), Other.References.begin()); } }; struct SimilarExprCollector: public SourceEntityWalker { SourceManager &SM; /// The expression under selection. Expr *SelectedExpr; ArrayRef AllTokens; llvm::SetVector &Bucket; /// The tokens included in the expression under selection. ArrayRef SelectedTokens; /// The referenced decls in the expression under selection. ReferenceCollector SelectedReferences; bool compareTokenContent(ArrayRef Left, ArrayRef Right) { if (Left.size() != Right.size()) return false; return std::equal(Left.begin(), Left.end(), Right.begin(), [](const Token &L, const Token& R) { return L.getText() == R.getText(); }); } /// Find all tokens included by an expression. ArrayRef getExprSlice(Expr *E) { return slice_token_array(AllTokens, E->getStartLoc(), E->getEndLoc()); } SimilarExprCollector(SourceManager &SM, Expr *SelectedExpr, ArrayRef AllTokens, llvm::SetVector &Bucket): SM(SM), SelectedExpr(SelectedExpr), AllTokens(AllTokens), Bucket(Bucket), SelectedTokens(getExprSlice(SelectedExpr)), SelectedReferences(SelectedExpr){} bool walkToExprPre(Expr *E) override { // We don't extract implicit expressions. if (E->isImplicit()) return true; if (E->getKind() != SelectedExpr->getKind()) return true; // First check the underlying token arrays have the same content. if (compareTokenContent(getExprSlice(E), SelectedTokens)) { ReferenceCollector CurrentReferences(E); // Next, check the referenced decls are same. if (CurrentReferences == SelectedReferences) Bucket.insert(E); } return true; } }; bool RefactoringActionExtractExprBase::performChange() { // Check if the new name is ok. if (!Lexer::isIdentifier(PreferredName)) { DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName); return true; } // Find the enclosing brace statement; ContextFinder Finder(*TheFile, RangeInfo.ContainedNodes.front(), [](ASTNode N) { return N.isStmt(StmtKind::Brace); }); auto *SelectedExpr = RangeInfo.ContainedNodes[0].get(); Finder.resolve(); SourceLoc InsertLoc; llvm::SetVector AllVisibleDecls; struct DeclCollector: public SourceEntityWalker { llvm::SetVector &Bucket; DeclCollector(llvm::SetVector &Bucket): Bucket(Bucket) {} bool walkToDeclPre(Decl *D, CharSourceRange Range) override { if (auto *VD = dyn_cast(D)) Bucket.insert(VD); return true; } } Collector(AllVisibleDecls); llvm::SetVector AllExpressions; if (!Finder.getContexts().empty()) { // Get the innermost brace statement. auto BS = static_cast(Finder.getContexts().back().get()); // Collect all value decls inside the brace statement. Collector.walk(BS); if (ExtractRepeated) { // Collect all expressions we are going to extract. SimilarExprCollector(SM, SelectedExpr, slice_token_array(TheFile->getAllTokens(), BS->getStartLoc(), BS->getEndLoc()), AllExpressions).walk(BS); } else { AllExpressions.insert(SelectedExpr); } assert(!AllExpressions.empty() && "at least one expression is extracted."); for (auto Ele : BS->getElements()) { // Find the element that encloses the first expression under extraction. if (SM.rangeContains(Ele.getSourceRange(), (*AllExpressions.begin())->getSourceRange())) { // Insert before the enclosing element. InsertLoc = Ele.getStartLoc(); } } } // Complain about no inserting position. if (InsertLoc.isInvalid()) { DiagEngine.diagnose(SourceLoc(), diag::no_insert_position); return true; } // Correct name if collision happens. PreferredName = correctNameInternal(TheFile->getASTContext(), PreferredName, AllVisibleDecls.getArrayRef()); // Print the type name of this expression. SmallString<16> TyBuffer; // We are not sure about the type of repeated expressions. if (!ExtractRepeated) { if (auto Ty = RangeInfo.getType()) { llvm::raw_svector_ostream OS(TyBuffer); OS << ": "; Ty->getRValueType()->reconstituteSugar(true)->print(OS); } } SmallString<64> DeclBuffer; llvm::raw_svector_ostream OS(DeclBuffer); unsigned StartOffset, EndOffset; OS << tok::kw_let << " "; StartOffset = DeclBuffer.size(); OS << PreferredName; EndOffset = DeclBuffer.size(); OS << TyBuffer.str() << " = " << RangeInfo.ContentRange.str() << "\n"; NoteRegion DeclNameRegion{ RefactoringRangeKind::BaseName, /*StartLine=*/1, /*StartColumn=*/StartOffset + 1, /*EndLine=*/1, /*EndColumn=*/EndOffset + 1, /*ArgIndex*/None }; // Perform code change. EditConsumer.accept(SM, InsertLoc, DeclBuffer.str(), {DeclNameRegion}); // Replace all occurrences of the extracted expression. for (auto *E : AllExpressions) { EditConsumer.accept(SM, Lexer::getCharSourceRangeFromSourceRange(SM, E->getSourceRange()), PreferredName, {{ RefactoringRangeKind::BaseName, /*StartLine=*/1, /*StartColumn-*/1, /*EndLine=*/1, /*EndColumn=*/static_cast(PreferredName.size() + 1), /*ArgIndex*/None }}); } return false; } bool RefactoringActionExtractExpr:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { switch (Info.Kind) { case RangeKind::SingleExpression: // We disallow extract literal expression for two reasons: // (1) since we print the type for extracted expression, the type of a // literal may print as "int2048" where it is not typically users' choice; // (2) Extracting one literal provides little value for users. return checkExtractConditions(Info, Diag).success(); case RangeKind::PartOfExpression: case RangeKind::SingleDecl: case RangeKind::MultiTypeMemberDecl: case RangeKind::SingleStatement: case RangeKind::MultiStatement: case RangeKind::Invalid: return false; } llvm_unreachable("unhandled kind"); } bool RefactoringActionExtractExpr::performChange() { return RefactoringActionExtractExprBase(TheFile, RangeInfo, DiagEngine, false, PreferredName, EditConsumer).performChange(); } bool RefactoringActionExtractRepeatedExpr:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { switch (Info.Kind) { case RangeKind::SingleExpression: return checkExtractConditions(Info, Diag). success({CannotExtractReason::Literal}); case RangeKind::PartOfExpression: case RangeKind::SingleDecl: case RangeKind::MultiTypeMemberDecl: case RangeKind::SingleStatement: case RangeKind::MultiStatement: case RangeKind::Invalid: return false; } llvm_unreachable("unhandled kind"); } bool RefactoringActionExtractRepeatedExpr::performChange() { return RefactoringActionExtractExprBase(TheFile, RangeInfo, DiagEngine, true, PreferredName, EditConsumer).performChange(); } bool RefactoringActionMoveMembersToExtension::isApplicable( const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { switch (Info.Kind) { case RangeKind::SingleDecl: case RangeKind::MultiTypeMemberDecl: { DeclContext *DC = Info.RangeContext; // The the common decl context is not a nomial type, we cannot create an // extension for it if (!DC || !DC->getInnermostDeclarationDeclContext() || !isa(DC->getInnermostDeclarationDeclContext())) return false; // Members of types not declared at top file level cannot be extracted // to an extension at top file level if (DC->getParent()->getContextKind() != DeclContextKind::FileUnit) return false; // Check if contained nodes are all allowed decls. for (auto Node : Info.ContainedNodes) { Decl *D = Node.dyn_cast(); if (!D) return false; if (isa(D) || isa(D) || isa(D) || isa(D)) return false; } // We should not move instance variables with storage into the extension // because they are not allowed to be declared there for (auto DD : Info.DeclaredDecls) { if (auto ASD = dyn_cast(DD.VD)) { // Only disallow storages in the common decl context, allow them in // any subtypes if (ASD->hasStorage() && ASD->getDeclContext() == DC) { return false; } } } return true; } case RangeKind::SingleExpression: case RangeKind::PartOfExpression: case RangeKind::SingleStatement: case RangeKind::MultiStatement: case RangeKind::Invalid: return false; } llvm_unreachable("unhandled kind"); } bool RefactoringActionMoveMembersToExtension::performChange() { DeclContext *DC = RangeInfo.RangeContext; auto CommonTypeDecl = dyn_cast(DC->getInnermostDeclarationDeclContext()); assert(CommonTypeDecl && "Not applicable if common parent is no nomial type"); SmallString<64> Buffer; llvm::raw_svector_ostream OS(Buffer); OS << "\n\n"; OS << "extension " << CommonTypeDecl->getName() << " {\n"; OS << RangeInfo.ContentRange.str().trim(); OS << "\n}"; // Insert extension after the type declaration EditConsumer.insertAfter(SM, CommonTypeDecl->getEndLoc(), Buffer); EditConsumer.remove(SM, RangeInfo.ContentRange); return false; } namespace { // A SingleDecl range may not include all decls actually declared in that range: // a var decl has accessors that aren't included. This will find those missing // decls. class FindAllSubDecls : public SourceEntityWalker { SmallPtrSetImpl &Found; public: FindAllSubDecls(SmallPtrSetImpl &found) : Found(found) {} bool walkToDeclPre(Decl *D, CharSourceRange range) override { // Record this Decl, and skip its contents if we've already touched it. if (!Found.insert(D).second) return false; if (auto ASD = dyn_cast(D)) { ASD->visitParsedAccessors([&](AccessorDecl *accessor) { Found.insert(accessor); }); } return true; } }; } bool RefactoringActionReplaceBodiesWithFatalError::isApplicable( const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { switch (Info.Kind) { case RangeKind::SingleDecl: case RangeKind::MultiTypeMemberDecl: { SmallPtrSet Found; for (auto decl : Info.DeclaredDecls) { FindAllSubDecls(Found).walk(decl.VD); } for (auto decl : Found) { auto AFD = dyn_cast(decl); if (AFD && !AFD->isImplicit()) return true; } return false; } case RangeKind::SingleExpression: case RangeKind::PartOfExpression: case RangeKind::SingleStatement: case RangeKind::MultiStatement: case RangeKind::Invalid: return false; } llvm_unreachable("unhandled kind"); } bool RefactoringActionReplaceBodiesWithFatalError::performChange() { const StringRef replacement = "{\nfatalError()\n}"; SmallPtrSet Found; for (auto decl : RangeInfo.DeclaredDecls) { FindAllSubDecls(Found).walk(decl.VD); } for (auto decl : Found) { auto AFD = dyn_cast(decl); if (!AFD || AFD->isImplicit()) continue; auto range = AFD->getBodySourceRange(); // If we're in replacement mode (i.e. have an edit consumer), we can // rewrite the function body. auto charRange = Lexer::getCharSourceRangeFromSourceRange(SM, range); EditConsumer.accept(SM, charRange, replacement); } return false; } static std::pair findCollapseNestedIfTarget(const ResolvedCursorInfo &CursorInfo) { auto StmtStartInfo = dyn_cast(&CursorInfo); if (!StmtStartInfo) return {}; // Ensure the statement is 'if' statement. It must not have 'else' clause. IfStmt *OuterIf = dyn_cast(StmtStartInfo->getTrailingStmt()); if (!OuterIf) return {}; if (OuterIf->getElseStmt()) return {}; // The body must contain a sole inner 'if' statement. auto Body = dyn_cast_or_null(OuterIf->getThenStmt()); if (!Body || Body->getNumElements() != 1) return {}; IfStmt *InnerIf = dyn_cast_or_null(Body->getFirstElement().dyn_cast()); if (!InnerIf) return {}; // Inner 'if' statement also cannot have 'else' clause. if (InnerIf->getElseStmt()) return {}; return {OuterIf, InnerIf}; } bool RefactoringActionCollapseNestedIfStmt:: isApplicable(const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) { return findCollapseNestedIfTarget(CursorInfo).first; } bool RefactoringActionCollapseNestedIfStmt::performChange() { auto Target = findCollapseNestedIfTarget(CursorInfo); if (!Target.first) return true; auto OuterIf = Target.first; auto InnerIf = Target.second; EditorConsumerInsertStream OS( EditConsumer, SM, Lexer::getCharSourceRangeFromSourceRange(SM, OuterIf->getSourceRange())); OS << tok::kw_if << " "; // Emit conditions. bool first = true; for (auto &C : llvm::concat(OuterIf->getCond(), InnerIf->getCond())) { if (first) first = false; else OS << ", "; OS << Lexer::getCharSourceRangeFromSourceRange(SM, C.getSourceRange()) .str(); } // Emit body. OS << " "; OS << Lexer::getCharSourceRangeFromSourceRange( SM, InnerIf->getThenStmt()->getSourceRange()) .str(); return false; } static std::unique_ptr> findConcatenatedExpressions(const ResolvedRangeInfo &Info, ASTContext &Ctx) { Expr *E = nullptr; switch (Info.Kind) { case RangeKind::SingleExpression: E = Info.ContainedNodes[0].get(); break; case RangeKind::PartOfExpression: E = Info.CommonExprParent; break; default: return nullptr; } assert(E); struct StringInterpolationExprFinder: public SourceEntityWalker { std::unique_ptr> Bucket = std::make_unique>(); ASTContext &Ctx; bool IsValidInterpolation = true; StringInterpolationExprFinder(ASTContext &Ctx): Ctx(Ctx) {} bool isConcatenationExpr(DeclRefExpr* Expr) { if (!Expr) return false; auto *FD = dyn_cast(Expr->getDecl()); if (FD == nullptr || (FD != Ctx.getPlusFunctionOnString() && FD != Ctx.getPlusFunctionOnRangeReplaceableCollection())) { return false; } return true; } bool walkToExprPre(Expr *E) override { if (E->isImplicit()) return true; // FIXME: we should have ErrorType instead of null. if (E->getType().isNull()) return true; //Only binary concatenation operators should exist in expression if (E->getKind() == ExprKind::Binary) { auto *BE = dyn_cast(E); auto *OperatorDeclRef = BE->getSemanticFn()->getMemberOperatorRef(); if (!(isConcatenationExpr(OperatorDeclRef) && E->getType()->isString())) { IsValidInterpolation = false; return false; } return true; } // Everything that evaluates to string should be gathered. if (E->getType()->isString()) { Bucket->insert(E); return false; } if (auto *DR = dyn_cast(E)) { // Checks whether all function references in expression are concatenations. auto *FD = dyn_cast(DR->getDecl()); auto IsConcatenation = isConcatenationExpr(DR); if (FD && IsConcatenation) { return false; } } // There was non-expected expression, it's not valid interpolation then. IsValidInterpolation = false; return false; } } Walker(Ctx); Walker.walk(E); // There should be two or more expressions to convert. if (!Walker.IsValidInterpolation || Walker.Bucket->size() < 2) return nullptr; return std::move(Walker.Bucket); } static void interpolatedExpressionForm(Expr *E, SourceManager &SM, llvm::raw_ostream &OS) { if (auto *Literal = dyn_cast(E)) { OS << Literal->getValue(); return; } auto ExpStr = Lexer::getCharSourceRangeFromSourceRange(SM, E->getSourceRange()).str().str(); if (isa(E)) { ExpStr.erase(0, 1); ExpStr.pop_back(); OS << ExpStr; return; } OS << "\\(" << ExpStr << ")"; } bool RefactoringActionConvertStringsConcatenationToInterpolation:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { auto RangeContext = Info.RangeContext; if (RangeContext) { auto &Ctx = Info.RangeContext->getASTContext(); return findConcatenatedExpressions(Info, Ctx) != nullptr; } return false; } bool RefactoringActionConvertStringsConcatenationToInterpolation::performChange() { auto Expressions = findConcatenatedExpressions(RangeInfo, Ctx); if (!Expressions) return true; EditorConsumerInsertStream OS(EditConsumer, SM, RangeInfo.ContentRange); OS << "\""; for (auto It = Expressions->begin(); It != Expressions->end(); ++It) { interpolatedExpressionForm(*It, SM, OS); } OS << "\""; return false; } /// Abstract helper class containing info about a TernaryExpr /// that can be expanded into an IfStmt. class ExpandableTernaryExprInfo { public: virtual ~ExpandableTernaryExprInfo() {} virtual TernaryExpr *getTernary() = 0; virtual SourceRange getNameRange() = 0; virtual Type getType() = 0; virtual bool shouldDeclareNameAndType() { return !getType().isNull(); } virtual bool isValid() { //Ensure all public properties are non-nil and valid if (!getTernary() || !getNameRange().isValid()) return false; if (shouldDeclareNameAndType() && getType().isNull()) return false; return true; //valid } CharSourceRange getNameCharRange(const SourceManager &SM) { return Lexer::getCharSourceRangeFromSourceRange(SM, getNameRange()); } }; /// Concrete subclass containing info about an AssignExpr /// where the source is the expandable TernaryExpr. class ExpandableAssignTernaryExprInfo: public ExpandableTernaryExprInfo { public: ExpandableAssignTernaryExprInfo(AssignExpr *Assign): Assign(Assign) {} TernaryExpr *getTernary() override { if (!Assign) return nullptr; return dyn_cast_or_null(Assign->getSrc()); } SourceRange getNameRange() override { auto Invalid = SourceRange(); if (!Assign) return Invalid; if (auto dest = Assign->getDest()) return dest->getSourceRange(); return Invalid; } Type getType() override { return nullptr; } private: AssignExpr *Assign = nullptr; }; /// Concrete subclass containing info about a PatternBindingDecl /// where the pattern initializer is the expandable TernaryExpr. class ExpandableBindingTernaryExprInfo: public ExpandableTernaryExprInfo { public: ExpandableBindingTernaryExprInfo(PatternBindingDecl *Binding): Binding(Binding) {} TernaryExpr *getTernary() override { if (Binding && Binding->getNumPatternEntries() == 1) { if (auto *Init = Binding->getInit(0)) { return dyn_cast(Init); } } return nullptr; } SourceRange getNameRange() override { if (auto Pattern = getNamePattern()) return Pattern->getSourceRange(); return SourceRange(); } Type getType() override { if (auto Pattern = getNamePattern()) return Pattern->getType(); return nullptr; } private: Pattern *getNamePattern() { if (!Binding || Binding->getNumPatternEntries() != 1) return nullptr; auto Pattern = Binding->getPattern(0); if (!Pattern) return nullptr; if (auto TyPattern = dyn_cast(Pattern)) Pattern = TyPattern->getSubPattern(); return Pattern; } PatternBindingDecl *Binding = nullptr; }; std::unique_ptr findExpandableTernaryExpression(const ResolvedRangeInfo &Info) { if (Info.Kind != RangeKind::SingleDecl && Info.Kind != RangeKind:: SingleExpression) return nullptr; if (Info.ContainedNodes.size() != 1) return nullptr; if (auto D = Info.ContainedNodes[0].dyn_cast()) if (auto Binding = dyn_cast(D)) return std::make_unique(Binding); if (auto E = Info.ContainedNodes[0].dyn_cast()) if (auto Assign = dyn_cast(E)) return std::make_unique(Assign); return nullptr; } bool RefactoringActionExpandTernaryExpr:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { auto Target = findExpandableTernaryExpression(Info); return Target && Target->isValid(); } bool RefactoringActionExpandTernaryExpr::performChange() { auto Target = findExpandableTernaryExpression(RangeInfo); if (!Target || !Target->isValid()) return true; //abort auto NameCharRange = Target->getNameCharRange(SM); auto IfRange = Target->getTernary()->getSourceRange(); auto IfCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, IfRange); auto CondRange = Target->getTernary()->getCondExpr()->getSourceRange(); auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, CondRange); auto ThenRange = Target->getTernary()->getThenExpr()->getSourceRange(); auto ThenCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ThenRange); auto ElseRange = Target->getTernary()->getElseExpr()->getSourceRange(); auto ElseCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ElseRange); SmallString<64> DeclBuffer; llvm::raw_svector_ostream OS(DeclBuffer); StringRef Space = " "; StringRef NewLine = "\n"; if (Target->shouldDeclareNameAndType()) { //Specifier will not be replaced; append after specifier OS << NameCharRange.str() << tok::colon << Space; OS << Target->getType() << NewLine; } OS << tok::kw_if << Space; OS << CondCharRange.str() << Space; OS << tok::l_brace << NewLine; OS << NameCharRange.str() << Space; OS << tok::equal << Space; OS << ThenCharRange.str() << NewLine; OS << tok::r_brace << Space; OS << tok::kw_else << Space; OS << tok::l_brace << NewLine; OS << NameCharRange.str() << Space; OS << tok::equal << Space; OS << ElseCharRange.str() << NewLine; OS << tok::r_brace; //Start replacement with name range, skip the specifier auto ReplaceRange(NameCharRange); ReplaceRange.widen(IfCharRange); EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str()); return false; //don't abort } bool RefactoringActionConvertIfLetExprToGuardExpr:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { if (Info.Kind != RangeKind::SingleStatement && Info.Kind != RangeKind::MultiStatement) return false; if (Info.ContainedNodes.empty()) return false; IfStmt *If = nullptr; if (Info.ContainedNodes.size() == 1) { if (auto S = Info.ContainedNodes[0].dyn_cast()) { If = dyn_cast(S); } } if (!If) return false; auto CondList = If->getCond(); if (CondList.size() == 1) { auto E = CondList[0]; auto P = E.getKind(); if (P == swift::StmtConditionElement::CK_PatternBinding) { auto Body = dyn_cast_or_null(If->getThenStmt()); if (Body) return true; } } return false; } bool RefactoringActionConvertIfLetExprToGuardExpr::performChange() { auto S = RangeInfo.ContainedNodes[0].dyn_cast(); IfStmt *If = dyn_cast(S); auto CondList = If->getCond(); // Get if-let condition SourceRange range = CondList[0].getSourceRange(); SourceManager &SM = RangeInfo.RangeContext->getASTContext().SourceMgr; auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, range); auto Body = dyn_cast_or_null(If->getThenStmt()); // Get if-let then body. auto firstElement = Body->getFirstElement(); auto lastElement = Body->getLastElement(); SourceRange bodyRange = firstElement.getSourceRange(); bodyRange.widen(lastElement.getSourceRange()); auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, bodyRange); SmallString<64> DeclBuffer; llvm::raw_svector_ostream OS(DeclBuffer); StringRef Space = " "; StringRef NewLine = "\n"; OS << tok::kw_guard << Space; OS << CondCharRange.str().str() << Space; OS << tok::kw_else << Space; OS << tok::l_brace << NewLine; // Get if-let else body. if (auto *ElseBody = dyn_cast_or_null(If->getElseStmt())) { auto firstElseElement = ElseBody->getFirstElement(); auto lastElseElement = ElseBody->getLastElement(); SourceRange elseBodyRange = firstElseElement.getSourceRange(); elseBodyRange.widen(lastElseElement.getSourceRange()); auto ElseBodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, elseBodyRange); OS << ElseBodyCharRange.str().str() << NewLine; } OS << tok::kw_return << NewLine; OS << tok::r_brace << NewLine; OS << BodyCharRange.str().str(); // Replace if-let to guard auto ReplaceRange = RangeInfo.ContentRange; EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str()); return false; } bool RefactoringActionConvertGuardExprToIfLetExpr:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { if (Info.Kind != RangeKind::SingleStatement && Info.Kind != RangeKind::MultiStatement) return false; if (Info.ContainedNodes.empty()) return false; GuardStmt *guardStmt = nullptr; if (Info.ContainedNodes.size() > 0) { if (auto S = Info.ContainedNodes[0].dyn_cast()) { guardStmt = dyn_cast(S); } } if (!guardStmt) return false; auto CondList = guardStmt->getCond(); if (CondList.size() == 1) { auto E = CondList[0]; auto P = E.getPatternOrNull(); if (P && E.getKind() == swift::StmtConditionElement::CK_PatternBinding) return true; } return false; } bool RefactoringActionConvertGuardExprToIfLetExpr::performChange() { // Get guard stmt auto S = RangeInfo.ContainedNodes[0].dyn_cast(); GuardStmt *Guard = dyn_cast(S); // Get guard condition auto CondList = Guard->getCond(); // Get guard condition source SourceRange range = CondList[0].getSourceRange(); SourceManager &SM = RangeInfo.RangeContext->getASTContext().SourceMgr; auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, range); SmallString<64> DeclBuffer; llvm::raw_svector_ostream OS(DeclBuffer); StringRef Space = " "; StringRef NewLine = "\n"; OS << tok::kw_if << Space; OS << CondCharRange.str().str() << Space; OS << tok::l_brace << NewLine; // Get nodes after guard to place them at if-let body if (RangeInfo.ContainedNodes.size() > 1) { auto S = RangeInfo.ContainedNodes[1].getSourceRange(); S.widen(RangeInfo.ContainedNodes.back().getSourceRange()); auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, S); OS << BodyCharRange.str().str() << NewLine; } OS << tok::r_brace; // Get guard body auto Body = dyn_cast_or_null(Guard->getBody()); if (Body && Body->getNumElements() > 1) { auto firstElement = Body->getFirstElement(); auto lastElement = Body->getLastElement(); SourceRange bodyRange = firstElement.getSourceRange(); bodyRange.widen(lastElement.getSourceRange()); auto BodyCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, bodyRange); OS << Space << tok::kw_else << Space << tok::l_brace << NewLine; OS << BodyCharRange.str().str() << NewLine; OS << tok::r_brace; } // Replace guard to if-let auto ReplaceRange = RangeInfo.ContentRange; EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str()); return false; } bool RefactoringActionConvertToSwitchStmt:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { class ConditionalChecker : public ASTWalker { public: bool ParamsUseSameVars = true; bool ConditionUseOnlyAllowedFunctions = false; StringRef ExpectName; PostWalkResult walkToExprPost(Expr *E) override { if (E->getKind() != ExprKind::DeclRef) return Action::Continue(E); auto D = dyn_cast(E)->getDecl(); if (D->getKind() == DeclKind::Var || D->getKind() == DeclKind::Param) ParamsUseSameVars = checkName(dyn_cast(D)); if (D->getKind() == DeclKind::Func) ConditionUseOnlyAllowedFunctions = checkName(dyn_cast(D)); if (allCheckPassed()) return Action::Continue(E); return Action::Stop(); } bool allCheckPassed() { return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions; } private: bool checkName(VarDecl *VD) { auto Name = VD->getName().str(); if (ExpectName.empty()) ExpectName = Name; return Name == ExpectName; } bool checkName(FuncDecl *FD) { const auto Name = FD->getBaseIdentifier().str(); return Name == "~=" || Name == "==" || Name == "__derived_enum_equals" || Name == "__derived_struct_equals" || Name == "||" || Name == "..."; } }; class SwitchConvertable { public: SwitchConvertable(const ResolvedRangeInfo &Info) : Info(Info) { } bool isApplicable() { if (Info.Kind != RangeKind::SingleStatement) return false; if (!findIfStmt()) return false; return checkEachCondition(); } private: const ResolvedRangeInfo &Info; IfStmt *If = nullptr; ConditionalChecker checker; bool findIfStmt() { if (Info.ContainedNodes.size() != 1) return false; if (auto S = Info.ContainedNodes.front().dyn_cast()) If = dyn_cast(S); return If != nullptr; } bool checkEachCondition() { checker = ConditionalChecker(); do { if (!checkEachElement()) return false; } while ((If = dyn_cast_or_null(If->getElseStmt()))); return true; } bool checkEachElement() { bool result = true; auto ConditionalList = If->getCond(); for (auto Element : ConditionalList) { result &= check(Element); } return result; } bool check(StmtConditionElement ConditionElement) { if (ConditionElement.getKind() == StmtConditionElement::CK_Availability) return false; if (ConditionElement.getKind() == StmtConditionElement::CK_PatternBinding) checker.ConditionUseOnlyAllowedFunctions = true; ConditionElement.walk(checker); return checker.allCheckPassed(); } }; return SwitchConvertable(Info).isApplicable(); } bool RefactoringActionConvertToSwitchStmt::performChange() { class VarNameFinder : public ASTWalker { public: std::string VarName; PostWalkResult walkToExprPost(Expr *E) override { if (E->getKind() != ExprKind::DeclRef) return Action::Continue(E); auto D = dyn_cast(E)->getDecl(); if (D->getKind() != DeclKind::Var && D->getKind() != DeclKind::Param) return Action::Continue(E); VarName = dyn_cast(D)->getName().str().str(); return Action::Stop(); } }; class ConditionalPatternFinder : public ASTWalker { public: ConditionalPatternFinder(SourceManager &SM) : SM(SM) {} SmallString<64> ConditionalPattern = SmallString<64>(); PostWalkResult walkToExprPost(Expr *E) override { auto *BE = dyn_cast(E); if (!BE) return Action::Continue(E); if (isFunctionNameAllowed(BE)) appendPattern(BE->getLHS(), BE->getRHS()); return Action::Continue(E); } PreWalkResult walkToPatternPre(Pattern *P) override { ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, P->getSourceRange()).str()); if (P->getKind() == PatternKind::OptionalSome) ConditionalPattern.append("?"); return Action::Stop(); } private: SourceManager &SM; bool isFunctionNameAllowed(BinaryExpr *E) { Expr *Fn = E->getFn(); if (auto DotSyntaxCall = dyn_cast_or_null(Fn)) { Fn = DotSyntaxCall->getFn(); } DeclRefExpr *DeclRef = dyn_cast_or_null(Fn); if (!DeclRef) { return false; } auto FunctionDeclaration = dyn_cast_or_null(DeclRef->getDecl()); if (!FunctionDeclaration) { return false; } auto &ASTCtx = FunctionDeclaration->getASTContext(); const auto FunctionName = FunctionDeclaration->getBaseIdentifier(); return FunctionName == ASTCtx.Id_MatchOperator || FunctionName == ASTCtx.Id_EqualsOperator || FunctionName == ASTCtx.Id_derived_enum_equals || FunctionName == ASTCtx.Id_derived_struct_equals; } void appendPattern(Expr *LHS, Expr *RHS) { auto *PatternArgument = RHS; if (PatternArgument->getKind() == ExprKind::DeclRef) PatternArgument = LHS; if (ConditionalPattern.size() > 0) ConditionalPattern.append(", "); ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, PatternArgument->getSourceRange()).str()); } }; class ConverterToSwitch { public: ConverterToSwitch(const ResolvedRangeInfo &Info, SourceManager &SM) : Info(Info), SM(SM) { } void performConvert(SmallString<64> &Out) { If = findIf(); OptionalLabel = If->getLabelInfo().Name.str().str(); ControlExpression = findControlExpression(); findPatternsAndBodies(PatternsAndBodies); DefaultStatements = findDefaultStatements(); makeSwitchStatement(Out); } private: const ResolvedRangeInfo &Info; SourceManager &SM; IfStmt *If; IfStmt *PreviousIf; std::string OptionalLabel; std::string ControlExpression; SmallVector, 16> PatternsAndBodies; std::string DefaultStatements; IfStmt *findIf() { auto S = Info.ContainedNodes[0].dyn_cast(); return dyn_cast(S); } std::string findControlExpression() { auto ConditionElement = If->getCond().front(); auto Finder = VarNameFinder(); ConditionElement.walk(Finder); return Finder.VarName; } void findPatternsAndBodies(SmallVectorImpl> &Out) { do { auto pattern = findPattern(); auto body = findBodyStatements(); Out.push_back(std::make_pair(pattern, body)); PreviousIf = If; } while ((If = dyn_cast_or_null(If->getElseStmt()))); } std::string findPattern() { auto ConditionElement = If->getCond().front(); auto Finder = ConditionalPatternFinder(SM); ConditionElement.walk(Finder); return Finder.ConditionalPattern.str().str(); } std::string findBodyStatements() { return findBodyWithoutBraces(If->getThenStmt()); } std::string findDefaultStatements() { auto ElseBody = dyn_cast_or_null(PreviousIf->getElseStmt()); if (!ElseBody) return getTokenText(tok::kw_break).str(); return findBodyWithoutBraces(ElseBody); } std::string findBodyWithoutBraces(Stmt *body) { auto BS = dyn_cast(body); if (!BS) return Lexer::getCharSourceRangeFromSourceRange(SM, body->getSourceRange()).str().str(); if (BS->getElements().empty()) return getTokenText(tok::kw_break).str(); SourceRange BodyRange = BS->getElements().front().getSourceRange(); BodyRange.widen(BS->getElements().back().getSourceRange()); return Lexer::getCharSourceRangeFromSourceRange(SM, BodyRange).str().str(); } void makeSwitchStatement(SmallString<64> &Out) { StringRef Space = " "; StringRef NewLine = "\n"; llvm::raw_svector_ostream OS(Out); if (OptionalLabel.size() > 0) OS << OptionalLabel << ":" << Space; OS << tok::kw_switch << Space << ControlExpression << Space << tok::l_brace << NewLine; for (auto &pair : PatternsAndBodies) { OS << tok::kw_case << Space << pair.first << tok::colon << NewLine; OS << pair.second << NewLine; } OS << tok::kw_default << tok::colon << NewLine; OS << DefaultStatements << NewLine; OS << tok::r_brace; } }; SmallString<64> result; ConverterToSwitch(RangeInfo, SM).performConvert(result); EditConsumer.accept(SM, RangeInfo.ContentRange, result.str()); return false; } /// Struct containing info about an IfStmt that can be converted into a /// TernaryExpr. struct ConvertToTernaryExprInfo { ConvertToTernaryExprInfo() {} Expr *AssignDest() { if (!Then || !Then->getDest() || !Else || !Else->getDest()) return nullptr; auto ThenDest = Then->getDest(); auto ElseDest = Else->getDest(); if (ThenDest->getKind() != ElseDest->getKind()) return nullptr; switch (ThenDest->getKind()) { case ExprKind::DeclRef: { auto ThenRef = dyn_cast(Then->getDest()); auto ElseRef = dyn_cast(Else->getDest()); if (!ThenRef || !ThenRef->getDecl() || !ElseRef || !ElseRef->getDecl()) return nullptr; const auto ThenName = ThenRef->getDecl()->getName(); const auto ElseName = ElseRef->getDecl()->getName(); if (ThenName.compare(ElseName) != 0) return nullptr; return Then->getDest(); } case ExprKind::Tuple: { auto ThenTuple = dyn_cast(Then->getDest()); auto ElseTuple = dyn_cast(Else->getDest()); if (!ThenTuple || !ElseTuple) return nullptr; auto ThenNames = ThenTuple->getElementNames(); auto ElseNames = ElseTuple->getElementNames(); if (!ThenNames.equals(ElseNames)) return nullptr; return ThenTuple; } default: return nullptr; } } Expr *ThenSrc() { if (!Then) return nullptr; return Then->getSrc(); } Expr *ElseSrc() { if (!Else) return nullptr; return Else->getSrc(); } bool isValid() { if (!Cond || !AssignDest() || !ThenSrc() || !ElseSrc() || !IfRange.isValid()) return false; return true; } PatternBindingDecl *Binding = nullptr; //optional Expr *Cond = nullptr; //required AssignExpr *Then = nullptr; //required AssignExpr *Else = nullptr; //required SourceRange IfRange; }; ConvertToTernaryExprInfo findConvertToTernaryExpression(const ResolvedRangeInfo &Info) { auto notFound = ConvertToTernaryExprInfo(); if (Info.Kind != RangeKind::SingleStatement && Info.Kind != RangeKind::MultiStatement) return notFound; if (Info.ContainedNodes.empty()) return notFound; struct AssignExprFinder: public SourceEntityWalker { AssignExpr *Assign = nullptr; AssignExprFinder(Stmt* S) { if (S) walk(S); } virtual bool walkToExprPre(Expr *E) override { Assign = dyn_cast(E); return false; } }; ConvertToTernaryExprInfo Target; IfStmt *If = nullptr; if (Info.ContainedNodes.size() == 1) { if (auto S = Info.ContainedNodes[0].dyn_cast()) If = dyn_cast(S); } if (Info.ContainedNodes.size() == 2) { if (auto D = Info.ContainedNodes[0].dyn_cast()) Target.Binding = dyn_cast(D); if (auto S = Info.ContainedNodes[1].dyn_cast()) If = dyn_cast(S); } if (!If) return notFound; auto CondList = If->getCond(); if (CondList.size() != 1) return notFound; Target.Cond = CondList[0].getBooleanOrNull(); Target.IfRange = If->getSourceRange(); Target.Then = AssignExprFinder(If->getThenStmt()).Assign; Target.Else = AssignExprFinder(If->getElseStmt()).Assign; return Target; } bool RefactoringActionConvertToTernaryExpr:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { return findConvertToTernaryExpression(Info).isValid(); } bool RefactoringActionConvertToTernaryExpr::performChange() { auto Target = findConvertToTernaryExpression(RangeInfo); if (!Target.isValid()) return true; //abort SmallString<64> DeclBuffer; llvm::raw_svector_ostream OS(DeclBuffer); StringRef Space = " "; auto IfRange = Target.IfRange; auto ReplaceRange = Lexer::getCharSourceRangeFromSourceRange(SM, IfRange); auto CondRange = Target.Cond->getSourceRange(); auto CondCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, CondRange); auto ThenRange = Target.ThenSrc()->getSourceRange(); auto ThenCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ThenRange); auto ElseRange = Target.ElseSrc()->getSourceRange(); auto ElseCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, ElseRange); CharSourceRange DestCharRange; if (Target.Binding) { auto DestRange = Target.Binding->getSourceRange(); DestCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, DestRange); ReplaceRange.widen(DestCharRange); } else { auto DestRange = Target.AssignDest()->getSourceRange(); DestCharRange = Lexer::getCharSourceRangeFromSourceRange(SM, DestRange); } OS << DestCharRange.str() << Space << tok::equal << Space; OS << CondCharRange.str() << Space << tok::question_postfix << Space; OS << ThenCharRange.str() << Space << tok::colon << Space; OS << ElseCharRange.str(); EditConsumer.accept(SM, ReplaceRange, DeclBuffer.str()); return false; //don't abort } /// The helper class analyzes a given nominal decl or an extension decl to /// decide whether stubs are required to filled in and the context in which /// these stubs should be filled. class FillProtocolStubContext { std::vector getUnsatisfiedRequirements(const IterableDeclContext *IDC); /// Context in which the content should be filled; this could be either a /// nominal type declaraion or an extension declaration. DeclContext *DC; /// The type that adopts the required protocol stubs. For nominal type decl, this /// should be the declared type itself; for extension decl, this should be the /// extended type at hand. Type Adopter; /// The start location of the decl, either nominal type or extension, for the /// printer to figure out the right indentation. SourceLoc StartLoc; /// The location of '{' for the decl, thus we know where to insert the filling /// stubs. SourceLoc BraceStartLoc; /// The value decls that should be satisfied; this could be either function /// decls, property decls, or required type alias. std::vector FillingContents; public: FillProtocolStubContext(ExtensionDecl *ED) : DC(ED), Adopter(ED->getExtendedType()), StartLoc(ED->getStartLoc()), BraceStartLoc(ED->getBraces().Start), FillingContents(getUnsatisfiedRequirements(ED)) {}; FillProtocolStubContext(NominalTypeDecl *ND) : DC(ND), Adopter(ND->getDeclaredType()), StartLoc(ND->getStartLoc()), BraceStartLoc(ND->getBraces().Start), FillingContents(getUnsatisfiedRequirements(ND)) {}; FillProtocolStubContext() : DC(nullptr), Adopter(), FillingContents({}) {}; static FillProtocolStubContext getContextFromCursorInfo( const ResolvedCursorInfo &Tok); ArrayRef getFillingContents() const { return llvm::makeArrayRef(FillingContents); } DeclContext *getFillingContext() const { return DC; } bool canProceed() const { return StartLoc.isValid() && BraceStartLoc.isValid() && !getFillingContents().empty(); } Type getAdopter() const { return Adopter; } SourceLoc getContextStartLoc() const { return StartLoc; } SourceLoc getBraceStartLoc() const { return BraceStartLoc; } }; FillProtocolStubContext FillProtocolStubContext:: getContextFromCursorInfo(const ResolvedCursorInfo &CursorInfo) { if(!CursorInfo.isValid()) return FillProtocolStubContext(); auto ValueRefInfo = dyn_cast(&CursorInfo); if (!ValueRefInfo) { return FillProtocolStubContext(); } if (!ValueRefInfo->isRef()) { // If the type name is on the declared nominal, e.g. "class A {}" if (auto ND = dyn_cast(ValueRefInfo->getValueD())) { return FillProtocolStubContext(ND); } } else if (auto *ED = ValueRefInfo->getExtTyRef()) { // If the type ref is on a declared extension, e.g. "extension A {}" return FillProtocolStubContext(ED); } return FillProtocolStubContext(); } std::vector FillProtocolStubContext:: getUnsatisfiedRequirements(const IterableDeclContext *IDC) { // The results to return. std::vector NonWitnessedReqs; // For each conformance of the extended nominal. for(ProtocolConformance *Con : IDC->getLocalConformances()) { // Collect non-witnessed requirements. Con->forEachNonWitnessedRequirement( [&](ValueDecl *VD) { NonWitnessedReqs.push_back(VD); }); } return NonWitnessedReqs; } bool RefactoringActionFillProtocolStub:: isApplicable(const ResolvedCursorInfo &Tok, DiagnosticEngine &Diag) { return FillProtocolStubContext::getContextFromCursorInfo(Tok).canProceed(); } bool RefactoringActionFillProtocolStub::performChange() { // Get the filling protocol context from the input token. FillProtocolStubContext Context = FillProtocolStubContext:: getContextFromCursorInfo(CursorInfo); assert(Context.canProceed()); assert(!Context.getFillingContents().empty()); assert(Context.getFillingContext()); SmallString<128> Text; { llvm::raw_svector_ostream SS(Text); Type Adopter = Context.getAdopter(); SourceLoc Loc = Context.getContextStartLoc(); auto Contents = Context.getFillingContents(); // For each unsatisfied requirement, print the stub to the buffer. std::for_each(Contents.begin(), Contents.end(), [&](ValueDecl *VD) { printRequirementStub(VD, Context.getFillingContext(), Adopter, Loc, SS); }); } // Insert all stubs after '{' in the extension/nominal type decl. EditConsumer.insertAfter(SM, Context.getBraceStartLoc(), Text); return false; } static void collectAvailableRefactoringsAtCursor( SourceFile *SF, unsigned Line, unsigned Column, SmallVectorImpl &Kinds, ArrayRef DiagConsumers) { // Prepare the tool box. ASTContext &Ctx = SF->getASTContext(); SourceManager &SM = Ctx.SourceMgr; DiagnosticEngine DiagEngine(SM); std::for_each(DiagConsumers.begin(), DiagConsumers.end(), [&](DiagnosticConsumer *Con) { DiagEngine.addConsumer(*Con); }); SourceLoc Loc = SM.getLocForLineCol(SF->getBufferID().value(), Line, Column); if (Loc.isInvalid()) return; ResolvedCursorInfo Tok = evaluateOrDefault(SF->getASTContext().evaluator, CursorInfoRequest{CursorInfoOwner(SF, Lexer::getLocForStartOfToken(SM, Loc))}, ResolvedCursorInfo()); collectAvailableRefactorings(Tok, Kinds, /*Exclude rename*/ false); } static EnumDecl* getEnumDeclFromSwitchStmt(SwitchStmt *SwitchS) { if (auto SubjectTy = SwitchS->getSubjectExpr()->getType()) { // FIXME: Support more complex subject like '(Enum1, Enum2)'. return dyn_cast_or_null(SubjectTy->getAnyNominal()); } return nullptr; } static bool performCasesExpansionInSwitchStmt(SwitchStmt *SwitchS, DiagnosticEngine &DiagEngine, SourceLoc ExpandedStmtLoc, EditorConsumerInsertStream &OS ) { // Assume enum elements are not handled in the switch statement. auto EnumDecl = getEnumDeclFromSwitchStmt(SwitchS); assert(EnumDecl); llvm::DenseSet UnhandledElements; EnumDecl->getAllElements(UnhandledElements); for (auto Current : SwitchS->getCases()) { if (Current->isDefault()) { continue; } // For each handled enum element, remove it from the bucket. for (auto Item : Current->getCaseLabelItems()) { if (auto *EEP = dyn_cast_or_null(Item.getPattern())) { UnhandledElements.erase(EEP->getElementDecl()); } } } // If all enum elements are handled in the switch statement, issue error. if (UnhandledElements.empty()) { DiagEngine.diagnose(ExpandedStmtLoc, diag::no_remaining_cases); return true; } printEnumElementsAsCases(UnhandledElements, OS); return false; } // Finds SwitchStmt that contains given CaseStmt. static SwitchStmt* findEnclosingSwitchStmt(CaseStmt *CS, SourceFile *SF, DiagnosticEngine &DiagEngine) { auto IsSwitch = [](ASTNode Node) { return Node.is() && Node.get()->getKind() == StmtKind::Switch; }; ContextFinder Finder(*SF, CS, IsSwitch); Finder.resolve(); // If failed to find the switch statement, issue error. if (Finder.getContexts().empty()) { DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch); return nullptr; } auto *SwitchS = static_cast(Finder.getContexts().back(). get()); // Make sure that CaseStmt is included in switch that was found. auto Cases = SwitchS->getCases(); auto Default = std::find(Cases.begin(), Cases.end(), CS); if (Default == Cases.end()) { DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch); return nullptr; } return SwitchS; } bool RefactoringActionExpandDefault:: isApplicable(const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) { auto Exit = [&](bool Applicable) { if (!Applicable) Diag.diagnose(SourceLoc(), diag::invalid_default_location); return Applicable; }; auto StmtStartInfo = dyn_cast(&CursorInfo); if (!StmtStartInfo) return Exit(false); if (auto *CS = dyn_cast(StmtStartInfo->getTrailingStmt())) { auto EnclosingSwitchStmt = findEnclosingSwitchStmt(CS, CursorInfo.getSourceFile(), Diag); if (!EnclosingSwitchStmt) return false; auto EnumD = getEnumDeclFromSwitchStmt(EnclosingSwitchStmt); auto IsApplicable = CS->isDefault() && EnumD != nullptr; return IsApplicable; } return Exit(false); } bool RefactoringActionExpandDefault::performChange() { // If we've not seen the default statement inside the switch statement, issue // error. auto StmtStartInfo = cast(CursorInfo); auto *CS = static_cast(StmtStartInfo.getTrailingStmt()); auto *SwitchS = findEnclosingSwitchStmt(CS, TheFile, DiagEngine); assert(SwitchS); EditorConsumerInsertStream OS(EditConsumer, SM, Lexer::getCharSourceRangeFromSourceRange(SM, CS->getLabelItemsRange())); return performCasesExpansionInSwitchStmt(SwitchS, DiagEngine, CS->getStartLoc(), OS); } bool RefactoringActionExpandSwitchCases:: isApplicable(const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &DiagEngine) { auto StmtStartInfo = dyn_cast(&CursorInfo); if (!StmtStartInfo || !StmtStartInfo->getTrailingStmt()) return false; if (auto *Switch = dyn_cast(StmtStartInfo->getTrailingStmt())) { return getEnumDeclFromSwitchStmt(Switch); } return false; } bool RefactoringActionExpandSwitchCases::performChange() { auto StmtStartInfo = cast(CursorInfo); auto *SwitchS = dyn_cast(StmtStartInfo.getTrailingStmt()); assert(SwitchS); auto InsertRange = CharSourceRange(); auto Cases = SwitchS->getCases(); auto Default = std::find_if(Cases.begin(), Cases.end(), [](CaseStmt *Stmt) { return Stmt->isDefault(); }); if (Default != Cases.end()) { auto DefaultRange = (*Default)->getLabelItemsRange(); InsertRange = Lexer::getCharSourceRangeFromSourceRange(SM, DefaultRange); } else { auto RBraceLoc = SwitchS->getRBraceLoc(); InsertRange = CharSourceRange(SM, RBraceLoc, RBraceLoc); } EditorConsumerInsertStream OS(EditConsumer, SM, InsertRange); if (SM.getLineAndColumnInBuffer(SwitchS->getLBraceLoc()).first == SM.getLineAndColumnInBuffer(SwitchS->getRBraceLoc()).first) { OS << "\n"; } auto Result = performCasesExpansionInSwitchStmt(SwitchS, DiagEngine, SwitchS->getStartLoc(), OS); return Result; } static Expr *findLocalizeTarget(const ResolvedCursorInfo &CursorInfo) { auto ExprStartInfo = dyn_cast(&CursorInfo); if (!ExprStartInfo) return nullptr; struct StringLiteralFinder: public SourceEntityWalker { SourceLoc StartLoc; Expr *Target; StringLiteralFinder(SourceLoc StartLoc): StartLoc(StartLoc), Target(nullptr) {} bool walkToExprPre(Expr *E) override { if (E->getStartLoc() != StartLoc) return false; if (E->getKind() == ExprKind::InterpolatedStringLiteral) return false; if (E->getKind() == ExprKind::StringLiteral) { Target = E; return false; } return true; } } Walker(ExprStartInfo->getTrailingExpr()->getStartLoc()); Walker.walk(ExprStartInfo->getTrailingExpr()); return Walker.Target; } bool RefactoringActionLocalizeString:: isApplicable(const ResolvedCursorInfo &Tok, DiagnosticEngine &Diag) { return findLocalizeTarget(Tok); } bool RefactoringActionLocalizeString::performChange() { Expr* Target = findLocalizeTarget(CursorInfo); if (!Target) return true; EditConsumer.accept(SM, Target->getStartLoc(), "NSLocalizedString("); EditConsumer.insertAfter(SM, Target->getEndLoc(), ", comment: \"\")"); return false; } struct MemberwiseParameter { CharSourceRange NameRange; Type MemberType; Expr *DefaultExpr; MemberwiseParameter(CharSourceRange nameRange, Type type, Expr *initialExpr) : NameRange(nameRange), MemberType(type), DefaultExpr(initialExpr) {} }; static void generateMemberwiseInit(SourceEditConsumer &EditConsumer, SourceManager &SM, ArrayRef memberVector, SourceLoc targetLocation) { EditConsumer.accept(SM, targetLocation, "\ninternal init("); auto insertMember = [&SM](const MemberwiseParameter &memberData, raw_ostream &OS, bool wantsSeparator) { { OS << SM.extractText(memberData.NameRange) << ": "; // Unconditionally print '@escaping' if we print out a function type - // the assignments we generate below will escape this parameter. if (isa(memberData.MemberType->getCanonicalType())) { OS << "@" << TypeAttributes::getAttrName(TAK_escaping) << " "; } OS << memberData.MemberType.getString(); } bool HasAddedDefault = false; if (auto *expr = memberData.DefaultExpr) { if (expr->getSourceRange().isValid()) { auto range = Lexer::getCharSourceRangeFromSourceRange( SM, expr->getSourceRange()); OS << " = " << SM.extractText(range); HasAddedDefault = true; } } if (!HasAddedDefault && memberData.MemberType->isOptional()) { OS << " = nil"; } if (wantsSeparator) { OS << ", "; } }; // Process the initial list of members, inserting commas as appropriate. std::string Buffer; llvm::raw_string_ostream OS(Buffer); for (const auto &memberData : llvm::enumerate(memberVector)) { bool wantsSeparator = (memberData.index() != memberVector.size() - 1); insertMember(memberData.value(), OS, wantsSeparator); } // Synthesize the body. OS << ") {\n"; for (auto &member : memberVector) { // self. = auto name = SM.extractText(member.NameRange); OS << "self." << name << " = " << name << "\n"; } OS << "}\n"; // Accept the entire edit. EditConsumer.accept(SM, targetLocation, OS.str()); } static SourceLoc collectMembersForInit(const ResolvedCursorInfo &CursorInfo, SmallVectorImpl &memberVector) { auto ValueRefInfo = dyn_cast(&CursorInfo); if (!ValueRefInfo || !ValueRefInfo->getValueD()) return SourceLoc(); NominalTypeDecl *nominalDecl = dyn_cast(ValueRefInfo->getValueD()); if (!nominalDecl || nominalDecl->getStoredProperties().empty() || ValueRefInfo->isRef()) { return SourceLoc(); } SourceLoc bracesStart = nominalDecl->getBraces().Start; if (!bracesStart.isValid()) return SourceLoc(); SourceLoc targetLocation = bracesStart.getAdvancedLoc(1); if (!targetLocation.isValid()) return SourceLoc(); SourceManager &SM = nominalDecl->getASTContext().SourceMgr; for (auto member : nominalDecl->getMembers()) { auto varDecl = dyn_cast(member); if (!varDecl) { continue; } if (varDecl->getAttrs().hasAttribute()) { // Exclude lazy members from the memberwise initializer. This is // inconsistent with the implicitly synthesized memberwise initializer but // we think it makes more sense because otherwise the lazy variable's // initializer gets evaluated eagerly. continue; } if (!varDecl->isMemberwiseInitialized(/*preferDeclaredProperties=*/true)) { continue; } auto patternBinding = varDecl->getParentPatternBinding(); if (!patternBinding) continue; const auto i = patternBinding->getPatternEntryIndexForVarDecl(varDecl); Expr *defaultInit = nullptr; if (patternBinding->isExplicitlyInitialized(i) || patternBinding->isDefaultInitializable()) { defaultInit = patternBinding->getOriginalInit(i); } auto NameRange = Lexer::getCharSourceRangeFromSourceRange(SM, varDecl->getNameLoc()); memberVector.emplace_back(NameRange, varDecl->getType(), defaultInit); } return targetLocation; } bool RefactoringActionMemberwiseInitLocalRefactoring:: isApplicable(const ResolvedCursorInfo &Tok, DiagnosticEngine &Diag) { SmallVector memberVector; return collectMembersForInit(Tok, memberVector).isValid(); } bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() { SmallVector memberVector; SourceLoc targetLocation = collectMembersForInit(CursorInfo, memberVector); if (targetLocation.isInvalid()) return true; generateMemberwiseInit(EditConsumer, SM, memberVector, targetLocation); return false; } /// If \p NTD is a protocol, return all the protocols it inherits from. If it's /// a type, return all the protocols it conforms to. static SmallVector getAllProtocols(NominalTypeDecl *NTD) { if (auto Proto = dyn_cast(NTD)) { return SmallVector( Proto->getInheritedProtocols().begin(), Proto->getInheritedProtocols().end()); } else { return NTD->getAllProtocols(); } } class AddEquatableContext { /// Declaration context DeclContext *DC; /// Adopter type Type Adopter; /// Start location of declaration context brace SourceLoc StartLoc; /// Array of all inherited protocols' locations ArrayRef ProtocolsLocations; /// Array of all conformed protocols SmallVector Protocols; /// Start location of declaration, /// a place to write protocol name SourceLoc ProtInsertStartLoc; /// Stored properties of extending adopter ArrayRef StoredProperties; /// Range of internal members in declaration DeclRange Range; bool conformsToEquatableProtocol() { for (ProtocolDecl *Protocol : Protocols) { if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) { return true; } } return false; } bool isRequirementValid() { auto Reqs = getProtocolRequirements(); if (Reqs.empty()) { return false; } auto Req = dyn_cast(Reqs[0]); return Req && Req->getParameters()->size() == 2; } bool isPropertiesListValid() { return !getUserAccessibleProperties().empty(); } void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params); std::vector getProtocolRequirements(); std::vector getUserAccessibleProperties(); public: AddEquatableContext(NominalTypeDecl *Decl) : DC(Decl), Adopter(Decl->getDeclaredType()), StartLoc(Decl->getBraces().Start), ProtocolsLocations(Decl->getInherited()), Protocols(getAllProtocols(Decl)), ProtInsertStartLoc(Decl->getNameLoc()), StoredProperties(Decl->getStoredProperties()), Range(Decl->getMembers()){}; AddEquatableContext(ExtensionDecl *Decl) : DC(Decl), Adopter(Decl->getExtendedType()), StartLoc(Decl->getBraces().Start), ProtocolsLocations(Decl->getInherited()), Protocols(getAllProtocols(Decl->getExtendedNominal())), ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()), StoredProperties(Decl->getExtendedNominal()->getStoredProperties()), Range(Decl->getMembers()){}; AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(), Protocols(), StoredProperties(), Range(nullptr, nullptr) {}; static AddEquatableContext getDeclarationContextFromInfo(const ResolvedCursorInfo &Info); std::string getInsertionTextForProtocol(); std::string getInsertionTextForFunction(SourceManager &SM); bool isValid() { // FIXME: Allow to generate explicit == method for declarations which already have // compiler-generated == method return StartLoc.isValid() && ProtInsertStartLoc.isValid() && !conformsToEquatableProtocol() && isPropertiesListValid() && isRequirementValid(); } SourceLoc getStartLocForProtocolDecl() { if (ProtocolsLocations.empty()) { return ProtInsertStartLoc; } return ProtocolsLocations.back().getSourceRange().Start; } bool isMembersRangeEmpty() { return Range.empty(); } SourceLoc getInsertStartLoc(); }; SourceLoc AddEquatableContext:: getInsertStartLoc() { SourceLoc MaxLoc = StartLoc; for (auto Mem : Range) { if (Mem->getEndLoc().getOpaquePointerValue() > MaxLoc.getOpaquePointerValue()) { MaxLoc = Mem->getEndLoc(); } } return MaxLoc; } std::string AddEquatableContext:: getInsertionTextForProtocol() { StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable); std::string Buffer; llvm::raw_string_ostream OS(Buffer); if (ProtocolsLocations.empty()) { OS << ": " << ProtocolName; return Buffer; } OS << ", " << ProtocolName; return Buffer; } std::string AddEquatableContext:: getInsertionTextForFunction(SourceManager &SM) { auto Reqs = getProtocolRequirements(); auto Req = dyn_cast(Reqs[0]); auto Params = Req->getParameters(); StringRef ExtraIndent; StringRef CurrentIndent = Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent); std::string Indent; if (isMembersRangeEmpty()) { Indent = (CurrentIndent + ExtraIndent).str(); } else { Indent = CurrentIndent.str(); } PrintOptions Options = PrintOptions::printVerbose(); Options.PrintDocumentationComments = false; Options.setBaseType(Adopter); Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) { Printer << " {"; Printer.printNewline(); printFunctionBody(Printer, ExtraIndent, Params); Printer.printNewline(); Printer << "}"; }; std::string Buffer; llvm::raw_string_ostream OS(Buffer); ExtraIndentStreamPrinter Printer(OS, Indent); Printer.printNewline(); if (!isMembersRangeEmpty()) { Printer.printNewline(); } Reqs[0]->print(Printer, Options); return Buffer; } std::vector AddEquatableContext:: getUserAccessibleProperties() { std::vector PublicProperties; for (VarDecl *Decl : StoredProperties) { if (Decl->Decl::isUserAccessible()) { PublicProperties.push_back(Decl); } } return PublicProperties; } std::vector AddEquatableContext:: getProtocolRequirements() { std::vector Collection; auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable); for (auto Member : Proto->getMembers()) { auto Req = dyn_cast(Member); if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) { continue; } Collection.push_back(Req); } return Collection; } AddEquatableContext AddEquatableContext:: getDeclarationContextFromInfo(const ResolvedCursorInfo &Info) { auto ValueRefInfo = dyn_cast(&Info); if (!ValueRefInfo) { return AddEquatableContext(); } if (!ValueRefInfo->isRef()) { if (auto *NomDecl = dyn_cast(ValueRefInfo->getValueD())) { return AddEquatableContext(NomDecl); } } else if (auto *ExtDecl = ValueRefInfo->getExtTyRef()) { if (ExtDecl->getExtendedNominal()) { return AddEquatableContext(ExtDecl); } } return AddEquatableContext(); } void AddEquatableContext:: printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) { SmallString<128> Return; llvm::raw_svector_ostream SS(Return); SS << tok::kw_return; StringRef Space = " "; StringRef AdditionalSpace = " "; StringRef Point = "."; StringRef Join = " == "; StringRef And = " &&"; auto Props = getUserAccessibleProperties(); auto FParam = Params->get(0)->getName(); auto SParam = Params->get(1)->getName(); auto Prop = Props[0]->getName(); Printer << ExtraIndent << Return << Space << FParam << Point << Prop << Join << SParam << Point << Prop; if (Props.size() > 1) { std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){ auto Name = VD->getName(); Printer << And; Printer.printNewline(); Printer << ExtraIndent << AdditionalSpace << FParam << Point << Name << Join << SParam << Point << Name; }); } } bool RefactoringActionAddEquatableConformance:: isApplicable(const ResolvedCursorInfo &Tok, DiagnosticEngine &Diag) { return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid(); } bool RefactoringActionAddEquatableConformance:: performChange() { auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo); EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(), Context.getInsertionTextForProtocol()); EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), Context.getInsertionTextForFunction(SM)); return false; } class AddCodableContext { /// Declaration context DeclContext *DC; /// Start location of declaration context brace SourceLoc StartLoc; /// Array of all conformed protocols SmallVector Protocols; /// Range of internal members in declaration DeclRange Range; bool conformsToCodableProtocol() { for (ProtocolDecl *Protocol : Protocols) { if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Encodable || Protocol->getKnownProtocolKind() == KnownProtocolKind::Decodable) { return true; } } return false; } public: AddCodableContext(NominalTypeDecl *Decl) : DC(Decl), StartLoc(Decl->getBraces().Start), Protocols(getAllProtocols(Decl)), Range(Decl->getMembers()){}; AddCodableContext(ExtensionDecl *Decl) : DC(Decl), StartLoc(Decl->getBraces().Start), Protocols(getAllProtocols(Decl->getExtendedNominal())), Range(Decl->getMembers()){}; AddCodableContext() : DC(nullptr), Protocols(), Range(nullptr, nullptr){}; static AddCodableContext getDeclarationContextFromInfo(const ResolvedCursorInfo &Info); void printInsertionText(const ResolvedCursorInfo &CursorInfo, SourceManager &SM, llvm::raw_ostream &OS); bool isValid() { return StartLoc.isValid() && conformsToCodableProtocol(); } SourceLoc getInsertStartLoc(); }; SourceLoc AddCodableContext::getInsertStartLoc() { SourceLoc MaxLoc = StartLoc; for (auto Mem : Range) { if (Mem->getEndLoc().getOpaquePointerValue() > MaxLoc.getOpaquePointerValue()) { MaxLoc = Mem->getEndLoc(); } } return MaxLoc; } /// Walks an AST and prints the synthesized Codable implementation. class SynthesizedCodablePrinter : public ASTWalker { private: ASTPrinter &Printer; public: SynthesizedCodablePrinter(ASTPrinter &Printer) : Printer(Printer) {} PreWalkAction walkToDeclPre(Decl *D) override { auto *VD = dyn_cast(D); if (!VD) return Action::SkipChildren(); if (!VD->isSynthesized()) { return Action::Continue(); } SmallString<32> Scratch; auto name = VD->getName().getString(Scratch); // Print all synthesized enums, // since Codable can synthesize multiple enums (for associated values). auto shouldPrint = isa(VD) || name == "init(from:)" || name == "encode(to:)"; if (!shouldPrint) { // Some other synthesized decl that we don't want to print. return Action::SkipChildren(); } Printer.printNewline(); if (auto enumDecl = dyn_cast(D)) { // Manually print enum here, since we don't want to print synthesized // functions. Printer << "enum " << enumDecl->getNameStr(); PrintOptions Options; Options.PrintSpaceBeforeInheritance = false; enumDecl->printInherited(Printer, Options); Printer << " {"; for (Decl *EC : enumDecl->getAllElements()) { Printer.printNewline(); Printer << " "; EC->print(Printer, Options); } Printer.printNewline(); Printer << "}"; return Action::SkipChildren(); } PrintOptions Options; Options.SynthesizeSugarOnTypes = true; Options.FunctionDefinitions = true; Options.VarInitializers = true; Options.PrintExprs = true; Options.TypeDefinitions = true; Options.ExcludeAttrList.push_back(DAK_HasInitialValue); Printer.printNewline(); D->print(Printer, Options); return Action::SkipChildren(); } }; void AddCodableContext::printInsertionText(const ResolvedCursorInfo &CursorInfo, SourceManager &SM, llvm::raw_ostream &OS) { StringRef ExtraIndent; StringRef CurrentIndent = Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent); std::string Indent; if (getInsertStartLoc() == StartLoc) { Indent = (CurrentIndent + ExtraIndent).str(); } else { Indent = CurrentIndent.str(); } ExtraIndentStreamPrinter Printer(OS, Indent); Printer.printNewline(); SynthesizedCodablePrinter Walker(Printer); DC->getAsDecl()->walk(Walker); } AddCodableContext AddCodableContext::getDeclarationContextFromInfo( const ResolvedCursorInfo &Info) { auto ValueRefInfo = dyn_cast(&Info); if (!ValueRefInfo) { return AddCodableContext(); } if (!ValueRefInfo->isRef()) { if (auto *NomDecl = dyn_cast(ValueRefInfo->getValueD())) { return AddCodableContext(NomDecl); } } // TODO: support extensions // (would need to get synthesized nodes from the main decl, // and only if it's in the same file?) return AddCodableContext(); } bool RefactoringActionAddExplicitCodableImplementation::isApplicable( const ResolvedCursorInfo &Tok, DiagnosticEngine &Diag) { return AddCodableContext::getDeclarationContextFromInfo(Tok).isValid(); } bool RefactoringActionAddExplicitCodableImplementation::performChange() { auto Context = AddCodableContext::getDeclarationContextFromInfo(CursorInfo); SmallString<64> Buffer; llvm::raw_svector_ostream OS(Buffer); Context.printInsertionText(CursorInfo, SM, OS); EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), OS.str()); return false; } static CharSourceRange findSourceRangeToWrapInCatch(const ResolvedExprStartCursorInfo &CursorInfo, SourceFile *TheFile, SourceManager &SM) { Expr *E = CursorInfo.getTrailingExpr(); if (!E) return CharSourceRange(); auto Node = ASTNode(E); auto NodeChecker = [](ASTNode N) { return N.isStmt(StmtKind::Brace); }; ContextFinder Finder(*TheFile, Node, NodeChecker); Finder.resolve(); auto Contexts = Finder.getContexts(); if (Contexts.empty()) return CharSourceRange(); auto TargetNode = Contexts.back(); BraceStmt *BStmt = dyn_cast(TargetNode.dyn_cast()); auto ConvertToCharRange = [&SM](SourceRange SR) { return Lexer::getCharSourceRangeFromSourceRange(SM, SR); }; assert(BStmt); auto ExprRange = ConvertToCharRange(E->getSourceRange()); // Check elements of the deepest BraceStmt, pick one that covers expression. for (auto Elem: BStmt->getElements()) { auto ElemRange = ConvertToCharRange(Elem.getSourceRange()); if (ElemRange.contains(ExprRange)) TargetNode = Elem; } return ConvertToCharRange(TargetNode.getSourceRange()); } bool RefactoringActionConvertToDoCatch:: isApplicable(const ResolvedCursorInfo &Tok, DiagnosticEngine &Diag) { auto ExprStartInfo = dyn_cast(&Tok); if (!ExprStartInfo || !ExprStartInfo->getTrailingExpr()) return false; return isa(ExprStartInfo->getTrailingExpr()); } bool RefactoringActionConvertToDoCatch::performChange() { auto ExprStartInfo = cast(CursorInfo); auto *TryExpr = dyn_cast(ExprStartInfo.getTrailingExpr()); assert(TryExpr); auto Range = findSourceRangeToWrapInCatch(ExprStartInfo, TheFile, SM); if (!Range.isValid()) return true; // Wrap given range in do catch block. EditConsumer.accept(SM, Range.getStart(), "do {\n"); EditorConsumerInsertStream OS(EditConsumer, SM, Range.getEnd()); OS << "\n} catch {\n" << getCodePlaceholder() << "\n}"; // Delete ! from try! expression auto ExclaimLen = getKeywordLen(tok::exclaim_postfix); auto ExclaimRange = CharSourceRange(TryExpr->getExclaimLoc(), ExclaimLen); EditConsumer.remove(SM, ExclaimRange); return false; } /// Given a cursor position, this function tries to collect a number literal /// expression immediately following the cursor. static NumberLiteralExpr *getTrailingNumberLiteral( const ResolvedCursorInfo &Tok) { // This cursor must point to the start of an expression. auto ExprStartInfo = dyn_cast(&Tok); if (!ExprStartInfo) return nullptr; // For every sub-expression, try to find the literal expression that matches // our criteria. class FindLiteralNumber : public ASTWalker { Expr * const parent; public: NumberLiteralExpr *found = nullptr; explicit FindLiteralNumber(Expr *parent) : parent(parent) { } PreWalkResult walkToExprPre(Expr *expr) override { if (auto *literal = dyn_cast(expr)) { // The sub-expression must have the same start loc with the outermost // expression, i.e. the cursor position. if (!found && parent->getStartLoc().getOpaquePointerValue() == expr->getStartLoc().getOpaquePointerValue()) { found = literal; } } return Action::SkipChildrenIf(found, expr); } }; auto parent = ExprStartInfo->getTrailingExpr(); FindLiteralNumber finder(parent); parent->walk(finder); return finder.found; } static std::string insertUnderscore(StringRef Text) { SmallString<64> Buffer; llvm::raw_svector_ostream OS(Buffer); for (auto It = Text.begin(); It != Text.end(); ++It) { unsigned Distance = It - Text.begin(); if (Distance && !(Distance % 3)) { OS << '_'; } OS << *It; } return OS.str().str(); } void insertUnderscoreInDigits(StringRef Digits, raw_ostream &OS) { StringRef BeforePointRef, AfterPointRef; std::tie(BeforePointRef, AfterPointRef) = Digits.split('.'); std::string BeforePoint(BeforePointRef); std::string AfterPoint(AfterPointRef); // Insert '_' for the part before the decimal point. std::reverse(BeforePoint.begin(), BeforePoint.end()); BeforePoint = insertUnderscore(BeforePoint); std::reverse(BeforePoint.begin(), BeforePoint.end()); OS << BeforePoint; // Insert '_' for the part after the decimal point, if necessary. if (!AfterPoint.empty()) { OS << '.'; OS << insertUnderscore(AfterPoint); } } bool RefactoringActionSimplifyNumberLiteral:: isApplicable(const ResolvedCursorInfo &Tok, DiagnosticEngine &Diag) { if (auto *Literal = getTrailingNumberLiteral(Tok)) { SmallString<64> Buffer; llvm::raw_svector_ostream OS(Buffer); StringRef Digits = Literal->getDigitsText(); insertUnderscoreInDigits(Digits, OS); // If inserting '_' results in a different digit sequence, this refactoring // is applicable. return OS.str() != Digits; } return false; } bool RefactoringActionSimplifyNumberLiteral::performChange() { if (auto *Literal = getTrailingNumberLiteral(CursorInfo)) { EditorConsumerInsertStream OS(EditConsumer, SM, CharSourceRange(SM, Literal->getDigitsLoc(), Lexer::getLocForEndOfToken(SM, Literal->getEndLoc()))); StringRef Digits = Literal->getDigitsText(); insertUnderscoreInDigits(Digits, OS); return false; } return true; } static CallExpr *findTrailingClosureTarget( SourceManager &SM, const ResolvedCursorInfo &CursorInfo) { if (CursorInfo.getKind() == CursorInfoKind::StmtStart) // StmtStart postion can't be a part of CallExpr. return nullptr; // Find inner most CallExpr ContextFinder Finder( *CursorInfo.getSourceFile(), CursorInfo.getLoc(), [](ASTNode N) { return N.isStmt(StmtKind::Brace) || N.isExpr(ExprKind::Call); }); Finder.resolve(); auto contexts = Finder.getContexts(); if (contexts.empty()) return nullptr; // If the innermost context is a statement (which will be a BraceStmt per // the filtering condition above), drop it. if (contexts.back().is()) { contexts = contexts.drop_back(); } if (contexts.empty() || !contexts.back().is()) return nullptr; CallExpr *CE = cast(contexts.back().get()); // The last argument is a non-trailing closure? auto *Args = CE->getArgs(); if (Args->empty() || Args->hasAnyTrailingClosures()) return nullptr; auto *LastArg = Args->back().getExpr(); if (auto *ICE = dyn_cast(LastArg)) LastArg = ICE->getSyntacticSubExpr(); if (isa(LastArg) || isa(LastArg)) return CE; return nullptr; } bool RefactoringActionTrailingClosure:: isApplicable(const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) { SourceManager &SM = CursorInfo.getSourceFile()->getASTContext().SourceMgr; return findTrailingClosureTarget(SM, CursorInfo); } bool RefactoringActionTrailingClosure::performChange() { auto *CE = findTrailingClosureTarget(SM, CursorInfo); if (!CE) return true; auto *ArgList = CE->getArgs()->getOriginalArgs(); auto LParenLoc = ArgList->getLParenLoc(); auto RParenLoc = ArgList->getRParenLoc(); if (LParenLoc.isInvalid() || RParenLoc.isInvalid()) return true; auto NumArgs = ArgList->size(); if (NumArgs == 0) return true; auto *ClosureArg = ArgList->getExpr(NumArgs - 1); if (auto *ICE = dyn_cast(ClosureArg)) ClosureArg = ICE->getSyntacticSubExpr(); // Replace: // * Open paren with ' ' if the closure is sole argument. // * Comma with ') ' otherwise. if (NumArgs > 1) { auto *PrevArg = ArgList->getExpr(NumArgs - 2); CharSourceRange PreRange( SM, Lexer::getLocForEndOfToken(SM, PrevArg->getEndLoc()), ClosureArg->getStartLoc()); EditConsumer.accept(SM, PreRange, ") "); } else { CharSourceRange PreRange(SM, LParenLoc, ClosureArg->getStartLoc()); EditConsumer.accept(SM, PreRange, " "); } // Remove original closing paren. CharSourceRange PostRange( SM, Lexer::getLocForEndOfToken(SM, ClosureArg->getEndLoc()), Lexer::getLocForEndOfToken(SM, RParenLoc)); EditConsumer.remove(SM, PostRange); return false; } static bool collectRangeStartRefactorings(const ResolvedRangeInfo &Info) { switch (Info.Kind) { case RangeKind::SingleExpression: case RangeKind::SingleStatement: case RangeKind::SingleDecl: case RangeKind::PartOfExpression: return true; case RangeKind::MultiStatement: case RangeKind::MultiTypeMemberDecl: case RangeKind::Invalid: return false; } } bool RefactoringActionConvertToComputedProperty:: isApplicable(const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { if (Info.Kind != RangeKind::SingleDecl) { return false; } if (Info.ContainedNodes.size() != 1) { return false; } auto D = Info.ContainedNodes[0].dyn_cast(); if (!D) { return false; } auto Binding = dyn_cast(D); if (!Binding) { return false; } auto SV = Binding->getSingleVar(); if (!SV) { return false; } // willSet, didSet cannot be provided together with a getter for (auto AD : SV->getAllAccessors()) { if (AD->isObservingAccessor()) { return false; } } // 'lazy' must not be used on a computed property // NSCopying and IBOutlet attribute requires property to be mutable auto Attributies = SV->getAttrs(); if (Attributies.hasAttribute() || Attributies.hasAttribute() || Attributies.hasAttribute()) { return false; } // Property wrapper cannot be applied to a computed property if (SV->hasAttachedPropertyWrapper()) { return false; } // has an initializer return Binding->hasInitStringRepresentation(0); } bool RefactoringActionConvertToComputedProperty::performChange() { // Get an initialization auto D = RangeInfo.ContainedNodes[0].dyn_cast(); auto Binding = dyn_cast(D); SmallString<128> scratch; auto Init = Binding->getInitStringRepresentation(0, scratch); // Get type auto SV = Binding->getSingleVar(); auto SVType = SV->getType(); auto TR = SV->getTypeReprOrParentPatternTypeRepr(); SmallString<64> DeclBuffer; llvm::raw_svector_ostream OS(DeclBuffer); StringRef Space = " "; StringRef NewLine = "\n"; OS << tok::kw_var << Space; // Add var name OS << SV->getNameStr().str() << ":" << Space; // For computed property must write a type of var if (TR) { OS << Lexer::getCharSourceRangeFromSourceRange(SM, TR->getSourceRange()).str(); } else { SVType.print(OS); } OS << Space << tok::l_brace << NewLine; // Add an initialization OS << tok::kw_return << Space << Init.str() << NewLine; OS << tok::r_brace; // Replace initializer to computed property auto ReplaceStartLoc = Binding->getLoc(); auto ReplaceEndLoc = Binding->getSourceRange().End; auto ReplaceRange = SourceRange(ReplaceStartLoc, ReplaceEndLoc); auto ReplaceCharSourceRange = Lexer::getCharSourceRangeFromSourceRange(SM, ReplaceRange); EditConsumer.accept(SM, ReplaceCharSourceRange, DeclBuffer.str()); return false; // success } namespace asyncrefactorings { // TODO: Should probably split the refactorings into separate files /// Whether the given type is (or conforms to) the stdlib Error type bool isErrorType(Type Ty, ModuleDecl *MD) { if (!Ty) return false; return !MD->conformsToProtocol(Ty, Ty->getASTContext().getErrorDecl()) .isInvalid(); } // The single Decl* subject of a switch statement, or nullptr if none Decl *singleSwitchSubject(const SwitchStmt *Switch) { if (auto *DRE = dyn_cast(Switch->getSubjectExpr())) return DRE->getDecl(); return nullptr; } /// A more aggressive variant of \c Expr::getReferencedDecl that also looks /// through autoclosures created to pass the \c self parameter to a member funcs ValueDecl *getReferencedDecl(const Expr *Fn) { Fn = Fn->getSemanticsProvidingExpr(); if (auto *DRE = dyn_cast(Fn)) return DRE->getDecl(); if (auto ApplyE = dyn_cast(Fn)) return getReferencedDecl(ApplyE->getFn()); if (auto *ACE = dyn_cast(Fn)) { if (auto *Unwrapped = ACE->getUnwrappedCurryThunkExpr()) return getReferencedDecl(Unwrapped); } return nullptr; } FuncDecl *getUnderlyingFunc(const Expr *Fn) { return dyn_cast_or_null(getReferencedDecl(Fn)); } /// Find the outermost call of the given location CallExpr *findOuterCall(const ResolvedCursorInfo &CursorInfo) { auto IncludeInContext = [](ASTNode N) { if (auto *E = N.dyn_cast()) return !E->isImplicit(); return false; }; // TODO: Bit pointless using the "ContextFinder" here. Ideally we would have // already generated a slice of the AST for anything that contains // the cursor location ContextFinder Finder(*CursorInfo.getSourceFile(), CursorInfo.getLoc(), IncludeInContext); Finder.resolve(); auto Contexts = Finder.getContexts(); if (Contexts.empty()) return nullptr; CallExpr *CE = dyn_cast(Contexts[0].get()); if (!CE) return nullptr; SourceManager &SM = CursorInfo.getSourceFile()->getASTContext().SourceMgr; if (!SM.rangeContains(CE->getFn()->getSourceRange(), CursorInfo.getLoc())) return nullptr; return CE; } /// Find the function matching the given location if it is not an accessor and /// either has a body or is a member of a protocol FuncDecl *findFunction(const ResolvedCursorInfo &CursorInfo) { auto IncludeInContext = [](ASTNode N) { if (auto *D = N.dyn_cast()) return !D->isImplicit(); return false; }; ContextFinder Finder(*CursorInfo.getSourceFile(), CursorInfo.getLoc(), IncludeInContext); Finder.resolve(); auto Contexts = Finder.getContexts(); if (Contexts.empty()) return nullptr; if (Contexts.back().isDecl(DeclKind::Param)) Contexts = Contexts.drop_back(); auto *FD = dyn_cast_or_null(Contexts.back().get()); if (!FD || isa(FD)) return nullptr; auto *Body = FD->getBody(); if (!Body && !isa(FD->getDeclContext())) return nullptr; SourceManager &SM = CursorInfo.getSourceFile()->getASTContext().SourceMgr; SourceLoc DeclEnd = Body ? Body->getLBraceLoc() : FD->getEndLoc(); if (!SM.rangeContains(SourceRange(FD->getStartLoc(), DeclEnd), CursorInfo.getLoc())) return nullptr; return FD; } FuncDecl *isOperator(const BinaryExpr *BE) { auto *AE = dyn_cast(BE->getFn()); if (AE) { auto *Callee = AE->getCalledValue(); if (Callee && Callee->isOperator() && isa(Callee)) return cast(Callee); } return nullptr; } /// Describes the expressions to be kept from a call to the handler in a /// function that has (or will have ) and async alternative. Eg. /// ``` /// func toBeAsync(completion: (String?, Error?) -> Void) { /// ... /// completion("something", nil) // Result = ["something"], IsError = false /// ... /// completion(nil, MyError.Bad) // Result = [MyError.Bad], IsError = true /// } class HandlerResult { SmallVector Args; bool IsError = false; public: HandlerResult() {} HandlerResult(ArrayRef ArgsRef) : Args(ArgsRef.begin(), ArgsRef.end()) {} HandlerResult(Argument Arg, bool IsError) : IsError(IsError) { Args.push_back(Arg); } bool isError() { return IsError; } ArrayRef args() { return Args; } }; /// The type of the handler, ie. whether it takes regular parameters or a /// single parameter of `Result` type. enum class HandlerType { INVALID, PARAMS, RESULT }; /// A single return type of a refactored async function. If the async function /// returns a tuple, each element of the tuple (represented by a \c /// LabeledReturnType) might have a label, otherwise the \p Label is empty. struct LabeledReturnType { Identifier Label; swift::Type Ty; LabeledReturnType(Identifier Label, swift::Type Ty) : Label(Label), Ty(Ty) {} }; /// Given a function with an async alternative (or one that *could* have an /// async alternative), stores information about the completion handler. /// The completion handler can be either a variable (which includes a parameter) /// or a function struct AsyncHandlerDesc { PointerUnion Handler = nullptr; HandlerType Type = HandlerType::INVALID; bool HasError = false; static AsyncHandlerDesc get(const ValueDecl *Handler, bool RequireName) { AsyncHandlerDesc HandlerDesc; if (auto Var = dyn_cast(Handler)) { HandlerDesc.Handler = Var; } else if (auto Func = dyn_cast(Handler)) { HandlerDesc.Handler = Func; } else { // The handler must be a variable or function return AsyncHandlerDesc(); } // Callback must have a completion-like name if (RequireName && !isCompletionHandlerParamName(HandlerDesc.getNameStr())) return AsyncHandlerDesc(); // Callback must be a function type and return void. Doesn't need to have // any parameters - may just be a "I'm done" callback auto *HandlerTy = HandlerDesc.getType()->getAs(); if (!HandlerTy || !HandlerTy->getResult()->isVoid()) return AsyncHandlerDesc(); // Find the type of result in the handler (eg. whether it's a Result<...>, // just parameters, or nothing). auto HandlerParams = HandlerTy->getParams(); if (HandlerParams.size() == 1) { auto ParamTy = HandlerParams.back().getPlainType()->getAs(); if (ParamTy && ParamTy->isResult()) { auto GenericArgs = ParamTy->getGenericArgs(); assert(GenericArgs.size() == 2 && "Result should have two params"); HandlerDesc.Type = HandlerType::RESULT; HandlerDesc.HasError = !GenericArgs.back()->isUninhabited(); } } if (HandlerDesc.Type != HandlerType::RESULT) { // Only handle non-result parameters for (auto &Param : HandlerParams) { if (Param.getPlainType() && Param.getPlainType()->isResult()) return AsyncHandlerDesc(); } HandlerDesc.Type = HandlerType::PARAMS; if (!HandlerParams.empty()) { auto LastParamTy = HandlerParams.back().getParameterType(); HandlerDesc.HasError = isErrorType(LastParamTy->getOptionalObjectType(), Handler->getModuleContext()); } } return HandlerDesc; } bool isValid() const { return Type != HandlerType::INVALID; } /// Return the declaration of the completion handler as a \c ValueDecl. /// In practice, the handler will always be a \c VarDecl or \c /// AbstractFunctionDecl. /// \c getNameStr and \c getType provide access functions that are available /// for both variables and functions, but not on \c ValueDecls. const ValueDecl *getHandler() const { if (!Handler) { return nullptr; } if (auto Var = Handler.dyn_cast()) { return Var; } else if (auto Func = Handler.dyn_cast()) { return Func; } else { llvm_unreachable("Unknown handler type"); } } /// Return the name of the completion handler. If it is a variable, the /// variable name, if it's a function, the function base name. StringRef getNameStr() const { if (auto Var = Handler.dyn_cast()) { return Var->getNameStr(); } else if (auto Func = Handler.dyn_cast()) { return Func->getNameStr(); } else { llvm_unreachable("Unknown handler type"); } } HandlerType getHandlerType() const { return Type; } /// Get the type of the completion handler. swift::Type getType() const { if (auto Var = Handler.dyn_cast()) { return Var->getType(); } else if (auto Func = Handler.dyn_cast()) { auto Type = Func->getInterfaceType(); // Undo the self curry thunk if we are referencing a member function. if (Func->hasImplicitSelfDecl()) { assert(Type->is()); Type = Type->getAs()->getResult(); } return Type; } else { llvm_unreachable("Unknown handler type"); } } ArrayRef params() const { auto Ty = getType()->getAs(); assert(Ty && "Type must be a function type"); return Ty->getParams(); } /// Retrieve the parameters relevant to a successful return from the /// completion handler. This drops the Error parameter if present. ArrayRef getSuccessParams() const { if (HasError && Type == HandlerType::PARAMS) return params().drop_back(); return params(); } /// If the completion handler has an Error parameter, return it. Optional getErrorParam() const { if (HasError && Type == HandlerType::PARAMS) return params().back(); return None; } /// Get the type of the error that will be thrown by the \c async method or \c /// None if the completion handler doesn't accept an error parameter. /// This may be more specialized than the generic 'Error' type if the /// completion handler of the converted function takes a more specialized /// error type. Optional getErrorType() const { if (HasError) { switch (Type) { case HandlerType::INVALID: return None; case HandlerType::PARAMS: // The last parameter of the completion handler is the error param return params().back().getPlainType()->lookThroughSingleOptionalType(); case HandlerType::RESULT: assert( params().size() == 1 && "Result handler should have the Result type as the only parameter"); auto ResultType = params().back().getPlainType()->getAs(); auto GenericArgs = ResultType->getGenericArgs(); assert(GenericArgs.size() == 2 && "Result should have two params"); // The second (last) generic parameter of the Result type is the error // type. return GenericArgs.back(); } } else { return None; } } /// The `CallExpr` if the given node is a call to the `Handler` CallExpr *getAsHandlerCall(ASTNode Node) const { if (!isValid()) return nullptr; if (auto E = Node.dyn_cast()) { if (auto *CE = dyn_cast(E->getSemanticsProvidingExpr())) { if (CE->getFn()->getReferencedDecl().getDecl() == getHandler()) { return CE; } } } return nullptr; } /// Returns \c true if the call to the completion handler contains possibly /// non-nil values for both the success and error parameters, e.g. /// \code /// completion(result, error) /// \endcode /// This can only happen if the completion handler is a params handler. bool isAmbiguousCallToParamHandler(const CallExpr *CE) const { if (!HasError || Type != HandlerType::PARAMS) { // Only param handlers with an error can pass both an error AND a result. return false; } auto Args = CE->getArgs()->getArgExprs(); if (!isa(Args.back())) { // We've got an error parameter. If any of the success params is not nil, // the call is ambiguous. for (auto &Arg : Args.drop_back()) { if (!isa(Arg)) { return true; } } } return false; } /// Given a call to the `Handler`, extract the expressions to be returned or /// thrown, taking care to remove the `.success`/`.failure` if it's a /// `RESULT` handler type. /// If the call is ambiguous (contains potentially non-nil arguments to both /// the result and the error parameters), the \p ReturnErrorArgsIfAmbiguous /// determines whether the success or error parameters are passed. HandlerResult extractResultArgs(const CallExpr *CE, bool ReturnErrorArgsIfAmbiguous) const { auto *ArgList = CE->getArgs(); SmallVector Scratch(ArgList->begin(), ArgList->end()); auto Args = llvm::makeArrayRef(Scratch); if (Type == HandlerType::PARAMS) { bool IsErrorResult; if (isAmbiguousCallToParamHandler(CE)) { IsErrorResult = ReturnErrorArgsIfAmbiguous; } else { // If there's an error parameter and the user isn't passing nil to it, // assume this is the error path. IsErrorResult = (HasError && !isa(Args.back().getExpr())); } if (IsErrorResult) return HandlerResult(Args.back(), true); // We can drop the args altogether if they're just Void. if (willAsyncReturnVoid()) return HandlerResult(); return HandlerResult(HasError ? Args.drop_back() : Args); } else if (Type == HandlerType::RESULT) { if (Args.size() != 1) return HandlerResult(Args); auto *ResultCE = dyn_cast(Args[0].getExpr()); if (!ResultCE) return HandlerResult(Args); auto *DSC = dyn_cast(ResultCE->getFn()); if (!DSC) return HandlerResult(Args); auto *D = dyn_cast( DSC->getFn()->getReferencedDecl().getDecl()); if (!D) return HandlerResult(Args); auto ResultArgList = ResultCE->getArgs(); auto isFailure = D->getNameStr() == StringRef("failure"); // We can drop the arg altogether if it's just Void. if (!isFailure && willAsyncReturnVoid()) return HandlerResult(); // Otherwise the arg gets the .success() or .failure() call dropped. return HandlerResult(ResultArgList->get(0), isFailure); } llvm_unreachable("Unhandled result type"); } // Convert the type of a success parameter in the completion handler function // to a return type suitable for an async function. If there is an error // parameter present e.g (T?, Error?) -> Void, this unwraps a level of // optionality from T?. If this is a Result type, returns the success // type T. swift::Type getSuccessParamAsyncReturnType(swift::Type Ty) const { switch (Type) { case HandlerType::PARAMS: { // If there's an Error parameter in the handler, the success branch can // be unwrapped. if (HasError) Ty = Ty->lookThroughSingleOptionalType(); return Ty; } case HandlerType::RESULT: { // Result maps to T. return Ty->castTo()->getGenericArgs()[0]; } case HandlerType::INVALID: llvm_unreachable("Invalid handler type"); } } /// If the async function returns a tuple, the label of the \p Index -th /// element in the returned tuple. If the function doesn't return a tuple or /// the element is unlabeled, an empty identifier is returned. Identifier getAsyncReturnTypeLabel(size_t Index) const { assert(Index < getSuccessParams().size()); if (getSuccessParams().size() <= 1) { // There can't be any labels if the async function doesn't return a tuple. return Identifier(); } else { return getSuccessParams()[Index].getInternalLabel(); } } /// Gets the return value types for the async equivalent of this handler. ArrayRef getAsyncReturnTypes(SmallVectorImpl &Scratch) const { for (size_t I = 0; I < getSuccessParams().size(); ++I) { auto Ty = getSuccessParams()[I].getParameterType(); Scratch.emplace_back(getAsyncReturnTypeLabel(I), getSuccessParamAsyncReturnType(Ty)); } return Scratch; } /// Whether the async equivalent of this handler returns Void. bool willAsyncReturnVoid() const { // If all of the success params will be converted to Void return types, // this will be a Void async function. return llvm::all_of(getSuccessParams(), [&](auto ¶m) { auto Ty = param.getParameterType(); return getSuccessParamAsyncReturnType(Ty)->isVoid(); }); } // TODO: If we have an async alternative we should check its result types // for whether to unwrap or not bool shouldUnwrap(swift::Type Ty) const { return HasError && Ty->isOptional(); } }; /// Given a completion handler that is part of a function signature, stores /// information about that completion handler and its index within the function /// declaration. struct AsyncHandlerParamDesc : public AsyncHandlerDesc { /// Enum to represent the position of the completion handler param within /// the parameter list. Given `(A, B, C, D)`: /// - A is `First` /// - B and C are `Middle` /// - D is `Last` /// The position is `Only` if there's a single parameter that is the /// completion handler and `None` if there is no handler. enum class Position { First, Middle, Last, Only, None }; /// The function the completion handler is a parameter of. const FuncDecl *Func = nullptr; /// The index of the completion handler in the function that declares it. unsigned Index = 0; /// The async alternative, if one is found. const AbstractFunctionDecl *Alternative = nullptr; AsyncHandlerParamDesc() : AsyncHandlerDesc() {} AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, const FuncDecl *Func, unsigned Index, const AbstractFunctionDecl *Alternative) : AsyncHandlerDesc(Handler), Func(Func), Index(Index), Alternative(Alternative) {} static AsyncHandlerParamDesc find(const FuncDecl *FD, bool RequireAttributeOrName) { if (!FD || FD->hasAsync() || FD->hasThrows() || !FD->getResultInterfaceType()->isVoid()) return AsyncHandlerParamDesc(); const auto *Alternative = FD->getAsyncAlternative(); Optional Index = FD->findPotentialCompletionHandlerParam(Alternative); if (!Index) return AsyncHandlerParamDesc(); bool RequireName = RequireAttributeOrName && !Alternative; return AsyncHandlerParamDesc( AsyncHandlerDesc::get(FD->getParameters()->get(*Index), RequireName), FD, *Index, Alternative); } /// Build an @available attribute with the name of the async alternative as /// the \c renamed argument, followed by a newline. SmallString<128> buildRenamedAttribute() const { SmallString<128> AvailabilityAttr; llvm::raw_svector_ostream OS(AvailabilityAttr); // If there's an alternative then there must already be an attribute, // don't add another. if (!isValid() || Alternative) return AvailabilityAttr; DeclName Name = Func->getName(); OS << "@available(*, renamed: \"" << Name.getBaseName() << "("; ArrayRef ArgNames = Name.getArgumentNames(); for (size_t I = 0; I < ArgNames.size(); ++I) { if (I != Index) { OS << ArgNames[I] << tok::colon; } } OS << ")\")\n"; return AvailabilityAttr; } /// Retrieves the parameter decl for the completion handler parameter, or /// \c nullptr if no valid completion parameter is present. const ParamDecl *getHandlerParam() const { if (!isValid()) return nullptr; return cast(getHandler()); } /// See \c Position Position handlerParamPosition() const { if (!isValid()) return Position::None; const auto *Params = Func->getParameters(); if (Params->size() == 1) return Position::Only; if (Index == 0) return Position::First; if (Index == Params->size() - 1) return Position::Last; return Position::Middle; } bool operator==(const AsyncHandlerParamDesc &Other) const { return Handler == Other.Handler && Type == Other.Type && HasError == Other.HasError && Index == Other.Index; } bool alternativeIsAccessor() const { return isa_and_nonnull(Alternative); } }; /// The type of a condition in a conditional statement. enum class ConditionType { NIL, // == nil NOT_NIL, // != nil IS_TRUE, // if b IS_FALSE, // if !b SUCCESS_PATTERN, // case .success FAILURE_PATTEN // case .failure }; /// Indicates whether a condition describes a success or failure path. For /// example, a check for whether an error parameter is present is a failure /// path. A check for a nil error parameter is a success path. This is distinct /// from ConditionType, as it relies on contextual information about what values /// need to be checked for success or failure. enum class ConditionPath { SUCCESS, FAILURE }; static ConditionPath flippedConditionPath(ConditionPath Path) { switch (Path) { case ConditionPath::SUCCESS: return ConditionPath::FAILURE; case ConditionPath::FAILURE: return ConditionPath::SUCCESS; } llvm_unreachable("Unhandled case in switch!"); } /// Finds the `Subject` being compared to in various conditions. Also finds any /// pattern that may have a bound name. struct CallbackCondition { Optional Type; const Decl *Subject = nullptr; const Pattern *BindPattern = nullptr; /// Initializes a `CallbackCondition` with a `!=` or `==` comparison of /// an `Optional` typed `Subject` to `nil`, or a `Bool` typed `Subject` to a /// boolean literal, ie. /// - ` != nil` /// - ` == nil` /// - ` != true` /// - ` == false` CallbackCondition(const BinaryExpr *BE, const FuncDecl *Operator) { bool FoundNil = false; BooleanLiteralExpr *FoundBool = nullptr; bool DidUnwrapOptional = false; for (auto *Operand : {BE->getLHS(), BE->getRHS()}) { Operand = Operand->getSemanticsProvidingExpr(); if (auto *IIOE = dyn_cast(Operand)) { Operand = IIOE->getSubExpr()->getSemanticsProvidingExpr(); DidUnwrapOptional = true; } if (isa(Operand)) { FoundNil = true; } else if (auto *BLE = dyn_cast(Operand)) { FoundBool = BLE; } else if (auto *DRE = dyn_cast(Operand)) { Subject = DRE->getDecl(); } } if (!Subject) return; if (FoundNil) { if (Operator->getBaseName() == "==") { Type = ConditionType::NIL; } else if (Operator->getBaseName() == "!=") { Type = ConditionType::NOT_NIL; } } else if (FoundBool) { if (Operator->getBaseName() == "==") { Type = FoundBool->getValue() ? ConditionType::IS_TRUE : ConditionType::IS_FALSE; } else if (Operator->getBaseName() == "!=" && !DidUnwrapOptional) { // Note that we don't consider this case if we unwrapped an optional, // as e.g optBool != false is a check for true *or* nil. Type = FoundBool->getValue() ? ConditionType::IS_FALSE : ConditionType::IS_TRUE; } } } /// A bool condition expression. explicit CallbackCondition(const Expr *E) { // FIXME: Sema should produce ErrorType. if (!E->getType() || !E->getType()->isBool()) return; auto CondType = ConditionType::IS_TRUE; E = E->getSemanticsProvidingExpr(); // If we have a prefix negation operator, this is a check for false. if (auto *PrefixOp = dyn_cast(E)) { auto *Callee = PrefixOp->getCalledValue(); if (Callee && Callee->isOperator() && Callee->getBaseName() == "!") { CondType = ConditionType::IS_FALSE; E = PrefixOp->getOperand()->getSemanticsProvidingExpr(); } } auto *DRE = dyn_cast(E); if (!DRE) return; Subject = DRE->getDecl(); Type = CondType; } /// Initializes a `CallbackCondition` with binding of an `Optional` or /// `Result` typed `Subject`, ie. /// - `let bind = ` /// - `case .success(let bind) = ` /// - `case .failure(let bind) = ` /// - `let bind = try? .get()` CallbackCondition(const Pattern *P, const Expr *Init) { Init = Init->getSemanticsProvidingExpr(); P = P->getSemanticsProvidingPattern(); if (auto *DRE = dyn_cast(Init)) { if (auto *OSP = dyn_cast(P)) { // `let bind = ` Type = ConditionType::NOT_NIL; Subject = DRE->getDecl(); BindPattern = OSP->getSubPattern(); } else if (auto *EEP = dyn_cast(P)) { // `case .(let ) = ` initFromEnumPattern(DRE->getDecl(), EEP); } } else if (auto *OTE = dyn_cast(Init)) { // `let bind = try? .get()` if (auto *OSP = dyn_cast(P)) initFromOptionalTry(OSP->getSubPattern(), OTE); } } /// Initializes a `CallbackCondtion` from a case statement inside a switch /// on `Subject` with `Result` type, ie. /// ``` /// switch { /// case .success(let bind): /// case .failure(let bind): /// } /// ``` CallbackCondition(const Decl *Subject, const CaseLabelItem *CaseItem) { if (auto *EEP = dyn_cast( CaseItem->getPattern()->getSemanticsProvidingPattern())) { // `case .(let )` initFromEnumPattern(Subject, EEP); } } bool isValid() const { return Type.has_value(); } private: void initFromEnumPattern(const Decl *D, const EnumElementPattern *EEP) { if (auto *EED = EEP->getElementDecl()) { auto eedTy = EED->getParentEnum()->getDeclaredType(); if (!eedTy || !eedTy->isResult()) return; if (EED->getNameStr() == StringRef("failure")) { Type = ConditionType::FAILURE_PATTEN; } else { Type = ConditionType::SUCCESS_PATTERN; } Subject = D; BindPattern = EEP->getSubPattern(); } } void initFromOptionalTry(const class Pattern *P, const OptionalTryExpr *OTE) { auto *ICE = dyn_cast(OTE->getSubExpr()); if (!ICE) return; auto *CE = dyn_cast(ICE->getSyntacticSubExpr()); if (!CE) return; auto *DSC = dyn_cast(CE->getFn()); if (!DSC) return; auto *BaseDRE = dyn_cast(DSC->getBase()); if (!BaseDRE->getType() || !BaseDRE->getType()->isResult()) return; auto *FnDRE = dyn_cast(DSC->getFn()); if (!FnDRE) return; auto *FD = dyn_cast(FnDRE->getDecl()); if (!FD || FD->getNameStr() != StringRef("get")) return; Type = ConditionType::NOT_NIL; Subject = BaseDRE->getDecl(); BindPattern = P; } }; /// A CallbackCondition with additional semantic information about whether it /// is for a success path or failure path. struct ClassifiedCondition : public CallbackCondition { ConditionPath Path; /// Whether this represents an Obj-C style boolean flag check for success. bool IsObjCStyleFlagCheck; explicit ClassifiedCondition(CallbackCondition Cond, ConditionPath Path, bool IsObjCStyleFlagCheck) : CallbackCondition(Cond), Path(Path), IsObjCStyleFlagCheck(IsObjCStyleFlagCheck) {} }; /// A wrapper for a map of parameter decls to their classified conditions, or /// \c None if they are not present in any conditions. struct ClassifiedCallbackConditions final : llvm::MapVector { Optional lookup(const Decl *D) const { auto Res = find(D); if (Res == end()) return None; return Res->second; } }; /// A list of nodes to print, along with a list of locations that may have /// preceding comments attached, which also need printing. For example: /// /// \code /// if .random() { /// // a /// print("hello") /// // b /// } /// \endcode /// /// To print out the contents of the if statement body, we'll include the AST /// node for the \c print call. This will also include the preceding comment /// \c a, but won't include the comment \c b. To ensure the comment \c b gets /// printed, the SourceLoc for the closing brace \c } is added as a possible /// comment loc. class NodesToPrint { SmallVector Nodes; SmallVector PossibleCommentLocs; public: NodesToPrint() {} NodesToPrint(ArrayRef Nodes, ArrayRef PossibleCommentLocs) : Nodes(Nodes.begin(), Nodes.end()), PossibleCommentLocs(PossibleCommentLocs.begin(), PossibleCommentLocs.end()) {} ArrayRef getNodes() const { return Nodes; } ArrayRef getPossibleCommentLocs() const { return PossibleCommentLocs; } /// Add an AST node to print. void addNode(ASTNode Node) { // Note we skip vars as they'll be printed as a part of their // PatternBindingDecl. if (!Node.isDecl(DeclKind::Var)) Nodes.push_back(Node); } /// Add a SourceLoc which may have a preceding comment attached. If so, the /// comment will be printed out at the appropriate location. void addPossibleCommentLoc(SourceLoc Loc) { if (Loc.isValid()) PossibleCommentLocs.push_back(Loc); } /// Add all the nodes in the brace statement to the list of nodes to print. /// This should be preferred over adding the nodes manually as it picks up the /// end location of the brace statement as a possible comment loc, ensuring /// that we print any trailing comments in the brace statement. void addNodesInBraceStmt(BraceStmt *Brace) { for (auto Node : Brace->getElements()) addNode(Node); // Ignore the end locations of implicit braces, as they're likely bogus. // e.g for a case statement, the r-brace loc points to the last token of the // last node in the body. if (!Brace->isImplicit()) addPossibleCommentLoc(Brace->getRBraceLoc()); } /// Add the nodes and comment locs from another NodesToPrint. void addNodes(NodesToPrint OtherNodes) { Nodes.append(OtherNodes.Nodes.begin(), OtherNodes.Nodes.end()); PossibleCommentLocs.append(OtherNodes.PossibleCommentLocs.begin(), OtherNodes.PossibleCommentLocs.end()); } /// Whether the last recorded node is an explicit return or break statement. bool hasTrailingReturnOrBreak() const { if (Nodes.empty()) return false; return (Nodes.back().isStmt(StmtKind::Return) || Nodes.back().isStmt(StmtKind::Break)) && !Nodes.back().isImplicit(); } /// If the last recorded node is an explicit return or break statement that /// can be safely dropped, drop it from the list. void dropTrailingReturnOrBreakIfPossible() { if (!hasTrailingReturnOrBreak()) return; auto *Node = Nodes.back().get(); // If this is a return statement with return expression, let's preserve it. if (auto *RS = dyn_cast(Node)) { if (RS->hasResult()) return; } // Remove the node from the list, but make sure to add it as a possible // comment loc to preserve any of its attached comments. Nodes.pop_back(); addPossibleCommentLoc(Node->getStartLoc()); } /// Returns a list of nodes to print in a brace statement. This picks up the /// end location of the brace statement as a possible comment loc, ensuring /// that we print any trailing comments in the brace statement. static NodesToPrint inBraceStmt(BraceStmt *stmt) { NodesToPrint Nodes; Nodes.addNodesInBraceStmt(stmt); return Nodes; } }; /// The statements within the closure of call to a function taking a callback /// are split into a `SuccessBlock` and `ErrorBlock` (`ClassifiedBlocks`). /// This class stores the nodes for each block, as well as a mapping of /// decls to any patterns they are used in. class ClassifiedBlock { NodesToPrint Nodes; // A mapping of closure params to a list of patterns that bind them. using ParamPatternBindingsMap = llvm::MapVector>; ParamPatternBindingsMap ParamPatternBindings; public: const NodesToPrint &nodesToPrint() const { return Nodes; } /// Attempt to retrieve an existing bound name for a closure parameter, or /// an empty string if there's no suitable existing binding. StringRef boundName(const Decl *D) const { // Adopt the same name as the representative single pattern, if it only // binds a single var. if (auto *P = getSinglePatternFor(D)) { if (P->getSingleVar()) return P->getBoundName().str(); } return StringRef(); } /// Checks whether a closure parameter can be represented by a single pattern /// that binds it. If the param is only bound by a single pattern, that will /// be returned. If there's a pattern with a single var that binds it, that /// will be returned, preferring a 'let' pattern to prefer out of line /// printing of 'var' patterns. const Pattern *getSinglePatternFor(const Decl *D) const { auto Iter = ParamPatternBindings.find(D); if (Iter == ParamPatternBindings.end()) return nullptr; const auto &Patterns = Iter->second; if (Patterns.empty()) return nullptr; if (Patterns.size() == 1) return Patterns[0]; // If we have multiple patterns, search for the best single var pattern to // use, preferring a 'let' binding. const Pattern *FirstSingleVar = nullptr; for (auto *P : Patterns) { if (!P->getSingleVar()) continue; if (!P->hasAnyMutableBindings()) return P; if (!FirstSingleVar) FirstSingleVar = P; } return FirstSingleVar; } /// Retrieve any bound vars that are effectively aliases of a given closure /// parameter. llvm::SmallDenseSet getAliasesFor(const Decl *D) const { auto Iter = ParamPatternBindings.find(D); if (Iter == ParamPatternBindings.end()) return {}; llvm::SmallDenseSet Aliases; // The single pattern that we replace the decl with is always an alias. if (auto *P = getSinglePatternFor(D)) { if (auto *SingleVar = P->getSingleVar()) Aliases.insert(SingleVar); } // Any other let bindings we have are also aliases. for (auto *P : Iter->second) { if (auto *SingleVar = P->getSingleVar()) { if (!P->hasAnyMutableBindings()) Aliases.insert(SingleVar); } } return Aliases; } const ParamPatternBindingsMap ¶mPatternBindings() const { return ParamPatternBindings; } void addNodesInBraceStmt(BraceStmt *Brace) { Nodes.addNodesInBraceStmt(Brace); } void addPossibleCommentLoc(SourceLoc Loc) { Nodes.addPossibleCommentLoc(Loc); } void addAllNodes(NodesToPrint OtherNodes) { Nodes.addNodes(std::move(OtherNodes)); } void addNode(ASTNode Node) { Nodes.addNode(Node); } void addBinding(const ClassifiedCondition &FromCondition) { auto *P = FromCondition.BindPattern; if (!P) return; // Patterns that don't bind anything aren't interesting. SmallVector Vars; P->collectVariables(Vars); if (Vars.empty()) return; ParamPatternBindings[FromCondition.Subject].push_back(P); } void addAllBindings(const ClassifiedCallbackConditions &FromConditions) { for (auto &Entry : FromConditions) addBinding(Entry.second); } }; /// The type of block rewritten code may be placed in. enum class BlockKind { SUCCESS, ERROR, FALLBACK }; /// A completion handler function parameter that is known to be a Bool flag /// indicating success or failure. struct KnownBoolFlagParam { const ParamDecl *Param; bool IsSuccessFlag; }; /// A set of parameters for a completion callback closure. class ClosureCallbackParams final { const AsyncHandlerParamDesc &HandlerDesc; ArrayRef AllParams; llvm::SetVector SuccessParams; const ParamDecl *ErrParam = nullptr; Optional BoolFlagParam; public: ClosureCallbackParams(const AsyncHandlerParamDesc &HandlerDesc, const ClosureExpr *Closure) : HandlerDesc(HandlerDesc), AllParams(Closure->getParameters()->getArray()) { assert(AllParams.size() == HandlerDesc.params().size()); assert(HandlerDesc.Type != HandlerType::RESULT || AllParams.size() == 1); SuccessParams.insert(AllParams.begin(), AllParams.end()); if (HandlerDesc.HasError && HandlerDesc.Type == HandlerType::PARAMS) ErrParam = SuccessParams.pop_back_val(); // Check to see if we have a known bool flag parameter. if (auto *AsyncAlt = HandlerDesc.Func->getAsyncAlternative()) { if (auto Conv = AsyncAlt->getForeignAsyncConvention()) { auto FlagIdx = Conv->completionHandlerFlagParamIndex(); if (FlagIdx && *FlagIdx >= 0 && *FlagIdx < AllParams.size()) { auto IsSuccessFlag = Conv->completionHandlerFlagIsErrorOnZero(); BoolFlagParam = {AllParams[*FlagIdx], IsSuccessFlag}; } } } } /// Whether the closure has a particular parameter. bool hasParam(const ParamDecl *Param) const { return Param == ErrParam || SuccessParams.contains(Param); } /// Whether \p Param is a success param. bool isSuccessParam(const ParamDecl *Param) const { return SuccessParams.contains(Param); } /// Whether \p Param is a closure parameter that may be unwrapped. This /// includes optional parameters as well as \c Result parameters that may be /// unwrapped through e.g 'try? res.get()'. bool isUnwrappableParam(const ParamDecl *Param) const { if (!hasParam(Param)) return false; if (getResultParam() == Param) return true; return HandlerDesc.shouldUnwrap(Param->getType()); } /// Whether \p Param is the known Bool parameter that indicates success or /// failure. bool isKnownBoolFlagParam(const ParamDecl *Param) const { if (auto BoolFlag = getKnownBoolFlagParam()) return BoolFlag->Param == Param; return false; } /// Whether \p Param is a closure parameter that has a binding available in /// the async variant of the call for a particular \p Block. bool hasBinding(const ParamDecl *Param, BlockKind Block) const { switch (Block) { case BlockKind::SUCCESS: // Known bool flags get dropped from the imported async variant. if (isKnownBoolFlagParam(Param)) return false; return isSuccessParam(Param); case BlockKind::ERROR: return Param == ErrParam; case BlockKind::FALLBACK: // We generally want to bind everything in the fallback case. return hasParam(Param); } llvm_unreachable("Unhandled case in switch"); } /// Retrieve the parameters to bind in a given \p Block. TinyPtrVector getParamsToBind(BlockKind Block) { TinyPtrVector Result; for (auto *Param : AllParams) { if (hasBinding(Param, Block)) Result.push_back(Param); } return Result; } /// If there is a known Bool flag parameter indicating success or failure, /// returns it, \c None otherwise. Optional getKnownBoolFlagParam() const { return BoolFlagParam; } /// All the parameters of the closure passed as the completion handler. ArrayRef getAllParams() const { return AllParams; } /// The success parameters of the closure passed as the completion handler. /// Note this includes a \c Result parameter. ArrayRef getSuccessParams() const { return SuccessParams.getArrayRef(); } /// The error parameter of the closure passed as the completion handler, or /// \c nullptr if there is no error parameter. const ParamDecl *getErrParam() const { return ErrParam; } /// If the closure has a single \c Result parameter, returns it, \c nullptr /// otherwise. const ParamDecl *getResultParam() const { return HandlerDesc.Type == HandlerType::RESULT ? SuccessParams[0] : nullptr; } }; /// Whether or not the given statement starts a new scope. Note that most /// statements are handled by the \c BraceStmt check. The others listed are /// a somewhat special case since they can also declare variables in their /// condition. static bool startsNewScope(Stmt *S) { switch (S->getKind()) { case StmtKind::Brace: case StmtKind::If: case StmtKind::While: case StmtKind::ForEach: case StmtKind::Case: return true; default: return false; } } struct ClassifiedBlocks { ClassifiedBlock SuccessBlock; ClassifiedBlock ErrorBlock; }; /// Classifer of callback closure statements that that have either multiple /// non-Result parameters or a single Result parameter and return Void. /// /// It performs a (possibly incorrect) best effort and may give up in certain /// cases. Aims to cover the idiomatic cases of either having no error /// parameter at all, or having success/error code wrapped in ifs/guards/switch /// using either pattern binding or nil checks. /// /// Code outside any clear conditions is assumed to be solely part of the /// success block for now, though some heuristics could be added to classify /// these better in the future. struct CallbackClassifier { /// Updates the success and error block of `Blocks` with nodes and bound /// names from `Body`. Errors are added through `DiagEngine`, possibly /// resulting in partially filled out blocks. static void classifyInto(ClassifiedBlocks &Blocks, const ClosureCallbackParams &Params, llvm::DenseSet &HandledSwitches, DiagnosticEngine &DiagEngine, BraceStmt *Body) { assert(!Body->getElements().empty() && "Cannot classify empty body"); CallbackClassifier Classifier(Blocks, Params, HandledSwitches, DiagEngine); Classifier.classifyNodes(Body->getElements(), Body->getRBraceLoc()); } private: ClassifiedBlocks &Blocks; const ClosureCallbackParams &Params; llvm::DenseSet &HandledSwitches; DiagnosticEngine &DiagEngine; ClassifiedBlock *CurrentBlock; /// This is set to \c true if we're currently classifying on a known condition /// path, where \c CurrentBlock is set to the appropriate block. This lets us /// be more lenient with unhandled conditions as we already know the block /// we're supposed to be in. bool IsKnownConditionPath = false; CallbackClassifier(ClassifiedBlocks &Blocks, const ClosureCallbackParams &Params, llvm::DenseSet &HandledSwitches, DiagnosticEngine &DiagEngine) : Blocks(Blocks), Params(Params), HandledSwitches(HandledSwitches), DiagEngine(DiagEngine), CurrentBlock(&Blocks.SuccessBlock) {} /// Attempt to apply custom classification logic to a given node, returning /// \c true if the node was classified, otherwise \c false. bool tryClassifyNode(ASTNode Node) { auto *Statement = Node.dyn_cast(); if (!Statement) return false; if (auto *IS = dyn_cast(Statement)) { NodesToPrint TempNodes; if (auto *BS = dyn_cast(IS->getThenStmt())) { TempNodes = NodesToPrint::inBraceStmt(BS); } else { TempNodes = NodesToPrint({IS->getThenStmt()}, /*commentLocs*/ {}); } classifyConditional(IS, IS->getCond(), std::move(TempNodes), IS->getElseStmt()); return true; } else if (auto *GS = dyn_cast(Statement)) { classifyConditional(GS, GS->getCond(), NodesToPrint(), GS->getBody()); return true; } else if (auto *SS = dyn_cast(Statement)) { classifySwitch(SS); return true; } else if (auto *RS = dyn_cast(Statement)) { // We can look through an implicit Void return of a SingleValueStmtExpr, // as that's semantically a statement. if (RS->hasResult() && RS->isImplicit()) { auto Ty = RS->getResult()->getType(); if (Ty && Ty->isVoid()) { if (auto *SVE = dyn_cast(RS->getResult())) return tryClassifyNode(SVE->getStmt()); } } } return false; } /// Classify a node, or add the node to the block if it cannot be classified. /// Returns \c true if there was an error. bool classifyNode(ASTNode Node) { auto DidClassify = tryClassifyNode(Node); if (!DidClassify) CurrentBlock->addNode(Node); return DiagEngine.hadAnyError(); } void classifyNodes(ArrayRef Nodes, SourceLoc EndCommentLoc) { for (auto Node : Nodes) { auto HadError = classifyNode(Node); if (HadError) return; } // Make sure to pick up any trailing comments. CurrentBlock->addPossibleCommentLoc(EndCommentLoc); } /// Whether any of the provided ASTNodes have a child expression that force /// unwraps the error parameter. Note that this doesn't walk into new scopes. bool hasForceUnwrappedErrorParam(ArrayRef Nodes) { auto *ErrParam = Params.getErrParam(); if (!ErrParam) return false; class ErrUnwrapFinder : public ASTWalker { const ParamDecl *ErrParam; bool FoundUnwrap = false; public: explicit ErrUnwrapFinder(const ParamDecl *ErrParam) : ErrParam(ErrParam) {} bool foundUnwrap() const { return FoundUnwrap; } PreWalkResult walkToExprPre(Expr *E) override { // Don't walk into ternary conditionals as they may have additional // conditions such as err != nil that make a force unwrap now valid. if (isa(E)) return Action::SkipChildren(E); auto *FVE = dyn_cast(E); if (!FVE) return Action::Continue(E); auto *DRE = dyn_cast(FVE->getSubExpr()); if (!DRE) return Action::Continue(E); if (DRE->getDecl() != ErrParam) return Action::Continue(E); // If we find the node we're looking for, make a note of it, and abort // the walk. FoundUnwrap = true; return Action::Stop(); } PreWalkResult walkToStmtPre(Stmt *S) override { // Don't walk into new explicit scopes, we only want to consider force // unwraps in the immediate conditional body. if (!S->isImplicit() && startsNewScope(S)) return Action::SkipChildren(S); return Action::Continue(S); } PreWalkAction walkToDeclPre(Decl *D) override { // Don't walk into new explicit DeclContexts. return Action::VisitChildrenIf(D->isImplicit() || !isa(D)); } }; for (auto Node : Nodes) { ErrUnwrapFinder walker(ErrParam); Node.walk(walker); if (walker.foundUnwrap()) return true; } return false; } /// Given a callback condition, classify it as a success or failure path. Optional classifyCallbackCondition(const CallbackCondition &Cond, const NodesToPrint &SuccessNodes, Stmt *ElseStmt) { if (!Cond.isValid()) return None; // If the condition involves a refutable pattern, we can't currently handle // it. if (Cond.BindPattern && Cond.BindPattern->isRefutablePattern()) return None; auto *SubjectParam = dyn_cast(Cond.Subject); if (!SubjectParam) return None; // For certain types of condition, they need to be certain kinds of params. auto CondType = *Cond.Type; switch (CondType) { case ConditionType::NOT_NIL: case ConditionType::NIL: if (!Params.isUnwrappableParam(SubjectParam)) return None; break; case ConditionType::IS_TRUE: case ConditionType::IS_FALSE: if (!Params.isSuccessParam(SubjectParam)) return None; break; case ConditionType::SUCCESS_PATTERN: case ConditionType::FAILURE_PATTEN: if (SubjectParam != Params.getResultParam()) return None; break; } // Let's start with a success path, and flip any negative conditions. auto Path = ConditionPath::SUCCESS; // If it's an error param, that's a flip. if (SubjectParam == Params.getErrParam()) Path = flippedConditionPath(Path); // If we have a nil, false, or failure condition, that's a flip. switch (CondType) { case ConditionType::NIL: case ConditionType::IS_FALSE: case ConditionType::FAILURE_PATTEN: Path = flippedConditionPath(Path); break; case ConditionType::IS_TRUE: case ConditionType::NOT_NIL: case ConditionType::SUCCESS_PATTERN: break; } // If we have a bool condition, it could be an Obj-C style flag check, which // we do some extra checking for. Otherwise, we're done. if (CondType != ConditionType::IS_TRUE && CondType != ConditionType::IS_FALSE) { return ClassifiedCondition(Cond, Path, /*ObjCFlagCheck*/ false); } // Check to see if we have a known bool flag parameter that indicates // success or failure. if (auto KnownBoolFlag = Params.getKnownBoolFlagParam()) { if (KnownBoolFlag->Param != SubjectParam) return None; // The path may need to be flipped depending on whether the flag indicates // success. if (!KnownBoolFlag->IsSuccessFlag) Path = flippedConditionPath(Path); return ClassifiedCondition(Cond, Path, /*ObjCStyleFlagCheck*/ true); } // If we've reached here, we have a bool flag check that isn't specified in // the async convention. We apply a heuristic to see if the error param is // force unwrapped in the conditional body. In that case, the user is // expecting it to be the error path, and it's more likely than not that the // flag value conveys no more useful information in the error block. // First check the success block. auto FoundInSuccessBlock = hasForceUnwrappedErrorParam(SuccessNodes.getNodes()); // Then check the else block if we have it. if (ASTNode ElseNode = ElseStmt) { // Unwrap the BraceStmt of the else clause if needed. This is needed as // we won't walk into BraceStmts by default as they introduce new // scopes. ArrayRef Nodes; if (auto *BS = dyn_cast(ElseStmt)) { Nodes = BS->getElements(); } else { Nodes = llvm::makeArrayRef(ElseNode); } if (hasForceUnwrappedErrorParam(Nodes)) { // If we also found an unwrap in the success block, we don't know what's // happening here. if (FoundInSuccessBlock) return None; // Otherwise we can determine this as a success condition. Note this is // flipped as if the error is present in the else block, this condition // is for success. return ClassifiedCondition(Cond, ConditionPath::SUCCESS, /*ObjCStyleFlagCheck*/ true); } } if (FoundInSuccessBlock) { // Note that the path is flipped as if the error is present in the success // block, this condition is for failure. return ClassifiedCondition(Cond, ConditionPath::FAILURE, /*ObjCStyleFlagCheck*/ true); } // Otherwise we can't classify this. return None; } /// Classifies all the conditions present in a given StmtCondition, taking /// into account its success body and failure body. Returns \c true if there /// were any conditions that couldn't be classified, \c false otherwise. bool classifyConditionsOf(StmtCondition Cond, const NodesToPrint &ThenNodesToPrint, Stmt *ElseStmt, ClassifiedCallbackConditions &Conditions) { bool UnhandledConditions = false; Optional ObjCFlagCheck; auto TryAddCond = [&](CallbackCondition CC) { auto Classified = classifyCallbackCondition(CC, ThenNodesToPrint, ElseStmt); // If we couldn't classify this, or if there are multiple Obj-C style flag // checks, this is unhandled. if (!Classified || (ObjCFlagCheck && Classified->IsObjCStyleFlagCheck)) { UnhandledConditions = true; return; } // If we've seen multiple conditions for the same subject, don't handle // this. if (!Conditions.insert({CC.Subject, *Classified}).second) { UnhandledConditions = true; return; } if (Classified->IsObjCStyleFlagCheck) ObjCFlagCheck = Classified; }; for (auto &CondElement : Cond) { if (auto *BoolExpr = CondElement.getBooleanOrNull()) { SmallVector Exprs; Exprs.push_back(BoolExpr); while (!Exprs.empty()) { auto *Next = Exprs.pop_back_val()->getSemanticsProvidingExpr(); if (auto *ACE = dyn_cast(Next)) Next = ACE->getSingleExpressionBody()->getSemanticsProvidingExpr(); if (auto *BE = dyn_cast_or_null(Next)) { auto *Operator = isOperator(BE); if (Operator) { // If we have an && operator, decompose its arguments. if (Operator->getBaseName() == "&&") { Exprs.push_back(BE->getLHS()); Exprs.push_back(BE->getRHS()); } else { // Otherwise check to see if we have an == nil or != nil // condition. TryAddCond(CallbackCondition(BE, Operator)); } continue; } } // Check to see if we have a lone bool condition. TryAddCond(CallbackCondition(Next)); } } else if (auto *P = CondElement.getPatternOrNull()) { TryAddCond(CallbackCondition(P, CondElement.getInitializer())); } } return UnhandledConditions || Conditions.empty(); } /// Classifies the conditions of a conditional statement, and adds the /// necessary nodes to either the success or failure block. void classifyConditional(Stmt *Statement, StmtCondition Condition, NodesToPrint ThenNodesToPrint, Stmt *ElseStmt) { ClassifiedCallbackConditions CallbackConditions; bool UnhandledConditions = classifyConditionsOf( Condition, ThenNodesToPrint, ElseStmt, CallbackConditions); auto ErrCondition = CallbackConditions.lookup(Params.getErrParam()); if (UnhandledConditions) { // Some unknown conditions. If there's an else, assume we can't handle // and use the fallback case. Otherwise add to either the success or // error block depending on some heuristics, known conditions will have // placeholders added (ideally we'd remove them) // TODO: Remove known conditions and split the `if` statement if (IsKnownConditionPath) { // If we're on a known condition path, we can be lenient as we already // know what block we're in and can therefore just add the conditional // straight to it. CurrentBlock->addNode(Statement); } else if (CallbackConditions.empty()) { // Technically this has a similar problem, ie. the else could have // conditions that should be in either success/error CurrentBlock->addNode(Statement); } else if (ElseStmt) { DiagEngine.diagnose(Statement->getStartLoc(), diag::unknown_callback_conditions); } else if (ErrCondition && ErrCondition->Path == ConditionPath::FAILURE) { Blocks.ErrorBlock.addNode(Statement); } else { for (auto &Entry : CallbackConditions) { if (Entry.second.Path == ConditionPath::FAILURE) { Blocks.ErrorBlock.addNode(Statement); return; } } Blocks.SuccessBlock.addNode(Statement); } return; } // If all the conditions were classified, make sure they're all consistently // on the success or failure path. Optional Path; for (auto &Entry : CallbackConditions) { auto &Cond = Entry.second; if (!Path) { Path = Cond.Path; } else if (*Path != Cond.Path) { // Similar to the unknown conditions case. Add the whole if unless // there's an else, in which case use the fallback instead. // TODO: Split the `if` statement if (ElseStmt) { DiagEngine.diagnose(Statement->getStartLoc(), diag::mixed_callback_conditions); } else { CurrentBlock->addNode(Statement); } return; } } assert(Path && "Didn't classify a path?"); auto *ThenBlock = &Blocks.SuccessBlock; auto *ElseBlock = &Blocks.ErrorBlock; // If the condition is for a failure path, the error block is ThenBlock, and // the success block is ElseBlock. if (*Path == ConditionPath::FAILURE) std::swap(ThenBlock, ElseBlock); // We'll be dropping the statement, but make sure to keep any attached // comments. CurrentBlock->addPossibleCommentLoc(Statement->getStartLoc()); ThenBlock->addAllBindings(CallbackConditions); // TODO: Handle nested ifs setNodes(ThenBlock, ElseBlock, std::move(ThenNodesToPrint)); if (ElseStmt) { if (auto *BS = dyn_cast(ElseStmt)) { // If this is a guard statement, we know that we'll always exit, // allowing us to classify any additional nodes into the opposite block. auto AlwaysExits = isa(Statement); setNodes(ElseBlock, ThenBlock, NodesToPrint::inBraceStmt(BS), AlwaysExits); } else { // If we reached here, we should have an else if statement. Given we // know we're in the else of a known condition, temporarily flip the // current block, and set that we know what path we're on. llvm::SaveAndRestore CondScope(IsKnownConditionPath, true); llvm::SaveAndRestore BlockScope(CurrentBlock, ElseBlock); classifyNodes(ArrayRef(ElseStmt), /*endCommentLoc*/ SourceLoc()); } } } /// Adds \p Nodes to \p Block, potentially flipping the current block if we /// can determine that the nodes being added will cause control flow to leave /// the scope. /// /// \param Block The block to add the nodes to. /// \param OtherBlock The block for the opposing condition path. /// \param Nodes The nodes to add. /// \param AlwaysExitsScope Whether the nodes being added always exit the /// scope, and therefore whether the current block should be flipped. void setNodes(ClassifiedBlock *Block, ClassifiedBlock *OtherBlock, NodesToPrint Nodes, bool AlwaysExitsScope = false) { // Drop an explicit trailing 'return' or 'break' if we can. bool HasTrailingReturnOrBreak = Nodes.hasTrailingReturnOrBreak(); if (HasTrailingReturnOrBreak) Nodes.dropTrailingReturnOrBreakIfPossible(); // If we know we're exiting the scope, we can set IsKnownConditionPath, as // we know any future nodes should be classified into the other block. if (HasTrailingReturnOrBreak || AlwaysExitsScope) { CurrentBlock = OtherBlock; IsKnownConditionPath = true; Block->addAllNodes(std::move(Nodes)); } else { Block->addAllNodes(std::move(Nodes)); } } void classifySwitch(SwitchStmt *SS) { auto *ResultParam = Params.getResultParam(); if (singleSwitchSubject(SS) != ResultParam) { CurrentBlock->addNode(SS); return; } // We'll be dropping the switch, but make sure to keep any attached // comments. CurrentBlock->addPossibleCommentLoc(SS->getStartLoc()); // Push the cases into a vector. This is only done to eagerly evaluate the // AsCaseStmtRange sequence so we can know what the last case is. SmallVector Cases; Cases.append(SS->getCases().begin(), SS->getCases().end()); for (auto *CS : Cases) { if (CS->hasFallthroughDest()) { DiagEngine.diagnose(CS->getLoc(), diag::callback_with_fallthrough); return; } if (CS->isDefault()) { DiagEngine.diagnose(CS->getLoc(), diag::callback_with_default); return; } auto Items = CS->getCaseLabelItems(); if (Items.size() > 1) { DiagEngine.diagnose(CS->getLoc(), diag::callback_multiple_case_items); return; } if (Items[0].getWhereLoc().isValid()) { DiagEngine.diagnose(CS->getLoc(), diag::callback_where_case_item); return; } auto *Block = &Blocks.SuccessBlock; auto *OtherBlock = &Blocks.ErrorBlock; auto SuccessNodes = NodesToPrint::inBraceStmt(CS->getBody()); // Classify the case pattern. auto CC = classifyCallbackCondition( CallbackCondition(ResultParam, &Items[0]), SuccessNodes, /*elseStmt*/ nullptr); if (!CC) { DiagEngine.diagnose(CS->getLoc(), diag::unknown_callback_case_item); return; } if (CC->Path == ConditionPath::FAILURE) std::swap(Block, OtherBlock); // We'll be dropping the case, but make sure to keep any attached // comments. Because these comments will effectively be part of the // previous case, add them to CurrentBlock. CurrentBlock->addPossibleCommentLoc(CS->getStartLoc()); // Make sure to grab trailing comments in the last case stmt. if (CS == Cases.back()) Block->addPossibleCommentLoc(SS->getRBraceLoc()); setNodes(Block, OtherBlock, std::move(SuccessNodes)); Block->addBinding(*CC); } // Mark this switch statement as having been transformed. HandledSwitches.insert(SS); } }; /// Base name of a decl if it has one, an empty \c DeclBaseName otherwise. static DeclBaseName getDeclName(const Decl *D) { if (auto *VD = dyn_cast(D)) { if (VD->hasName()) return VD->getBaseName(); } return DeclBaseName(); } class DeclCollector : private SourceEntityWalker { llvm::DenseSet &Decls; public: /// Collect all explicit declarations declared in \p Scope (or \p SF if /// \p Scope is a nullptr) that are not within their own scope. static void collect(BraceStmt *Scope, SourceFile &SF, llvm::DenseSet &Decls) { DeclCollector Collector(Decls); if (Scope) { for (auto Node : Scope->getElements()) { Collector.walk(Node); } } else { Collector.walk(SF); } } private: DeclCollector(llvm::DenseSet &Decls) : Decls(Decls) {} bool walkToDeclPre(Decl *D, CharSourceRange Range) override { // Want to walk through top level code decls (which are implicitly added // for top level non-decl code) and pattern binding decls (which contain // the var decls that we care about). if (isa(D) || isa(D)) return true; if (!D->isImplicit()) Decls.insert(D); return false; } bool walkToExprPre(Expr *E) override { return !isa(E); } bool walkToStmtPre(Stmt *S) override { return S->isImplicit() || !startsNewScope(S); } }; class ReferenceCollector : private SourceEntityWalker { SourceManager *SM; llvm::DenseSet DeclaredDecls; llvm::DenseSet &ReferencedDecls; ASTNode Target; bool AfterTarget; public: /// Collect all explicit references in \p Scope (or \p SF if \p Scope is /// a nullptr) that are after \p Target and not first declared. That is, /// references that we don't want to shadow with hoisted declarations. /// /// Also collect all declarations that are \c DeclContexts, which is an /// over-appoximation but let's us ignore them elsewhere. static void collect(ASTNode Target, BraceStmt *Scope, SourceFile &SF, llvm::DenseSet &Decls) { ReferenceCollector Collector(Target, &SF.getASTContext().SourceMgr, Decls); if (Scope) Collector.walk(Scope); else Collector.walk(SF); } private: ReferenceCollector(ASTNode Target, SourceManager *SM, llvm::DenseSet &Decls) : SM(SM), DeclaredDecls(), ReferencedDecls(Decls), Target(Target), AfterTarget(false) {} bool walkToDeclPre(Decl *D, CharSourceRange Range) override { // Bit of a hack, include all contexts so they're never renamed (seems worse // to rename a class/function than it does a variable). Again, an // over-approximation, but hopefully doesn't come up too often. if (isa(D) && !D->isImplicit()) { ReferencedDecls.insert(D); } if (AfterTarget && !D->isImplicit()) { DeclaredDecls.insert(D); } else if (D == Target.dyn_cast()) { AfterTarget = true; } return shouldWalkInto(D->getSourceRange()); } bool walkToExprPre(Expr *E) override { if (AfterTarget && !E->isImplicit()) { if (auto *DRE = dyn_cast(E)) { if (auto *D = DRE->getDecl()) { // Only care about references that aren't declared, as seen decls will // be renamed (if necessary) during the refactoring. if (!D->isImplicit() && !DeclaredDecls.count(D)) { ReferencedDecls.insert(D); // Also add the async alternative of a function to prevent // collisions if a call is replaced with the alternative. if (auto *AFD = dyn_cast(D)) { if (auto *Alternative = AFD->getAsyncAlternative()) ReferencedDecls.insert(Alternative); } } } } } else if (E == Target.dyn_cast()) { AfterTarget = true; } return shouldWalkInto(E->getSourceRange()); } bool walkToStmtPre(Stmt *S) override { if (S == Target.dyn_cast()) AfterTarget = true; return shouldWalkInto(S->getSourceRange()); } bool walkToPatternPre(Pattern *P) override { if (P == Target.dyn_cast()) AfterTarget = true; return shouldWalkInto(P->getSourceRange()); } bool shouldWalkInto(SourceRange Range) { return AfterTarget || (SM && SM->rangeContainsTokenLoc(Range, Target.getStartLoc())); } }; /// Similar to the \c ReferenceCollector but collects references in all scopes /// without any starting point in each scope. In addition, it tracks the number /// of references to a decl in a given scope. class ScopedDeclCollector : private SourceEntityWalker { public: using DeclsTy = llvm::DenseSet; using RefDeclsTy = llvm::DenseMap; private: using ScopedDeclsTy = llvm::DenseMap; struct Scope { DeclsTy DeclaredDecls; RefDeclsTy *ReferencedDecls; Scope(RefDeclsTy *ReferencedDecls) : DeclaredDecls(), ReferencedDecls(ReferencedDecls) {} }; ScopedDeclsTy ReferencedDecls; llvm::SmallVector ScopeStack; public: /// Starting at \c Scope, collect all explicit references in every scope /// within (including the initial) that are not first declared, ie. those that /// could end up shadowed. Also include all \c DeclContext declarations as /// we'd like to avoid renaming functions and types completely. void collect(ASTNode Node) { walk(Node); } const RefDeclsTy *getReferencedDecls(Stmt *Scope) const { auto Res = ReferencedDecls.find(Scope); if (Res == ReferencedDecls.end()) return nullptr; return &Res->second; } private: bool walkToDeclPre(Decl *D, CharSourceRange Range) override { if (ScopeStack.empty() || D->isImplicit()) return true; ScopeStack.back().DeclaredDecls.insert(D); if (isa(D)) (*ScopeStack.back().ReferencedDecls)[D] += 1; return true; } bool walkToExprPre(Expr *E) override { if (ScopeStack.empty()) return true; if (!E->isImplicit()) { if (auto *DRE = dyn_cast(E)) { if (auto *D = DRE->getDecl()) { // If we have a reference that isn't declared in the same scope, // increment the number of references to that decl. if (!D->isImplicit() && !ScopeStack.back().DeclaredDecls.count(D)) { (*ScopeStack.back().ReferencedDecls)[D] += 1; // Also add the async alternative of a function to prevent // collisions if a call is replaced with the alternative. if (auto *AFD = dyn_cast(D)) { if (auto *Alternative = AFD->getAsyncAlternative()) (*ScopeStack.back().ReferencedDecls)[Alternative] += 1; } } } } } return true; } bool walkToStmtPre(Stmt *S) override { // Purposely check \c BraceStmt here rather than \c startsNewScope. // References in the condition should be applied to the previous scope, not // the scope of that statement. if (isa(S)) ScopeStack.emplace_back(&ReferencedDecls[S]); return true; } bool walkToStmtPost(Stmt *S) override { if (isa(S)) { size_t NumScopes = ScopeStack.size(); if (NumScopes >= 2) { // Add any referenced decls to the parent scope that weren't declared // there. auto &ParentStack = ScopeStack[NumScopes - 2]; for (auto DeclAndNumRefs : *ScopeStack.back().ReferencedDecls) { auto *D = DeclAndNumRefs.first; if (!ParentStack.DeclaredDecls.count(D)) (*ParentStack.ReferencedDecls)[D] += DeclAndNumRefs.second; } } ScopeStack.pop_back(); } return true; } }; /// Checks whether an ASTNode contains a reference to a given declaration. class DeclReferenceFinder : private SourceEntityWalker { bool HasFoundReference = false; const Decl *Search; bool walkToExprPre(Expr *E) override { if (auto DRE = dyn_cast(E)) { if (DRE->getDecl() == Search) { HasFoundReference = true; return false; } } return true; } DeclReferenceFinder(const Decl *Search) : Search(Search) {} public: /// Returns \c true if \p node contains a reference to \p Search, \c false /// otherwise. static bool containsReference(ASTNode Node, const ValueDecl *Search) { DeclReferenceFinder Checker(Search); Checker.walk(Node); return Checker.HasFoundReference; } }; /// Builds up async-converted code for an AST node. /// /// If it is a function, its declaration will have `async` added. If a /// completion handler is present, it will be removed and the return type of /// the function will reflect the parameters of the handler, including an /// added `throws` if necessary. /// /// Calls to the completion handler are replaced with either a `return` or /// `throws` depending on the arguments. /// /// Calls to functions with an async alternative will be replaced with a call /// to the alternative, possibly wrapped in a do/catch. The do/catch is skipped /// if the the closure either: /// 1. Has no error /// 2. Has an error but no error handling (eg. just ignores) /// 3. Has error handling that only calls the containing function's handler /// with an error matching the error argument /// /// (2) is technically not the correct translation, but in practice it's likely /// the code a user would actually want. /// /// If the success vs error handling split inside the closure cannot be /// determined and the closure takes regular parameters (ie. not a Result), a /// fallback translation is used that keeps all the same variable names and /// simply moves the code within the closure out. /// /// The fallback is generally avoided, however, since it's quite unlikely to be /// the code the user intended. In most cases the refactoring will continue, /// with any unhandled decls wrapped in placeholders instead. class AsyncConverter : private SourceEntityWalker { struct Scope { llvm::DenseSet Names; /// If this scope is wrapped in a \c withChecked(Throwing)Continuation, the /// name of the continuation that must be resumed where there previously was /// a call to the function's completion handler. /// Otherwise an empty identifier. Identifier ContinuationName; Scope(Identifier ContinuationName) : Names(), ContinuationName(ContinuationName) {} /// Whether this scope is wrapped in a \c withChecked(Throwing)Continuation. bool isWrappedInContination() const { return !ContinuationName.empty(); } }; SourceFile *SF; SourceManager &SM; DiagnosticEngine &DiagEngine; // Node to convert ASTNode StartNode; // Completion handler of `StartNode` (if it's a function with an async // alternative) AsyncHandlerParamDesc TopHandler; SmallString<0> Buffer; llvm::raw_svector_ostream OS; // Decls where any force unwrap or optional chain of that decl should be // elided, e.g for a previously optional closure parameter that has become a // non-optional local. llvm::DenseSet Unwraps; // Decls whose references should be replaced with, either because they no // longer exist or are a different type. Any replaced code should ideally be // handled by the refactoring properly, but that's not possible in all cases llvm::DenseSet Placeholders; // Mapping from decl -> name, used as the name of possible new local // declarations of old completion handler parametes, as well as the // replacement for other hoisted declarations and their references llvm::DenseMap Names; /// The scopes (containing all name decls and whether the scope is wrapped in /// a continuation) as the AST is being walked. The first element is the /// initial scope and the last is the current scope. llvm::SmallVector Scopes; // Mapping of \c BraceStmt -> declarations referenced in that statement // without first being declared. These are used to fill the \c ScopeNames // map on entering that scope. ScopedDeclCollector ScopedDecls; /// The switch statements that have been re-written by this transform. llvm::DenseSet HandledSwitches; // The last source location that has been output. Used to output the source // between handled nodes SourceLoc LastAddedLoc; // Number of expressions (or pattern binding decl) currently nested in, taking // into account hoisting and the possible removal of ifs/switches int NestedExprCount = 0; // Whether a completion handler body is currently being hoisted out of its // call bool Hoisting = false; /// Whether a pattern is currently being converted. bool ConvertingPattern = false; /// A mapping of inline patterns to print for closure parameters. using InlinePatternsToPrint = llvm::DenseMap; public: /// Convert a function AsyncConverter(SourceFile *SF, SourceManager &SM, DiagnosticEngine &DiagEngine, AbstractFunctionDecl *FD, const AsyncHandlerParamDesc &TopHandler) : SF(SF), SM(SM), DiagEngine(DiagEngine), StartNode(FD), TopHandler(TopHandler), OS(Buffer) { Placeholders.insert(TopHandler.getHandler()); ScopedDecls.collect(FD); // Shouldn't strictly be necessary, but prefer possible shadowing over // crashes caused by a missing scope addNewScope({}); } /// Convert a call AsyncConverter(SourceFile *SF, SourceManager &SM, DiagnosticEngine &DiagEngine, CallExpr *CE, BraceStmt *Scope) : SF(SF), SM(SM), DiagEngine(DiagEngine), StartNode(CE), OS(Buffer) { ScopedDecls.collect(CE); // Create the initial scope, can be more accurate than the general // \c ScopedDeclCollector as there is a starting point. llvm::DenseSet UsedDecls; DeclCollector::collect(Scope, *SF, UsedDecls); ReferenceCollector::collect(StartNode, Scope, *SF, UsedDecls); addNewScope(UsedDecls); } ASTContext &getASTContext() const { return SF->getASTContext(); } bool convert() { assert(Buffer.empty() && "AsyncConverter can only be used once"); if (auto *FD = dyn_cast_or_null(StartNode.dyn_cast())) { addFuncDecl(FD); if (FD->getBody()) { convertNode(FD->getBody()); } } else { convertNode(StartNode, /*StartOverride=*/{}, /*ConvertCalls=*/true, /*IncludeComments=*/false); } return !DiagEngine.hadAnyError(); } /// When adding an async alternative method for the function declaration \c /// FD, this function tries to create a function body for the legacy function /// (the one with a completion handler), which calls the newly converted async /// function. There are certain situations in which we fail to create such a /// body, e.g. if the completion handler has the signature `(String, Error?) /// -> Void` in which case we can't synthesize the result of type \c String in /// the error case. bool createLegacyBody() { assert(Buffer.empty() && "AsyncConverter can only be used once"); if (!canCreateLegacyBody()) return false; FuncDecl *FD = cast(StartNode.get()); OS << tok::l_brace << "\n"; // start function body OS << "Task " << tok::l_brace << "\n"; addHoistedNamedCallback(FD, TopHandler, TopHandler.getNameStr(), [&]() { if (TopHandler.HasError) { OS << tok::kw_try << " "; } OS << "await "; // Since we're *creating* the async alternative here, there shouldn't // already be one. Thus, just assume that the call to the alternative is // the same as the call to the old completion handler function, minus the // completion handler arg. addForwardingCallTo(FD, /*HandlerReplacement=*/""); }); OS << "\n"; OS << tok::r_brace << "\n"; // end 'Task' OS << tok::r_brace << "\n"; // end function body return true; } /// Creates an async alternative function that forwards onto the completion /// handler function through /// withCheckedContinuation/withCheckedThrowingContinuation. bool createAsyncWrapper() { assert(Buffer.empty() && "AsyncConverter can only be used once"); auto *FD = cast(StartNode.get()); // First add the new async function declaration. addFuncDecl(FD); OS << tok::l_brace << "\n"; // Then add the body. OS << tok::kw_return << " "; if (TopHandler.HasError) OS << tok::kw_try << " "; OS << "await "; // withChecked[Throwing]Continuation { continuation in if (TopHandler.HasError) { OS << "withCheckedThrowingContinuation"; } else { OS << "withCheckedContinuation"; } OS << " " << tok::l_brace << " continuation " << tok::kw_in << "\n"; // fnWithHandler(args...) { ... } auto ClosureStr = getAsyncWrapperCompletionClosure("continuation", TopHandler); addForwardingCallTo(FD, /*HandlerReplacement=*/ClosureStr); OS << "\n"; OS << tok::r_brace << "\n"; // end continuation closure OS << tok::r_brace << "\n"; // end function body return true; } void replace(ASTNode Node, SourceEditConsumer &EditConsumer, SourceLoc StartOverride = SourceLoc()) { SourceRange Range = Node.getSourceRange(); if (StartOverride.isValid()) { Range = SourceRange(StartOverride, Range.End); } CharSourceRange CharRange = Lexer::getCharSourceRangeFromSourceRange(SM, Range); EditConsumer.accept(SM, CharRange, Buffer.str()); Buffer.clear(); } void insertAfter(ASTNode Node, SourceEditConsumer &EditConsumer) { EditConsumer.insertAfter(SM, Node.getEndLoc(), "\n\n"); EditConsumer.insertAfter(SM, Node.getEndLoc(), Buffer.str()); Buffer.clear(); } private: bool canCreateLegacyBody() { FuncDecl *FD = dyn_cast(StartNode.dyn_cast()); if (!FD) { return false; } if (FD == nullptr || FD->getBody() == nullptr) { return false; } if (FD->hasThrows()) { assert(!TopHandler.isValid() && "We shouldn't have found a handler desc " "if the original function throws"); return false; } return TopHandler.isValid(); } /// Prints a tuple of elements, or a lone single element if only one is /// present, using the provided printing function. template void addTupleOf(const Container &Elements, llvm::raw_ostream &OS, PrintFn PrintElt) { if (Elements.size() == 1) { PrintElt(Elements[0]); return; } OS << tok::l_paren; llvm::interleave(Elements, PrintElt, [&]() { OS << tok::comma << " "; }); OS << tok::r_paren; } /// Retrieve the completion handler closure argument for an async wrapper /// function. std::string getAsyncWrapperCompletionClosure(StringRef ContName, const AsyncHandlerParamDesc &HandlerDesc) { std::string OutputStr; llvm::raw_string_ostream OS(OutputStr); OS << tok::l_brace; // start closure // Prepare parameter names for the closure. auto SuccessParams = HandlerDesc.getSuccessParams(); SmallVector, 2> SuccessParamNames; for (auto idx : indices(SuccessParams)) { SuccessParamNames.emplace_back("result"); // If we have multiple success params, number them e.g res1, res2... if (SuccessParams.size() > 1) SuccessParamNames.back().append(std::to_string(idx + 1)); } Optional> ErrName; if (HandlerDesc.getErrorParam()) ErrName.emplace("error"); auto HasAnyParams = !SuccessParamNames.empty() || ErrName; if (HasAnyParams) OS << " "; // res1, res2 llvm::interleave( SuccessParamNames, [&](auto Name) { OS << Name; }, [&]() { OS << tok::comma << " "; }); // , err if (ErrName) { if (!SuccessParamNames.empty()) OS << tok::comma << " "; OS << *ErrName; } if (HasAnyParams) OS << " " << tok::kw_in; OS << "\n"; // The closure body. switch (HandlerDesc.Type) { case HandlerType::PARAMS: { // For a (Success?, Error?) -> Void handler, we do an if let on the error. if (ErrName) { // if let err = err { OS << tok::kw_if << " " << tok::kw_let << " "; OS << *ErrName << " " << tok::equal << " " << *ErrName << " "; OS << tok::l_brace << "\n"; for (auto Idx : indices(SuccessParamNames)) { auto ParamTy = SuccessParams[Idx].getParameterType(); if (!HandlerDesc.shouldUnwrap(ParamTy)) continue; } // continuation.resume(throwing: err) OS << ContName << tok::period << "resume" << tok::l_paren; OS << "throwing" << tok::colon << " " << *ErrName; OS << tok::r_paren << "\n"; // return } OS << tok::kw_return << "\n"; OS << tok::r_brace << "\n"; } // If we have any success params that we need to unwrap, insert a guard. for (auto Idx : indices(SuccessParamNames)) { auto &Name = SuccessParamNames[Idx]; auto ParamTy = SuccessParams[Idx].getParameterType(); if (!HandlerDesc.shouldUnwrap(ParamTy)) continue; // guard let res = res else { OS << tok::kw_guard << " " << tok::kw_let << " "; OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else; OS << " " << tok::l_brace << "\n"; // fatalError(...) OS << "fatalError" << tok::l_paren; OS << "\"Expected non-nil result '" << Name << "' for nil error\""; OS << tok::r_paren << "\n"; // End guard. OS << tok::r_brace << "\n"; } // continuation.resume(returning: (res1, res2, ...)) OS << ContName << tok::period << "resume" << tok::l_paren; OS << "returning" << tok::colon << " "; addTupleOf(SuccessParamNames, OS, [&](auto Ref) { OS << Ref; }); OS << tok::r_paren << "\n"; break; } case HandlerType::RESULT: { // continuation.resume(with: res) assert(SuccessParamNames.size() == 1); OS << ContName << tok::period << "resume" << tok::l_paren; OS << "with" << tok::colon << " " << SuccessParamNames[0]; OS << tok::r_paren << "\n"; break; } case HandlerType::INVALID: llvm_unreachable("Should not have an invalid handler here"); } OS << tok::r_brace; // end closure return OutputStr; } /// Retrieves the SourceRange of the preceding comment, or an invalid range if /// there is no preceding comment. CharSourceRange getPrecedingCommentRange(SourceLoc Loc) { auto Tokens = SF->getAllTokens(); auto TokenIter = token_lower_bound(Tokens, Loc); if (TokenIter == Tokens.end() || !TokenIter->hasComment()) return CharSourceRange(); return TokenIter->getCommentRange(); } /// Retrieves the location for the start of a comment attached to the token /// at the provided location, or the location itself if there is no comment. SourceLoc getLocIncludingPrecedingComment(SourceLoc Loc) { auto CommentRange = getPrecedingCommentRange(Loc); if (CommentRange.isInvalid()) return Loc; return CommentRange.getStart(); } /// If the provided SourceLoc has a preceding comment, print it out. void printCommentIfNeeded(SourceLoc Loc) { auto CommentRange = getPrecedingCommentRange(Loc); if (CommentRange.isValid()) OS << "\n" << CommentRange.str(); } void convertNodes(const NodesToPrint &ToPrint) { // Sort the possible comment locs in reverse order so we can pop them as we // go. SmallVector CommentLocs; CommentLocs.append(ToPrint.getPossibleCommentLocs().begin(), ToPrint.getPossibleCommentLocs().end()); llvm::sort(CommentLocs.begin(), CommentLocs.end(), [](auto lhs, auto rhs) { return lhs.getOpaquePointerValue() > rhs.getOpaquePointerValue(); }); // First print the nodes we've been asked to print. for (auto Node : ToPrint.getNodes()) { // If we need to print comments, do so now. while (!CommentLocs.empty()) { auto CommentLoc = CommentLocs.back().getOpaquePointerValue(); auto NodeLoc = Node.getStartLoc().getOpaquePointerValue(); assert(CommentLoc != NodeLoc && "Added node to both comment locs and nodes to print?"); // If the comment occurs after the node, don't print now. Wait until // the right node comes along. if (CommentLoc > NodeLoc) break; printCommentIfNeeded(CommentLocs.pop_back_val()); } OS << "\n"; convertNode(Node); } // We're done printing nodes. Make sure to output the remaining comments. while (!CommentLocs.empty()) printCommentIfNeeded(CommentLocs.pop_back_val()); } void convertNode(ASTNode Node, SourceLoc StartOverride = {}, bool ConvertCalls = true, bool IncludePrecedingComment = true) { if (!StartOverride.isValid()) StartOverride = Node.getStartLoc(); // Make sure to include any preceding comments attached to the loc if (IncludePrecedingComment) StartOverride = getLocIncludingPrecedingComment(StartOverride); llvm::SaveAndRestore RestoreLoc(LastAddedLoc, StartOverride); llvm::SaveAndRestore RestoreCount(NestedExprCount, ConvertCalls ? 0 : 1); walk(Node); addRange(LastAddedLoc, Node.getEndLoc(), /*ToEndOfToken=*/true); } void convertPattern(const Pattern *P) { // Only print semantic patterns. This cleans up the output of the transform // and works around some bogus source locs that can appear with typed // patterns in if let statements. P = P->getSemanticsProvidingPattern(); // Set up the start of the pattern as the last loc printed to make sure we // accurately fill in the gaps as we customize the printing of sub-patterns. llvm::SaveAndRestore RestoreLoc(LastAddedLoc, P->getStartLoc()); llvm::SaveAndRestore RestoreFlag(ConvertingPattern, true); walk(const_cast(P)); addRange(LastAddedLoc, P->getEndLoc(), /*ToEndOfToken*/ true); } /// Check whether \p Node requires the remainder of this scope to be wrapped /// in a \c withChecked(Throwing)Continuation. If it is necessary, add /// a call to \c withChecked(Throwing)Continuation and modify the current /// scope (\c Scopes.back() ) so that it knows it's wrapped in a continuation. /// /// Wrapping a node in a continuation is necessary if the following conditions /// are satisfied: /// - It contains a reference to the \c TopHandler's completion hander, /// because these completion handler calls need to be promoted to \c return /// statements in the refactored method, but /// - We cannot hoist the completion handler of \p Node, because it doesn't /// have an async alternative by our heuristics (e.g. because of a /// completion handler name mismatch or because it also returns a value /// synchronously). void wrapScopeInContinationIfNecessary(ASTNode Node) { if (NestedExprCount != 0) { // We can't start a continuation in the middle of an expression return; } if (Scopes.back().isWrappedInContination()) { // We are already in a continuation. No need to add another one. return; } if (!DeclReferenceFinder::containsReference(Node, TopHandler.getHandler())) { // The node doesn't have a reference to the function's completion handler. // It can stay a call with a completion handler, because we don't need to // promote a completion handler call to a 'return'. return; } // Wrap the current call in a continuation Identifier contName = createUniqueName("continuation"); Scopes.back().Names.insert(contName); Scopes.back().ContinuationName = contName; insertCustom(Node.getStartLoc(), [&]() { OS << tok::kw_return << ' '; if (TopHandler.HasError) { OS << tok::kw_try << ' '; } OS << "await "; if (TopHandler.HasError) { OS << "withCheckedThrowingContinuation "; } else { OS << "withCheckedContinuation "; } OS << tok::l_brace << ' ' << contName << ' ' << tok::kw_in << '\n'; }); } bool walkToPatternPre(Pattern *P) override { // If we're not converting a pattern, there's nothing extra to do. if (!ConvertingPattern) return true; // When converting a pattern, don't print the 'let' or 'var' of binding // subpatterns, as they're illegal when nested in PBDs, and we print a // top-level one. if (auto *BP = dyn_cast(P)) { return addCustom(BP->getSourceRange(), [&]() { convertPattern(BP->getSubPattern()); }); } return true; } bool walkToDeclPre(Decl *D, CharSourceRange Range) override { if (isa(D)) { // We can't hoist a closure inside a PatternBindingDecl. If it contains // a call to the completion handler, wrap it in a continuation. wrapScopeInContinationIfNecessary(D); NestedExprCount++; return true; } // Functions and types already have their names in \c Scopes.Names, only // variables should need to be renamed. if (isa(D)) { // If we don't already have a name for the var, assign it one. Note that // vars in binding patterns may already have assigned names here. if (Names.find(D) == Names.end()) { auto Ident = assignUniqueName(D, StringRef()); Scopes.back().Names.insert(Ident); } addCustom(D->getSourceRange(), [&]() { OS << newNameFor(D); }); } // Note we don't walk into any nested local function decls. If we start // doing so in the future, be sure to update the logic that deals with // converting unhandled returns into placeholders in walkToStmtPre. return false; } bool walkToDeclPost(Decl *D) override { NestedExprCount--; return true; } #define PLACEHOLDER_START "<#" #define PLACEHOLDER_END "#>" bool walkToExprPre(Expr *E) override { // TODO: Handle Result.get as well if (auto *DRE = dyn_cast(E)) { if (auto *D = DRE->getDecl()) { // Look through to the parent var decl if we have one. This ensures we // look at the var in a case stmt's pattern rather than the var that's // implicitly declared in the body. if (auto *VD = dyn_cast(D)) { if (auto *Parent = VD->getParentVarDecl()) D = Parent; } bool AddPlaceholder = Placeholders.count(D); StringRef Name = newNameFor(D, false); if (AddPlaceholder || !Name.empty()) return addCustom(DRE->getSourceRange(), [&]() { if (AddPlaceholder) OS << PLACEHOLDER_START; if (!Name.empty()) OS << Name; else D->getName().print(OS); if (AddPlaceholder) OS << PLACEHOLDER_END; }); } } else if (isa(E) || isa(E)) { // Remove a force unwrap or optional chain of a returned success value, // as it will no longer be optional. For force unwraps, this is always a // valid transform. For optional chains, it is a locally valid transform // within the optional chain e.g foo?.x -> foo.x, but may change the type // of the overall chain, which could cause errors elsewhere in the code. // However this is generally more useful to the user than just leaving // 'foo' as a placeholder. Note this is only the case when no other // optionals are involved in the chain, e.g foo?.x?.y -> foo.x?.y is // completely valid. if (auto *D = E->getReferencedDecl().getDecl()) { if (Unwraps.count(D)) return addCustom(E->getSourceRange(), [&]() { OS << newNameFor(D, true); }); } } else if (CallExpr *CE = TopHandler.getAsHandlerCall(E)) { if (Scopes.back().isWrappedInContination()) { return addCustom(E->getSourceRange(), [&]() { convertHandlerToContinuationResume(CE); }); } else if (NestedExprCount == 0) { return addCustom(E->getSourceRange(), [&]() { convertHandlerToReturnOrThrows(CE); }); } } else if (auto *CE = dyn_cast(E)) { // Try and hoist a call's completion handler. Don't do so if // - the current expression is nested (we can't start hoisting in the // middle of an expression) // - the current scope is wrapped in a continuation (we can't have await // calls in the continuation block) if (NestedExprCount == 0 && !Scopes.back().isWrappedInContination()) { // If the refactoring is on the call itself, do not require the callee // to have the @available attribute or a completion-like name. auto HandlerDesc = AsyncHandlerParamDesc::find( getUnderlyingFunc(CE->getFn()), /*RequireAttributeOrName=*/StartNode.dyn_cast() != CE); if (HandlerDesc.isValid()) { return addCustom(CE->getSourceRange(), [&]() { addHoistedCallback(CE, HandlerDesc); }); } } } // A void SingleValueStmtExpr is semantically more like a statement than // an expression, so recurse without bumping the expr depth or wrapping in // continuation. if (auto *SVE = dyn_cast(E)) { auto ty = SVE->getType(); if (!ty || ty->isVoid()) return true; } // We didn't do any special conversion for this expression. If needed, wrap // it in a continuation. wrapScopeInContinationIfNecessary(E); NestedExprCount++; return true; } bool replaceRangeWithPlaceholder(SourceRange range) { return addCustom(range, [&]() { OS << PLACEHOLDER_START; addRange(range, /*toEndOfToken*/ true); OS << PLACEHOLDER_END; }); } bool walkToExprPost(Expr *E) override { if (auto *SVE = dyn_cast(E)) { auto ty = SVE->getType(); if (!ty || ty->isVoid()) return true; } NestedExprCount--; return true; } #undef PLACEHOLDER_START #undef PLACEHOLDER_END bool walkToStmtPre(Stmt *S) override { // CaseStmt has an implicit BraceStmt inside it, which *should* start a new // scope, so don't check isImplicit here. if (startsNewScope(S)) { // Add all names of decls referenced within this statement that aren't // also declared first, plus any contexts. Note that \c getReferencedDecl // will only return a value for a \c BraceStmt. This means that \c IfStmt // (and other statements with conditions) will have their own empty scope, // which is fine for our purposes - their existing names are always valid. // The body of those statements will include the decls if they've been // referenced, so shadowing is still avoided there. if (auto *ReferencedDecls = ScopedDecls.getReferencedDecls(S)) { llvm::DenseSet Decls; for (auto DeclAndNumRefs : *ReferencedDecls) Decls.insert(DeclAndNumRefs.first); addNewScope(Decls); } else { addNewScope({}); } } else if (Hoisting && !S->isImplicit()) { // Some break and return statements need to be turned into placeholders, // as they may no longer perform the control flow that the user is // expecting. if (auto *BS = dyn_cast(S)) { // For a break, if it's jumping out of a switch statement that we've // re-written as a part of the transform, turn it into a placeholder, as // it would have been lifted out of the switch statement. if (auto *SS = dyn_cast(BS->getTarget())) { if (HandledSwitches.contains(SS)) return replaceRangeWithPlaceholder(S->getSourceRange()); } } else if (isa(S) && NestedExprCount == 0) { // For a return, if it's not nested inside another closure or function, // turn it into a placeholder, as it will be lifted out of the callback. // Note that we only turn the 'return' token into a placeholder as we // still want to be able to apply transforms to the argument. replaceRangeWithPlaceholder(S->getStartLoc()); } } return true; } bool walkToStmtPost(Stmt *S) override { if (startsNewScope(S)) { bool ClosedScopeWasWrappedInContinuation = Scopes.back().isWrappedInContination(); Scopes.pop_back(); if (ClosedScopeWasWrappedInContinuation && !Scopes.back().isWrappedInContination()) { // The nested scope was wrapped in a continuation but the current one // isn't anymore. Add the '}' that corresponds to the the call to // withChecked(Throwing)Continuation. insertCustom(S->getEndLoc(), [&]() { OS << tok::r_brace << '\n'; }); } } return true; } bool addCustom(SourceRange Range, llvm::function_ref Custom = {}) { addRange(LastAddedLoc, Range.Start); Custom(); LastAddedLoc = Lexer::getLocForEndOfToken(SM, Range.End); return false; } /// Insert custom text at the given \p Loc that shouldn't replace any existing /// source code. bool insertCustom(SourceLoc Loc, llvm::function_ref Custom = {}) { addRange(LastAddedLoc, Loc); Custom(); LastAddedLoc = Loc; return false; } void addRange(SourceLoc Start, SourceLoc End, bool ToEndOfToken = false) { if (ToEndOfToken) { OS << Lexer::getCharSourceRangeFromSourceRange(SM, SourceRange(Start, End)) .str(); } else { OS << CharSourceRange(SM, Start, End).str(); } } void addRange(SourceRange Range, bool ToEndOfToken = false) { addRange(Range.Start, Range.End, ToEndOfToken); } void addFuncDecl(const FuncDecl *FD) { auto *Params = FD->getParameters(); auto *HandlerParam = TopHandler.getHandlerParam(); auto ParamPos = TopHandler.handlerParamPosition(); // If the completion handler parameter has a default argument, the async // version is effectively @discardableResult, as not all the callers care // about receiving the completion call. if (HandlerParam && HandlerParam->isDefaultArgument()) OS << tok::at_sign << "discardableResult" << "\n"; // First chunk: start -> the parameter to remove (if any) SourceLoc LeftEndLoc; switch (ParamPos) { case AsyncHandlerParamDesc::Position::None: case AsyncHandlerParamDesc::Position::Only: case AsyncHandlerParamDesc::Position::First: // Handler is the first param (or there is none), so only include the ( LeftEndLoc = Params->getLParenLoc().getAdvancedLoc(1); break; case AsyncHandlerParamDesc::Position::Middle: // Handler is somewhere in the middle of the params, so we need to // include any comments and comma up until the handler LeftEndLoc = Params->get(TopHandler.Index)->getStartLoc(); LeftEndLoc = getLocIncludingPrecedingComment(LeftEndLoc); break; case AsyncHandlerParamDesc::Position::Last: // Handler is the last param, which means we don't want the comma. This // is a little annoying since we *do* want the comments past for the // last parameter LeftEndLoc = Lexer::getLocForEndOfToken( SM, Params->get(TopHandler.Index - 1)->getEndLoc()); // Skip to the end of any comments Token Next = Lexer::getTokenAtLocation(SM, LeftEndLoc, CommentRetentionMode::None); if (Next.getKind() != tok::NUM_TOKENS) LeftEndLoc = Next.getLoc(); break; } addRange(FD->getSourceRangeIncludingAttrs().Start, LeftEndLoc); // Second chunk: end of the parameter to remove -> right parenthesis SourceLoc MidStartLoc; SourceLoc MidEndLoc = Params->getRParenLoc().getAdvancedLoc(1); switch (ParamPos) { case AsyncHandlerParamDesc::Position::None: // No handler param, so make sure to include them all MidStartLoc = LeftEndLoc; break; case AsyncHandlerParamDesc::Position::First: case AsyncHandlerParamDesc::Position::Middle: // Handler param is either the first or one of the middle params. Skip // past it but make sure to include comments preceding the param after // the handler MidStartLoc = Params->get(TopHandler.Index + 1)->getStartLoc(); MidStartLoc = getLocIncludingPrecedingComment(MidStartLoc); break; case AsyncHandlerParamDesc::Position::Only: case AsyncHandlerParamDesc::Position::Last: // Handler param is last, this is easy since there's no other params // to copy over MidStartLoc = Params->getRParenLoc(); break; } addRange(MidStartLoc, MidEndLoc); // Third chunk: add in async and throws if necessary if (!FD->hasAsync()) OS << " async"; if (FD->hasThrows() || TopHandler.HasError) // TODO: Add throws if converting a function and it has a converted call // without a do/catch OS << " " << tok::kw_throws; // Fourth chunk: if no parent handler (ie. not adding an async // alternative), the rest of the decl. Otherwise, add in the new return // type if (!TopHandler.isValid()) { SourceLoc RightStartLoc = MidEndLoc; if (FD->hasThrows()) { RightStartLoc = Lexer::getLocForEndOfToken(SM, FD->getThrowsLoc()); } SourceLoc RightEndLoc = FD->getBody() ? FD->getBody()->getLBraceLoc() : RightStartLoc; addRange(RightStartLoc, RightEndLoc); return; } SmallVector Scratch; auto ReturnTypes = TopHandler.getAsyncReturnTypes(Scratch); if (ReturnTypes.empty()) { OS << " "; return; } // Print the function result type, making sure to omit a '-> Void' return. if (!TopHandler.willAsyncReturnVoid()) { OS << " -> "; addAsyncFuncReturnType(TopHandler); } if (FD->hasBody()) OS << " "; // TODO: Should remove the generic param and where clause for the error // param if it exists (and no other parameter uses that type) TrailingWhereClause *TWC = FD->getTrailingWhereClause(); if (TWC && TWC->getWhereLoc().isValid()) { auto Range = TWC->getSourceRange(); OS << Lexer::getCharSourceRangeFromSourceRange(SM, Range).str(); if (FD->hasBody()) OS << " "; } } void addFallbackVars(ArrayRef FallbackParams, const ClosureCallbackParams &AllParams) { for (auto *Param : FallbackParams) { auto Ty = Param->getType(); auto ParamName = newNameFor(Param); // If this is the known bool success param, we can use 'let' and type it // as non-optional, as it gets bound in both blocks. if (AllParams.isKnownBoolFlagParam(Param)) { OS << tok::kw_let << " " << ParamName << ": "; Ty->print(OS); OS << "\n"; continue; } OS << tok::kw_var << " " << ParamName << ": "; Ty->print(OS); if (!Ty->getOptionalObjectType()) OS << "?"; OS << " = " << tok::kw_nil << "\n"; } } void addDo() { OS << tok::kw_do << " " << tok::l_brace << "\n"; } /// Assuming that \p Result represents an error result to completion handler, /// returns \c true if the error has already been handled through a /// 'try await'. bool isErrorAlreadyHandled(HandlerResult Result) { assert(Result.isError()); assert(Result.args().size() == 1 && "There should only be one error parameter"); // We assume that the error has already been handled if its variable // declaration doesn't exist anymore, which is the case if it's in // Placeholders but not in Unwraps (if it's in Placeholders and Unwraps // an optional Error has simply been promoted to a non-optional Error). if (auto *DRE = dyn_cast(Result.args().back().getExpr())) { if (Placeholders.count(DRE->getDecl()) && !Unwraps.count(DRE->getDecl())) { return true; } } return false; } /// Returns \c true if the source representation of \p E can be interpreted /// as an expression returning an Optional value. bool isExpressionOptional(Expr *E) { if (isa(E)) { // E is downgrading a non-Optional result to an Optional. Its source // representation isn't Optional. return false; } if (auto DRE = dyn_cast(E)) { if (Unwraps.count(DRE->getDecl())) { // E has been promoted to a non-Optional value. It can't be used as an // Optional anymore. return false; } } if (!E->getType().isNull() && E->getType()->isOptional()) { return true; } // We couldn't determine the type. Assume non-Optional. return false; } /// Converts a call \p CE to a completion handler. Depending on the call it /// will be interpreted as a call that's returning a success result, an error /// or, if the call is completely ambiguous, adds an if-let that checks if the /// error is \c nil at runtime and dispatches to the success or error case /// depending on it. /// \p AddConvertedHandlerCall needs to add the converted version of the /// completion handler. Depending on the given \c HandlerResult, it must be /// intepreted as a success or error call. /// \p AddConvertedErrorCall must add the converted equivalent of returning an /// error. The passed \c StringRef contains the name of a variable that is of /// type 'Error'. void convertHandlerCall( const CallExpr *CE, llvm::function_ref AddConvertedHandlerCall, llvm::function_ref AddConvertedErrorCall) { auto Result = TopHandler.extractResultArgs(CE, /*ReturnErrorArgsIfAmbiguous=*/true); if (!TopHandler.isAmbiguousCallToParamHandler(CE)) { if (Result.isError()) { if (!isErrorAlreadyHandled(Result)) { // If the error has already been handled, we don't need to add another // throwing call. AddConvertedHandlerCall(Result); } } else { AddConvertedHandlerCall(Result); } } else { assert(Result.isError() && "If the call was ambiguous, we should have " "retrieved its error representation"); assert(Result.args().size() == 1 && "There should only be one error parameter"); Expr *ErrorExpr = Result.args().back().getExpr(); if (isErrorAlreadyHandled(Result)) { // The error has already been handled, interpret the call as a success // call. auto SuccessExprs = TopHandler.extractResultArgs( CE, /*ReturnErrorArgsIfAmbiguous=*/false); AddConvertedHandlerCall(SuccessExprs); } else if (!isExpressionOptional(ErrorExpr)) { // The error is never nil. No matter what the success param is, we // interpret it as an error call. AddConvertedHandlerCall(Result); } else { // The call was truly ambiguous. Add an // if let error = { // throw error // or equivalent // } else { // // } auto SuccessExprs = TopHandler.extractResultArgs( CE, /*ReturnErrorArgsIfAmbiguous=*/false); // The variable 'error' is only available in the 'if let' scope, so we // don't need to create a new unique one. StringRef ErrorName = "error"; OS << tok::kw_if << ' ' << tok::kw_let << ' ' << ErrorName << ' ' << tok::equal << ' '; convertNode(ErrorExpr, /*StartOverride=*/{}, /*ConvertCalls=*/false); OS << ' ' << tok::l_brace << '\n'; AddConvertedErrorCall(ErrorName); OS << tok::r_brace << ' ' << tok::kw_else << ' ' << tok::l_brace << '\n'; AddConvertedHandlerCall(SuccessExprs); OS << '\n' << tok::r_brace; } } } /// Convert a call \p CE to a completion handler to its 'return' or 'throws' /// equivalent. void convertHandlerToReturnOrThrows(const CallExpr *CE) { return convertHandlerCall( CE, [&](HandlerResult Exprs) { convertHandlerToReturnOrThrowsImpl(CE, Exprs); }, [&](StringRef ErrorName) { OS << tok::kw_throw << ' ' << ErrorName << '\n'; }); } /// Convert the call \p CE to a completion handler to its 'return' or 'throws' /// equivalent, where \p Result determines whether the call should be /// interpreted as an error or success call. void convertHandlerToReturnOrThrowsImpl(const CallExpr *CE, HandlerResult Result) { bool AddedReturnOrThrow = true; if (!Result.isError()) { // It's possible the user has already written an explicit return statement // for the completion handler call, e.g 'return completion(args...)'. In // that case, be sure not to add another return. auto *parent = getWalker().Parent.getAsStmt(); if (isa_and_nonnull(parent) && !cast(parent)->isImplicit()) { // The statement already has a return keyword. Don't add another one. AddedReturnOrThrow = false; } else { OS << tok::kw_return; } } else { OS << tok::kw_throw; } auto Args = Result.args(); if (!Args.empty()) { if (AddedReturnOrThrow) OS << ' '; addTupleOf(Args, OS, [&](Argument Arg) { // Special case: If the completion handler is a params handler that // takes an error, we could pass arguments to it without unwrapping // them. E.g. // simpleWithError { (res: String?, error: Error?) in // completion(res, nil) // } // But after refactoring `simpleWithError` to an async function we have // let res: String = await simple() // and `res` is no longer an `Optional`. Thus it's in `Placeholders` and // `Unwraps` and any reference to it will be replaced by a placeholder // unless it is wrapped in an unwrapping expression. This would cause us // to create `return <#res# >`. // Under our assumption that either the error or the result parameter // are non-nil, the above call to the completion handler is equivalent // to // completion(res!, nil) // which correctly yields // return res // Synthesize the force unwrap so that we get the expected results. auto *E = Arg.getExpr(); if (TopHandler.getHandlerType() == HandlerType::PARAMS && TopHandler.HasError) { if (auto DRE = dyn_cast(E->getSemanticsProvidingExpr())) { auto D = DRE->getDecl(); if (Unwraps.count(D)) { E = new (getASTContext()) ForceValueExpr(E, SourceLoc()); } } } // Can't just add the range as we need to perform replacements convertNode(E, /*StartOverride=*/Arg.getLabelLoc(), /*ConvertCalls=*/false); }); } } /// Convert a call \p CE to a completion handler to resumes of the /// continuation that's currently on top of the stack. void convertHandlerToContinuationResume(const CallExpr *CE) { return convertHandlerCall( CE, [&](HandlerResult Exprs) { convertHandlerToContinuationResumeImpl(CE, Exprs); }, [&](StringRef ErrorName) { Identifier ContinuationName = Scopes.back().ContinuationName; OS << ContinuationName << tok::period << "resume" << tok::l_paren << "throwing" << tok::colon << ' ' << ErrorName; OS << tok::r_paren << '\n'; }); } /// Convert a call \p CE to a completion handler to resumes of the /// continuation that's currently on top of the stack. /// \p Result determines whether the call should be interpreted as a success /// or error call. void convertHandlerToContinuationResumeImpl(const CallExpr *CE, HandlerResult Result) { assert(Scopes.back().isWrappedInContination()); std::vector Args; StringRef ResumeArgumentLabel; switch (TopHandler.getHandlerType()) { case HandlerType::PARAMS: { Args = Result.args(); if (!Result.isError()) { ResumeArgumentLabel = "returning"; } else { ResumeArgumentLabel = "throwing"; } break; } case HandlerType::RESULT: { Args = {CE->getArgs()->begin(), CE->getArgs()->end()}; ResumeArgumentLabel = "with"; break; } case HandlerType::INVALID: llvm_unreachable("Invalid top handler"); } // A vector in which each argument of Result has an entry. If the entry is // not empty then that argument has been unwrapped using 'guard let' into // a variable with that name. SmallVector ArgNames; ArgNames.reserve(Args.size()); /// When unwrapping a result argument \p Arg into a variable using /// 'guard let' return a suitable name for the unwrapped variable. /// \p ArgIndex is the index of \p Arg in the results passed to the /// completion handler. auto GetSuitableNameForGuardUnwrap = [&](Expr *Arg, unsigned ArgIndex) -> Identifier { // If Arg is a DeclRef, use its name for the guard unwrap. // guard let myVar1 = myVar. if (auto DRE = dyn_cast(Arg)) { return createUniqueName(DRE->getDecl()->getBaseIdentifier().str()); } else if (auto IIOE = dyn_cast(Arg)) { if (auto DRE = dyn_cast(IIOE->getSubExpr())) { return createUniqueName(DRE->getDecl()->getBaseIdentifier().str()); } } if (Args.size() == 1) { // We only have a single result. 'result' seems a resonable name. return createUniqueName("result"); } else { // We are returning a tuple. Name the result elements 'result' + // index in tuple. return createUniqueName("result" + std::to_string(ArgIndex)); } }; unsigned ArgIndex = 0; for (auto Arg : Args) { auto *ArgExpr = Arg.getExpr(); Identifier ArgName; if (isExpressionOptional(ArgExpr) && TopHandler.HasError) { ArgName = GetSuitableNameForGuardUnwrap(ArgExpr, ArgIndex); Scopes.back().Names.insert(ArgName); OS << tok::kw_guard << ' ' << tok::kw_let << ' ' << ArgName << ' ' << tok::equal << ' '; // If the argument is a call with a trailing closure, the generated // guard statement will not compile. // e.g. 'guard let result1 = value.map { $0 + 1 } else { ... }' // doesn't compile. Adding parentheses makes the code compile. auto HasTrailingClosure = false; if (auto *CE = dyn_cast(ArgExpr)) { if (CE->getArgs()->hasAnyTrailingClosures()) HasTrailingClosure = true; } if (HasTrailingClosure) OS << tok::l_paren; convertNode(ArgExpr, /*StartOverride=*/Arg.getLabelLoc(), /*ConvertCalls=*/false); if (HasTrailingClosure) OS << tok::r_paren; OS << ' ' << tok::kw_else << ' ' << tok::l_brace << '\n'; OS << "fatalError" << tok::l_paren; OS << "\"Expected non-nil result "; if (ArgName.str() != "result") { OS << "'" << ArgName << "' "; } OS << "in the non-error case\""; OS << tok::r_paren << '\n'; OS << tok::r_brace << '\n'; } ArgNames.push_back(ArgName); ArgIndex++; } Identifier ContName = Scopes.back().ContinuationName; OS << ContName << tok::period << "resume" << tok::l_paren << ResumeArgumentLabel << tok::colon << ' '; ArgIndex = 0; addTupleOf(Args, OS, [&](Argument Arg) { Identifier ArgName = ArgNames[ArgIndex]; if (!ArgName.empty()) { OS << ArgName; } else { // Can't just add the range as we need to perform replacements convertNode(Arg.getExpr(), /*StartOverride=*/Arg.getLabelLoc(), /*ConvertCalls=*/false); } ArgIndex++; }); OS << tok::r_paren; } /// From the given expression \p E, which is an argument to a function call, /// extract the passed closure if there is one. Otherwise return \c nullptr. ClosureExpr *extractCallback(Expr *E) { E = lookThroughFunctionConversionExpr(E); if (auto Closure = dyn_cast(E)) { return Closure; } else if (auto CaptureList = dyn_cast(E)) { return dyn_cast(CaptureList->getClosureBody()); } else { return nullptr; } } /// Callback arguments marked as e.g. `@convention(block)` produce arguments /// that are `FunctionConversionExpr`. /// We don't care about the conversions and want to shave them off. Expr *lookThroughFunctionConversionExpr(Expr *E) { if (auto FunctionConversion = dyn_cast(E)) { return lookThroughFunctionConversionExpr( FunctionConversion->getSubExpr()); } else { return E; } } void addHoistedCallback(const CallExpr *CE, const AsyncHandlerParamDesc &HandlerDesc) { llvm::SaveAndRestore RestoreHoisting(Hoisting, true); auto *ArgList = CE->getArgs(); if (HandlerDesc.Index >= ArgList->size()) { DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg); return; } Expr *CallbackArg = lookThroughFunctionConversionExpr(ArgList->getExpr(HandlerDesc.Index)); if (ClosureExpr *Callback = extractCallback(CallbackArg)) { // The user is using a closure for the completion handler addHoistedClosureCallback(CE, HandlerDesc, Callback); return; } if (auto CallbackDecl = getReferencedDecl(CallbackArg)) { if (CallbackDecl == TopHandler.getHandler()) { // We are refactoring the function that declared the completion handler // that would be called here. We can't call the completion handler // anymore because it will be removed. But since the function that // declared it is being refactored to async, we can just return the // values. if (!HandlerDesc.willAsyncReturnVoid()) { OS << tok::kw_return << " "; } InlinePatternsToPrint InlinePatterns; addAwaitCall(CE, ClassifiedBlock(), {}, InlinePatterns, HandlerDesc, /*AddDeclarations*/ false); return; } // We are not removing the completion handler, so we can call it once the // async function returns. // The completion handler that is called as part of the \p CE call. // This will be called once the async function returns. auto CompletionHandler = AsyncHandlerDesc::get(CallbackDecl, /*RequireAttributeOrName=*/false); if (CompletionHandler.isValid()) { if (auto CalledFunc = getUnderlyingFunc(CE->getFn())) { StringRef HandlerName = Lexer::getCharSourceRangeFromSourceRange( SM, CallbackArg->getSourceRange()).str(); addHoistedNamedCallback( CalledFunc, CompletionHandler, HandlerName, [&] { InlinePatternsToPrint InlinePatterns; addAwaitCall(CE, ClassifiedBlock(), {}, InlinePatterns, HandlerDesc, /*AddDeclarations*/ false); }); return; } } } DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg); } /// Add a binding to a known bool flag that indicates success or failure. void addBoolFlagParamBindingIfNeeded(Optional Flag, BlockKind Block) { if (!Flag) return; // Figure out the polarity of the binding based on the block we're in and // whether the flag indicates success. auto Polarity = true; switch (Block) { case BlockKind::SUCCESS: break; case BlockKind::ERROR: Polarity = !Polarity; break; case BlockKind::FALLBACK: llvm_unreachable("Not a valid place to bind"); } if (!Flag->IsSuccessFlag) Polarity = !Polarity; OS << newNameFor(Flag->Param) << " " << tok::equal << " "; OS << (Polarity ? tok::kw_true : tok::kw_false) << "\n"; } /// Add a call to the async alternative of \p CE and convert the \p Callback /// to be executed after the async call. \p HandlerDesc describes the /// completion handler in the function that's called by \p CE and \p ArgList /// are the arguments being passed in \p CE. void addHoistedClosureCallback(const CallExpr *CE, const AsyncHandlerParamDesc &HandlerDesc, const ClosureExpr *Callback) { if (HandlerDesc.params().size() != Callback->getParameters()->size()) { DiagEngine.diagnose(CE->getStartLoc(), diag::mismatched_callback_args); return; } ClosureCallbackParams CallbackParams(HandlerDesc, Callback); ClassifiedBlocks Blocks; auto *CallbackBody = Callback->getBody(); if (!HandlerDesc.HasError) { Blocks.SuccessBlock.addNodesInBraceStmt(CallbackBody); } else if (!CallbackBody->getElements().empty()) { CallbackClassifier::classifyInto(Blocks, CallbackParams, HandledSwitches, DiagEngine, CallbackBody); } auto SuccessBindings = CallbackParams.getParamsToBind(BlockKind::SUCCESS); auto *ErrParam = CallbackParams.getErrParam(); if (DiagEngine.hadAnyError()) { // For now, only fallback when the results are params with an error param, // in which case only the names are used (defaulted to the names of the // params if none). if (HandlerDesc.Type != HandlerType::PARAMS || !HandlerDesc.HasError) return; DiagEngine.resetHadAnyError(); // Note that we don't print any inline patterns here as we just want // assignments to the names in the outer scope. InlinePatternsToPrint InlinePatterns; auto AllBindings = CallbackParams.getParamsToBind(BlockKind::FALLBACK); prepareNames(ClassifiedBlock(), AllBindings, InlinePatterns); preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams, BlockKind::FALLBACK); addFallbackVars(AllBindings, CallbackParams); addDo(); addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns, HandlerDesc, /*AddDeclarations*/ false); OS << "\n"; // If we have a known Bool success param, we need to bind it. addBoolFlagParamBindingIfNeeded(CallbackParams.getKnownBoolFlagParam(), BlockKind::SUCCESS); addFallbackCatch(CallbackParams); OS << "\n"; convertNodes(NodesToPrint::inBraceStmt(CallbackBody)); clearNames(AllBindings); return; } auto *ErrOrResultParam = ErrParam; if (auto *ResultParam = CallbackParams.getResultParam()) ErrOrResultParam = ResultParam; auto ErrorNodes = Blocks.ErrorBlock.nodesToPrint().getNodes(); bool RequireDo = !ErrorNodes.empty(); // Check if we *actually* need a do/catch (see class comment) if (ErrorNodes.size() == 1) { auto Node = ErrorNodes[0]; if (auto *HandlerCall = TopHandler.getAsHandlerCall(Node)) { auto Res = TopHandler.extractResultArgs( HandlerCall, /*ReturnErrorArgsIfAmbiguous=*/true); if (Res.args().size() == 1) { // Skip if we have the param itself or the name it's bound to auto *ArgExpr = Res.args()[0].getExpr(); auto *SingleDecl = ArgExpr->getReferencedDecl().getDecl(); auto ErrName = Blocks.ErrorBlock.boundName(ErrOrResultParam); RequireDo = SingleDecl != ErrOrResultParam && !(Res.isError() && SingleDecl && SingleDecl->getName().isSimpleName(ErrName)); } } } // If we're not requiring a 'do', we'll be dropping the error block. But // let's make sure we at least preserve the comments in the error block by // transplanting them into the success block. This should make sure they // maintain a sensible ordering. if (!RequireDo) { auto ErrorNodes = Blocks.ErrorBlock.nodesToPrint(); for (auto CommentLoc : ErrorNodes.getPossibleCommentLocs()) Blocks.SuccessBlock.addPossibleCommentLoc(CommentLoc); } if (RequireDo) { addDo(); } auto InlinePatterns = getInlinePatternsToPrint(Blocks.SuccessBlock, SuccessBindings, Callback); prepareNames(Blocks.SuccessBlock, SuccessBindings, InlinePatterns); preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams, BlockKind::SUCCESS); addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns, HandlerDesc, /*AddDeclarations=*/true); printOutOfLineBindingPatterns(Blocks.SuccessBlock, InlinePatterns); convertNodes(Blocks.SuccessBlock.nodesToPrint()); clearNames(SuccessBindings); if (RequireDo) { // We don't use inline patterns for the error path. InlinePatternsToPrint ErrInlinePatterns; // Always use the ErrParam name if none is bound. prepareNames(Blocks.ErrorBlock, llvm::makeArrayRef(ErrOrResultParam), ErrInlinePatterns, /*AddIfMissing=*/HandlerDesc.Type != HandlerType::RESULT); preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams, BlockKind::ERROR); addCatch(ErrOrResultParam); convertNodes(Blocks.ErrorBlock.nodesToPrint()); OS << "\n" << tok::r_brace; clearNames(llvm::makeArrayRef(ErrOrResultParam)); } } /// Add a call to the async alternative of \p FD. Afterwards, pass the results /// of the async call to the completion handler, named \p HandlerName and /// described by \p HandlerDesc. /// \p AddAwaitCall adds the call to the refactored async method to the output /// stream without storing the result to any variables. /// This is used when the user didn't use a closure for the callback, but /// passed in a variable or function name for the completion handler. void addHoistedNamedCallback(const FuncDecl *FD, const AsyncHandlerDesc &HandlerDesc, StringRef HandlerName, std::function AddAwaitCall) { if (HandlerDesc.HasError) { // "result" and "error" always okay to use here since they're added // in their own scope, which only contains new code. addDo(); if (!HandlerDesc.willAsyncReturnVoid()) { OS << tok::kw_let << " result"; addResultTypeAnnotationIfNecessary(FD, HandlerDesc); OS << " " << tok::equal << " "; } AddAwaitCall(); OS << "\n"; addCallToCompletionHandler("result", HandlerDesc, HandlerName); OS << "\n"; OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n"; addCallToCompletionHandler(StringRef(), HandlerDesc, HandlerName); OS << "\n" << tok::r_brace; // end catch } else { // This code may be placed into an existing scope, in that case create // a unique "result" name so that it doesn't cause shadowing or redecls. StringRef ResultName; if (!HandlerDesc.willAsyncReturnVoid()) { Identifier Unique = createUniqueName("result"); Scopes.back().Names.insert(Unique); ResultName = Unique.str(); OS << tok::kw_let << " " << ResultName; addResultTypeAnnotationIfNecessary(FD, HandlerDesc); OS << " " << tok::equal << " "; } else { // The name won't end up being used, just give it a bogus one so that // the result path is taken (versus the error path). ResultName = "result"; } AddAwaitCall(); OS << "\n"; addCallToCompletionHandler(ResultName, HandlerDesc, HandlerName); } } /// Checks whether a binding pattern for a given decl can be printed inline in /// an await call, e.g 'let ((x, y), z) = await foo()', where '(x, y)' is the /// inline pattern. const Pattern * bindingPatternToPrintInline(const Decl *D, const ClassifiedBlock &Block, const ClosureExpr *CallbackClosure) { // Only currently done for callback closures. if (!CallbackClosure) return nullptr; // If we can reduce the pattern bindings down to a single pattern, we may // be able to print it inline. auto *P = Block.getSinglePatternFor(D); if (!P) return nullptr; // Patterns that bind a single var are always printed inline. if (P->getSingleVar()) return P; // If we have a multi-var binding, and the decl being bound is referenced // elsewhere in the block, we cannot print the pattern immediately in the // await call. Instead, we'll print it out of line. auto *Decls = ScopedDecls.getReferencedDecls(CallbackClosure->getBody()); assert(Decls); auto NumRefs = Decls->lookup(D); return NumRefs == 1 ? P : nullptr; } /// Retrieve a map of patterns to print inline for an array of param decls. InlinePatternsToPrint getInlinePatternsToPrint(const ClassifiedBlock &Block, ArrayRef Params, const ClosureExpr *CallbackClosure) { InlinePatternsToPrint Patterns; for (auto *Param : Params) { if (auto *P = bindingPatternToPrintInline(Param, Block, CallbackClosure)) Patterns[Param] = P; } return Patterns; } /// Print any out of line binding patterns that could not be printed as inline /// patterns. These typically appear directly after an await call, e.g: /// \code /// let x = await foo() /// let (y, z) = x /// \endcode void printOutOfLineBindingPatterns(const ClassifiedBlock &Block, const InlinePatternsToPrint &InlinePatterns) { for (auto &Entry : Block.paramPatternBindings()) { auto *D = Entry.first; auto Aliases = Block.getAliasesFor(D); for (auto *P : Entry.second) { // If we already printed this as an inline pattern, there's nothing else // to do. if (InlinePatterns.lookup(D) == P) continue; // If this is an alias binding, it can be elided. if (auto *SingleVar = P->getSingleVar()) { if (Aliases.contains(SingleVar)) continue; } auto HasMutable = P->hasAnyMutableBindings(); OS << "\n" << (HasMutable ? tok::kw_var : tok::kw_let) << " "; convertPattern(P); OS << " = "; OS << newNameFor(D); } } } /// Prints an \c await call to an \c async function, binding any return values /// into variables. /// /// \param CE The call expr to convert. /// \param SuccessBlock The nodes present in the success block following the /// call. /// \param SuccessParams The success parameters, which will be printed as /// return values. /// \param InlinePatterns A map of patterns that can be printed inline for /// a given param. /// \param HandlerDesc A description of the completion handler. /// \param AddDeclarations Whether or not to add \c let or \c var keywords to /// the return value bindings. void addAwaitCall(const CallExpr *CE, const ClassifiedBlock &SuccessBlock, ArrayRef SuccessParams, const InlinePatternsToPrint &InlinePatterns, const AsyncHandlerParamDesc &HandlerDesc, bool AddDeclarations) { auto *Args = CE->getArgs(); // Print the bindings to match the completion handler success parameters, // making sure to omit in the case of a Void return. if (!SuccessParams.empty() && !HandlerDesc.willAsyncReturnVoid()) { auto AllLet = true; // Gather the items to print for the variable bindings. This can either be // a param decl, or a pattern that binds it. using DeclOrPattern = llvm::PointerUnion; SmallVector ToPrint; for (auto *Param : SuccessParams) { // Check if we have an inline pattern to print. if (auto *P = InlinePatterns.lookup(Param)) { if (P->hasAnyMutableBindings()) AllLet = false; ToPrint.push_back(P); continue; } ToPrint.push_back(Param); } if (AddDeclarations) { if (AllLet) { OS << tok::kw_let; } else { OS << tok::kw_var; } OS << " "; } // 'res =' or '(res1, res2, ...) =' addTupleOf(ToPrint, OS, [&](DeclOrPattern Elt) { if (auto *P = Elt.dyn_cast()) { convertPattern(P); return; } OS << newNameFor(Elt.get()); }); OS << " " << tok::equal << " "; } if (HandlerDesc.HasError) { OS << tok::kw_try << " "; } OS << "await "; // Try to replace the name with that of the alternative. Use the existing // name if for some reason that's not possible. bool NameAdded = false; if (HandlerDesc.Alternative) { const ValueDecl *Named = HandlerDesc.Alternative; if (auto *Accessor = dyn_cast(HandlerDesc.Alternative)) Named = Accessor->getStorage(); if (!Named->getBaseName().isSpecial()) { Names.try_emplace(HandlerDesc.Func, Named->getBaseName().getIdentifier()); convertNode(CE->getFn(), /*StartOverride=*/{}, /*ConvertCalls=*/false, /*IncludeComments=*/false); NameAdded = true; } } if (!NameAdded) { addRange(CE->getStartLoc(), CE->getFn()->getEndLoc(), /*ToEndOfToken=*/true); } if (!HandlerDesc.alternativeIsAccessor()) OS << tok::l_paren; size_t ConvertedArgIndex = 0; ArrayRef AlternativeParams; if (HandlerDesc.Alternative) AlternativeParams = HandlerDesc.Alternative->getParameters()->getArray(); for (auto I : indices(*Args)) { auto Arg = Args->get(I); auto *ArgExpr = Arg.getExpr(); if (I == HandlerDesc.Index || isa(ArgExpr)) continue; if (ConvertedArgIndex > 0) OS << tok::comma << " "; if (HandlerDesc.Alternative) { // Skip argument if it's defaulted and has a different name while (ConvertedArgIndex < AlternativeParams.size() && AlternativeParams[ConvertedArgIndex]->isDefaultArgument() && AlternativeParams[ConvertedArgIndex]->getArgumentName() != Arg.getLabel()) { ConvertedArgIndex++; } if (ConvertedArgIndex < AlternativeParams.size()) { // Could have a different argument label (or none), so add it instead auto Name = AlternativeParams[ConvertedArgIndex]->getArgumentName(); if (!Name.empty()) OS << Name << ": "; convertNode(ArgExpr, /*StartOverride=*/{}, /*ConvertCalls=*/false); ConvertedArgIndex++; continue; } // Fallthrough if arguments don't match up for some reason } // Can't just add the range as we need to perform replacements. Also // make sure to include the argument label (if any) convertNode(ArgExpr, /*StartOverride=*/Arg.getLabelLoc(), /*ConvertCalls=*/false); ConvertedArgIndex++; } if (!HandlerDesc.alternativeIsAccessor()) OS << tok::r_paren; } void addFallbackCatch(const ClosureCallbackParams &Params) { auto *ErrParam = Params.getErrParam(); assert(ErrParam); auto ErrName = newNameFor(ErrParam); OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n" << ErrName << " = error\n"; // If we have a known Bool success param, we need to bind it. addBoolFlagParamBindingIfNeeded(Params.getKnownBoolFlagParam(), BlockKind::ERROR); OS << tok::r_brace; } void addCatch(const ParamDecl *ErrParam) { OS << "\n" << tok::r_brace << " " << tok::kw_catch << " "; auto ErrName = newNameFor(ErrParam, false); if (!ErrName.empty() && ErrName != "_") { OS << tok::kw_let << " " << ErrName << " "; } OS << tok::l_brace; } void preparePlaceholdersAndUnwraps(AsyncHandlerDesc HandlerDesc, const ClosureCallbackParams &Params, BlockKind Block) { // Params that have been dropped always need placeholdering. for (auto *Param : Params.getAllParams()) { if (!Params.hasBinding(Param, Block)) Placeholders.insert(Param); } // For the fallback case, no other params need placeholdering, as they are // all freely accessible in the fallback case. if (Block == BlockKind::FALLBACK) return; switch (HandlerDesc.Type) { case HandlerType::PARAMS: { auto *ErrParam = Params.getErrParam(); auto SuccessParams = Params.getSuccessParams(); switch (Block) { case BlockKind::FALLBACK: llvm_unreachable("Already handled"); case BlockKind::ERROR: if (ErrParam) { if (HandlerDesc.shouldUnwrap(ErrParam->getType())) { Placeholders.insert(ErrParam); Unwraps.insert(ErrParam); } // Can't use success params in the error body Placeholders.insert(SuccessParams.begin(), SuccessParams.end()); } break; case BlockKind::SUCCESS: for (auto *SuccessParam : SuccessParams) { auto Ty = SuccessParam->getType(); if (HandlerDesc.shouldUnwrap(Ty)) { // Either unwrap or replace with a placeholder if there's some other // reference Unwraps.insert(SuccessParam); Placeholders.insert(SuccessParam); } // Void parameters get omitted where possible, so turn any reference // into a placeholder, as its usage is unlikely what the user wants. if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid()) Placeholders.insert(SuccessParam); } // Can't use the error param in the success body if (ErrParam) Placeholders.insert(ErrParam); break; } break; } case HandlerType::RESULT: { // Any uses of the result parameter in the current body (that aren't // replaced) are invalid, so replace them with a placeholder. auto *ResultParam = Params.getResultParam(); assert(ResultParam); Placeholders.insert(ResultParam); break; } default: llvm_unreachable("Unhandled handler type"); } } /// Add a mapping from each passed parameter to a new name, possibly /// synthesizing a new one if hoisting it would cause a redeclaration or /// shadowing. If there's no bound name and \c AddIfMissing is false, no /// name will be added. void prepareNames(const ClassifiedBlock &Block, ArrayRef Params, const InlinePatternsToPrint &InlinePatterns, bool AddIfMissing = true) { for (auto *PD : Params) { // If this param is to be replaced by a pattern that binds multiple // separate vars, it's not actually going to be added to the scope, and // therefore doesn't need naming. This avoids needing to rename a var with // the same name later on in the scope, as it's not actually clashing. if (auto *P = InlinePatterns.lookup(PD)) { if (!P->getSingleVar()) continue; } auto Name = Block.boundName(PD); if (Name.empty() && !AddIfMissing) continue; auto Ident = assignUniqueName(PD, Name); // Also propagate the name to any aliases. for (auto *Alias : Block.getAliasesFor(PD)) Names[Alias] = Ident; } } /// Returns a unique name using \c Name as base that doesn't clash with any /// other names in the current scope. Identifier createUniqueName(StringRef Name) { Identifier Ident = getASTContext().getIdentifier(Name); if (Name == "_") return Ident; auto &CurrentNames = Scopes.back().Names; if (CurrentNames.count(Ident)) { // Add a number to the end of the name until it's unique given the current // names in scope. llvm::SmallString<32> UniquedName; unsigned UniqueId = 1; do { UniquedName = Name; UniquedName.append(std::to_string(UniqueId)); Ident = getASTContext().getIdentifier(UniquedName); UniqueId++; } while (CurrentNames.count(Ident)); } return Ident; } /// Create a unique name for the variable declared by \p D that doesn't /// clash with any other names in scope, using \p BoundName as the base name /// if not empty and the name of \p D otherwise. Adds this name to both /// \c Names and the current scope's names (\c Scopes.Names). Identifier assignUniqueName(const Decl *D, StringRef BoundName) { Identifier Ident; if (BoundName.empty()) { BoundName = getDeclName(D).userFacingName(); if (BoundName.empty()) return Ident; } if (BoundName.startswith("$")) { llvm::SmallString<8> NewName; NewName.append("val"); NewName.append(BoundName.drop_front()); Ident = createUniqueName(NewName); } else { Ident = createUniqueName(BoundName); } Names.try_emplace(D, Ident); Scopes.back().Names.insert(Ident); return Ident; } StringRef newNameFor(const Decl *D, bool Required = true) { auto Res = Names.find(D); if (Res == Names.end()) { assert(!Required && "Missing name for decl when one was required"); return StringRef(); } return Res->second.str(); } void addNewScope(const llvm::DenseSet &Decls) { if (Scopes.empty()) { Scopes.emplace_back(/*ContinuationName=*/Identifier()); } else { // If the parent scope is nested in a continuation, the new one is also. // Carry over the continuation name. Identifier PreviousContinuationName = Scopes.back().ContinuationName; Scopes.emplace_back(PreviousContinuationName); } for (auto D : Decls) { auto Name = getDeclName(D); if (!Name.empty()) Scopes.back().Names.insert(Name); } } void clearNames(ArrayRef Params) { for (auto *Param : Params) { Unwraps.erase(Param); Placeholders.erase(Param); Names.erase(Param); } } /// Adds a forwarding call to the old completion handler function, with /// \p HandlerReplacement that allows for a custom replacement or, if empty, /// removal of the completion handler closure. void addForwardingCallTo(const FuncDecl *FD, StringRef HandlerReplacement) { OS << FD->getBaseName() << tok::l_paren; auto *Params = FD->getParameters(); size_t ConvertedArgsIndex = 0; for (size_t I = 0, E = Params->size(); I < E; ++I) { if (I == TopHandler.Index) { /// If we're not replacing the handler with anything, drop it. if (HandlerReplacement.empty()) continue; // Use a trailing closure if the handler is the last param if (I == E - 1) { OS << tok::r_paren << " "; OS << HandlerReplacement; return; } // Otherwise fall through to do the replacement. } if (ConvertedArgsIndex > 0) OS << tok::comma << " "; const auto *Param = Params->get(I); if (!Param->getArgumentName().empty()) OS << Param->getArgumentName() << tok::colon << " "; if (I == TopHandler.Index) { OS << HandlerReplacement; } else { OS << Param->getParameterName(); } ConvertedArgsIndex++; } OS << tok::r_paren; } /// Adds a forwarded error argument to a completion handler call. If the error /// type of \p HandlerDesc is more specialized than \c Error, an /// 'as! CustomError' cast to the more specialized error type will be added to /// the output stream. void addForwardedErrorArgument(StringRef ErrorName, const AsyncHandlerDesc &HandlerDesc) { // If the error type is already Error, we can pass it as-is. auto ErrorType = *HandlerDesc.getErrorType(); if (ErrorType->getCanonicalType() == getASTContext().getErrorExistentialType()) { OS << ErrorName; return; } // Otherwise we need to add a force cast to the destination custom error // type. If this is for an Error? parameter, we'll need to add parens around // the cast to silence a compiler warning about force casting never // producing nil. auto RequiresParens = HandlerDesc.getErrorParam().has_value(); if (RequiresParens) OS << tok::l_paren; OS << ErrorName << " " << tok::kw_as << tok::exclaim_postfix << " "; ErrorType->lookThroughSingleOptionalType()->print(OS); if (RequiresParens) OS << tok::r_paren; } /// If \p T has a natural default value like \c nil for \c Optional or \c () /// for \c Void, add that default value to the output. Otherwise, add a /// placeholder that contains \p T's name as the hint. void addDefaultValueOrPlaceholder(Type T) { if (T->isOptional()) { OS << tok::kw_nil; } else if (T->isVoid()) { OS << "()"; } else { OS << "<#"; T.print(OS); OS << "#>"; } } /// Adds the \c Index -th parameter to the completion handler described by \p /// HanderDesc. /// If \p ResultName is not empty, it is assumed that a variable with that /// name contains the result returned from the async alternative. If the /// callback also takes an error parameter, \c nil passed to the completion /// handler for the error. If \p ResultName is empty, it is a assumed that a /// variable named 'error' contains the error thrown from the async method and /// 'nil' will be passed to the completion handler for all result parameters. void addCompletionHandlerArgument(size_t Index, StringRef ResultName, const AsyncHandlerDesc &HandlerDesc) { if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) { // The error parameter is the last argument of the completion handler. if (ResultName.empty()) { addForwardedErrorArgument("error", HandlerDesc); } else { addDefaultValueOrPlaceholder(HandlerDesc.params()[Index].getPlainType()); } } else { if (ResultName.empty()) { addDefaultValueOrPlaceholder(HandlerDesc.params()[Index].getPlainType()); } else if (HandlerDesc .getSuccessParamAsyncReturnType( HandlerDesc.params()[Index].getPlainType()) ->isVoid()) { // Void return types are not returned by the async function, synthesize // a Void instance. OS << tok::l_paren << tok::r_paren; } else if (HandlerDesc.getSuccessParams().size() > 1) { // If the async method returns a tuple, we need to pass its elements to // the completion handler separately. For example: // // func foo() async -> (String, Int) {} // // causes the following legacy body to be created: // // func foo(completion: (String, Int) -> Void) { // Task { // let result = await foo() // completion(result.0, result.1) // } // } OS << ResultName << tok::period; auto Label = HandlerDesc.getAsyncReturnTypeLabel(Index); if (!Label.empty()) { OS << Label; } else { OS << Index; } } else { OS << ResultName; } } } /// Add a call to the completion handler named \p HandlerName and described by /// \p HandlerDesc, passing all the required arguments. See \c /// getCompletionHandlerArgument for how the arguments are synthesized. void addCallToCompletionHandler(StringRef ResultName, const AsyncHandlerDesc &HandlerDesc, StringRef HandlerName) { OS << HandlerName << tok::l_paren; // Construct arguments to pass to the completion handler switch (HandlerDesc.Type) { case HandlerType::INVALID: llvm_unreachable("Cannot be rewritten"); break; case HandlerType::PARAMS: { for (size_t I = 0; I < HandlerDesc.params().size(); ++I) { if (I > 0) { OS << tok::comma << " "; } addCompletionHandlerArgument(I, ResultName, HandlerDesc); } break; } case HandlerType::RESULT: { if (!ResultName.empty()) { OS << tok::period_prefix << "success" << tok::l_paren; if (!HandlerDesc.willAsyncReturnVoid()) { OS << ResultName; } else { OS << tok::l_paren << tok::r_paren; } OS << tok::r_paren; } else { OS << tok::period_prefix << "failure" << tok::l_paren; addForwardedErrorArgument("error", HandlerDesc); OS << tok::r_paren; } break; } } OS << tok::r_paren; // Close the call to the completion handler } /// Adds the result type of a refactored async function that previously /// returned results via a completion handler described by \p HandlerDesc. void addAsyncFuncReturnType(const AsyncHandlerDesc &HandlerDesc) { // Type or (Type1, Type2, ...) SmallVector Scratch; auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch); if (ReturnTypes.empty()) { OS << "Void"; } else { addTupleOf(ReturnTypes, OS, [&](LabeledReturnType LabelAndType) { if (!LabelAndType.Label.empty()) { OS << LabelAndType.Label << tok::colon << " "; } LabelAndType.Ty->print(OS); }); } } /// If \p FD is generic, adds a type annotation with the return type of the /// converted async function. This is used when creating a legacy function, /// calling the converted 'async' function so that the generic parameters of /// the legacy function are passed to the generic function. For example for /// \code /// func foo() async -> GenericParam {} /// \endcode /// we generate /// \code /// func foo(completion: (GenericParam) -> Void) { /// Task { /// let result: GenericParam = await foo() /// <------------> /// completion(result) /// } /// } /// \endcode /// This function adds the range marked by \c <-----> void addResultTypeAnnotationIfNecessary(const FuncDecl *FD, const AsyncHandlerDesc &HandlerDesc) { if (FD->isGeneric()) { OS << tok::colon << " "; addAsyncFuncReturnType(HandlerDesc); } } }; } // namespace asyncrefactorings bool RefactoringActionConvertCallToAsyncAlternative::isApplicable( const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) { using namespace asyncrefactorings; // Currently doesn't check that the call is in an async context. This seems // possibly useful in some situations, so we'll see what the feedback is. // May need to change in the future auto *CE = findOuterCall(CursorInfo); if (!CE) return false; auto HandlerDesc = AsyncHandlerParamDesc::find( getUnderlyingFunc(CE->getFn()), /*RequireAttributeOrName=*/false); return HandlerDesc.isValid(); } /// Converts a call of a function with a possible async alternative, to use it /// instead. Currently this is any function that /// 1. has a void return type, /// 2. has a void returning closure as its last parameter, and /// 3. is not already async /// /// For now the call need not be in an async context, though this may change /// depending on feedback. bool RefactoringActionConvertCallToAsyncAlternative::performChange() { using namespace asyncrefactorings; auto *CE = findOuterCall(CursorInfo); assert(CE && "Should not run performChange when refactoring is not applicable"); // Find the scope this call is in ContextFinder Finder( *CursorInfo.getSourceFile(), CursorInfo.getLoc(), [](ASTNode N) { return N.isStmt(StmtKind::Brace) && !N.isImplicit(); }); Finder.resolve(); auto Scopes = Finder.getContexts(); BraceStmt *Scope = nullptr; if (!Scopes.empty()) Scope = cast(Scopes.back().get()); AsyncConverter Converter(TheFile, SM, DiagEngine, CE, Scope); if (!Converter.convert()) return true; Converter.replace(CE, EditConsumer); return false; } bool RefactoringActionConvertToAsync::isApplicable( const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) { using namespace asyncrefactorings; // As with the call refactoring, should possibly only apply if there's // actually calls to async alternatives. At the moment this will just add // `async` if there are no calls, which is probably fine. return findFunction(CursorInfo); } /// Converts a whole function to async, converting any calls to functions with /// async alternatives as above. bool RefactoringActionConvertToAsync::performChange() { using namespace asyncrefactorings; auto *FD = findFunction(CursorInfo); assert(FD && "Should not run performChange when refactoring is not applicable"); auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false); AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc); if (!Converter.convert()) return true; Converter.replace(FD, EditConsumer, FD->getSourceRangeIncludingAttrs().Start); return false; } bool RefactoringActionAddAsyncAlternative::isApplicable( const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) { using namespace asyncrefactorings; auto *FD = findFunction(CursorInfo); if (!FD) return false; auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false); return HandlerDesc.isValid(); } /// Adds an async alternative and marks the current function as deprecated. /// Equivalent to the conversion but /// 1. only works on functions that themselves are a possible async /// alternative, and /// 2. has extra handling to convert the completion/handler/callback closure /// parameter to either `return`/`throws` bool RefactoringActionAddAsyncAlternative::performChange() { using namespace asyncrefactorings; auto *FD = findFunction(CursorInfo); assert(FD && "Should not run performChange when refactoring is not applicable"); auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false); assert(HandlerDesc.isValid() && "Should not run performChange when refactoring is not applicable"); AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc); if (!Converter.convert()) return true; // Add a reference to the async function so that warnings appear when the // synchronous function is used in an async context SmallString<128> AvailabilityAttr = HandlerDesc.buildRenamedAttribute(); EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false), AvailabilityAttr); AsyncConverter LegacyBodyCreator(TheFile, SM, DiagEngine, FD, HandlerDesc); if (LegacyBodyCreator.createLegacyBody()) { LegacyBodyCreator.replace(FD->getBody(), EditConsumer); } // Add the async alternative Converter.insertAfter(FD, EditConsumer); return false; } bool RefactoringActionAddAsyncWrapper::isApplicable( const ResolvedCursorInfo &CursorInfo, DiagnosticEngine &Diag) { using namespace asyncrefactorings; auto *FD = findFunction(CursorInfo); if (!FD) return false; auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false); return HandlerDesc.isValid(); } bool RefactoringActionAddAsyncWrapper::performChange() { using namespace asyncrefactorings; auto *FD = findFunction(CursorInfo); assert(FD && "Should not run performChange when refactoring is not applicable"); auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*RequireAttributeOrName=*/false); assert(HandlerDesc.isValid() && "Should not run performChange when refactoring is not applicable"); AsyncConverter Converter(TheFile, SM, DiagEngine, FD, HandlerDesc); if (!Converter.createAsyncWrapper()) return true; // Add a reference to the async function so that warnings appear when the // synchronous function is used in an async context SmallString<128> AvailabilityAttr = HandlerDesc.buildRenamedAttribute(); EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false), AvailabilityAttr); // Add the async wrapper. Converter.insertAfter(FD, EditConsumer); return false; } static MacroExpansionExpr * findMacroExpansionTargetExpr(const ResolvedCursorInfo &Info) { // Handle '#' position in '#macroName(...)'. if (auto exprInfo = dyn_cast(&Info)) { if (auto target = dyn_cast_or_null(exprInfo->getTrailingExpr())) if (target->getRewritten()) return target; return nullptr; } // Handle 'macroName' position in '#macroName(...)'. if (auto refInfo = dyn_cast(&Info)) { if (refInfo->isRef() && isa_and_nonnull(refInfo->getValueD())) { ContextFinder Finder( *Info.getSourceFile(), Info.getLoc(), [&](ASTNode N) { auto *expr = dyn_cast_or_null(N.dyn_cast()); return expr && (expr->getMacroNameLoc().getBaseNameLoc() == Info.getLoc()); }); Finder.resolve(); if (!Finder.getContexts().empty()) { auto *target = dyn_cast(Finder.getContexts()[0].get()); if (target->getRewritten()) return target; } } return nullptr; } // TODO: handle MacroExpansionDecl. return nullptr; } bool RefactoringActionExpandMacro::isApplicable(const ResolvedCursorInfo &Info, DiagnosticEngine &Diag) { return findMacroExpansionTargetExpr(Info) != nullptr; } bool RefactoringActionExpandMacro::performChange() { auto target = findMacroExpansionTargetExpr(CursorInfo); if (!target) return true; auto exprRange = Lexer::getCharSourceRangeFromSourceRange(SM, target->getSourceRange()); auto rewrittenRange = Lexer::getCharSourceRangeFromSourceRange( SM, target->getRewritten()->getSourceRange()); auto rewrittenBuffer = SM.extractText(rewrittenRange); EditConsumer.accept(SM, exprRange, rewrittenBuffer); return false; } } // end of anonymous namespace StringRef swift::ide:: getDescriptiveRefactoringKindName(RefactoringKind Kind) { switch(Kind) { case RefactoringKind::None: llvm_unreachable("Should be a valid refactoring kind"); #define REFACTORING(KIND, NAME, ID) case RefactoringKind::KIND: return NAME; #include "swift/Refactoring/RefactoringKinds.def" } llvm_unreachable("unhandled kind"); } StringRef swift::ide:: getDescriptiveRenameUnavailableReason(RenameAvailableKind Kind) { switch(Kind) { case RenameAvailableKind::Available: return ""; case RenameAvailableKind::Unavailable_system_symbol: return "symbol from system module cannot be renamed"; case RenameAvailableKind::Unavailable_has_no_location: return "symbol without a declaration location cannot be renamed"; case RenameAvailableKind::Unavailable_has_no_name: return "cannot find the name of the symbol"; case RenameAvailableKind::Unavailable_has_no_accessibility: return "cannot decide the accessibility of the symbol"; case RenameAvailableKind::Unavailable_decl_from_clang: return "cannot rename a Clang symbol from its Swift reference"; } llvm_unreachable("unhandled kind"); } SourceLoc swift::ide::RangeConfig::getStart(SourceManager &SM) { return SM.getLocForLineCol(BufferId, Line, Column); } SourceLoc swift::ide::RangeConfig::getEnd(SourceManager &SM) { return getStart(SM).getAdvancedLoc(Length); } struct swift::ide::FindRenameRangesAnnotatingConsumer::Implementation { std::unique_ptr pRewriter; Implementation(SourceManager &SM, unsigned BufferId, raw_ostream &OS) : pRewriter(new SourceEditOutputConsumer(SM, BufferId, OS)) {} static StringRef tag(RefactoringRangeKind Kind) { switch (Kind) { case RefactoringRangeKind::BaseName: return "base"; case RefactoringRangeKind::KeywordBaseName: return "keywordBase"; case RefactoringRangeKind::ParameterName: return "param"; case RefactoringRangeKind::NoncollapsibleParameterName: return "noncollapsibleparam"; case RefactoringRangeKind::DeclArgumentLabel: return "arglabel"; case RefactoringRangeKind::CallArgumentLabel: return "callarg"; case RefactoringRangeKind::CallArgumentColon: return "callcolon"; case RefactoringRangeKind::CallArgumentCombined: return "callcombo"; case RefactoringRangeKind::SelectorArgumentLabel: return "sel"; } llvm_unreachable("unhandled kind"); } void accept(SourceManager &SM, const RenameRangeDetail &Range) { std::string NewText; llvm::raw_string_ostream OS(NewText); StringRef Tag = tag(Range.RangeKind); OS << "<" << Tag; if (Range.Index.has_value()) OS << " index=" << *Range.Index; OS << ">" << Range.Range.str() << ""; pRewriter->accept(SM, {Range.Range, OS.str(), {}}); } }; swift::ide::FindRenameRangesAnnotatingConsumer:: FindRenameRangesAnnotatingConsumer(SourceManager &SM, unsigned BufferId, raw_ostream &OS) : Impl(*new Implementation(SM, BufferId, OS)) {} swift::ide::FindRenameRangesAnnotatingConsumer::~FindRenameRangesAnnotatingConsumer() { delete &Impl; } void swift::ide::FindRenameRangesAnnotatingConsumer:: accept(SourceManager &SM, RegionType RegionType, ArrayRef Ranges) { if (RegionType == RegionType::Mismatch || RegionType == RegionType::Unmatched) return; for (const auto &Range : Ranges) { Impl.accept(SM, Range); } } void swift::ide::collectRenameAvailabilityInfo( const ValueDecl *VD, Optional RefInfo, SmallVectorImpl &Infos) { RenameAvailableKind AvailKind = RenameAvailableKind::Available; if (getRelatedSystemDecl(VD)){ AvailKind = RenameAvailableKind::Unavailable_system_symbol; } else if (VD->getClangDecl()) { AvailKind = RenameAvailableKind::Unavailable_decl_from_clang; } else if (VD->getStartLoc().isInvalid()) { AvailKind = RenameAvailableKind::Unavailable_has_no_location; } else if (!VD->hasName()) { AvailKind = RenameAvailableKind::Unavailable_has_no_name; } if (isa(VD)) { // Disallow renaming accessors. if (isa(VD)) return; // Disallow renaming deinit. if (isa(VD)) return; // Disallow renaming init with no arguments. if (auto CD = dyn_cast(VD)) { if (!CD->getParameters()->size()) return; if (RefInfo && !RefInfo->IsArgLabel) { NameMatcher Matcher(*(RefInfo->SF)); auto Resolved = Matcher.resolve({RefInfo->Loc, /*ResolveArgs*/true}); if (Resolved.LabelRanges.empty()) return; } } // Disallow renaming 'callAsFunction' method with no arguments. if (auto FD = dyn_cast(VD)) { // FIXME: syntactic rename can only decide by checking the spelling, not // whether it's an instance method, so we do the same here for now. if (FD->getBaseIdentifier() == FD->getASTContext().Id_callAsFunction) { if (!FD->getParameters()->size()) return; if (RefInfo && !RefInfo->IsArgLabel) { NameMatcher Matcher(*(RefInfo->SF)); auto Resolved = Matcher.resolve({RefInfo->Loc, /*ResolveArgs*/true}); if (Resolved.LabelRanges.empty()) return; } } } } // Always return local rename for parameters. // FIXME: if the cursor is on the argument, we should return global rename. if (isa(VD)) { Infos.emplace_back(RefactoringKind::LocalRename, AvailKind); return; } // If the indexer considers VD a global symbol, then we apply global rename. if (index::isLocalSymbol(VD)) Infos.emplace_back(RefactoringKind::LocalRename, AvailKind); else Infos.emplace_back(RefactoringKind::GlobalRename, AvailKind); } void swift::ide::collectAvailableRefactorings( const ResolvedCursorInfo &CursorInfo, SmallVectorImpl &Kinds, bool ExcludeRename) { DiagnosticEngine DiagEngine( CursorInfo.getSourceFile()->getASTContext().SourceMgr); if (!ExcludeRename) { if (RefactoringActionLocalRename::isApplicable(CursorInfo, DiagEngine)) Kinds.push_back(RefactoringKind::LocalRename); switch (CursorInfo.getKind()) { case CursorInfoKind::ModuleRef: case CursorInfoKind::Invalid: case CursorInfoKind::StmtStart: case CursorInfoKind::ExprStart: break; case CursorInfoKind::ValueRef: { auto ValueRefInfo = cast(CursorInfo); Optional RefInfo; if (ValueRefInfo.isRef()) RefInfo = {CursorInfo.getSourceFile(), CursorInfo.getLoc(), ValueRefInfo.isKeywordArgument()}; auto RenameOp = getAvailableRenameForDecl(ValueRefInfo.getValueD(), RefInfo); if (RenameOp.has_value() && RenameOp.value() == RefactoringKind::GlobalRename) Kinds.push_back(RenameOp.value()); } } } #define CURSOR_REFACTORING(KIND, NAME, ID) \ if (RefactoringKind::KIND != RefactoringKind::LocalRename && \ RefactoringAction##KIND::isApplicable(CursorInfo, DiagEngine)) \ Kinds.push_back(RefactoringKind::KIND); #include "swift/Refactoring/RefactoringKinds.def" } void swift::ide::collectAvailableRefactorings( SourceFile *SF, RangeConfig Range, bool &CollectRangeStartRefactorings, SmallVectorImpl &Kinds, ArrayRef DiagConsumers) { if (Range.Length == 0) { return collectAvailableRefactoringsAtCursor(SF, Range.Line, Range.Column, Kinds, DiagConsumers); } // Prepare the tool box. ASTContext &Ctx = SF->getASTContext(); SourceManager &SM = Ctx.SourceMgr; DiagnosticEngine DiagEngine(SM); std::for_each(DiagConsumers.begin(), DiagConsumers.end(), [&](DiagnosticConsumer *Con) { DiagEngine.addConsumer(*Con); }); ResolvedRangeInfo Result = evaluateOrDefault(SF->getASTContext().evaluator, RangeInfoRequest(RangeInfoOwner({SF, Range.getStart(SF->getASTContext().SourceMgr), Range.getEnd(SF->getASTContext().SourceMgr)})), ResolvedRangeInfo()); bool enableInternalRefactoring = getenv("SWIFT_ENABLE_INTERNAL_REFACTORING_ACTIONS"); #define RANGE_REFACTORING(KIND, NAME, ID) \ if (RefactoringAction##KIND::isApplicable(Result, DiagEngine)) \ Kinds.push_back(RefactoringKind::KIND); #define INTERNAL_RANGE_REFACTORING(KIND, NAME, ID) \ if (enableInternalRefactoring) \ RANGE_REFACTORING(KIND, NAME, ID) #include "swift/Refactoring/RefactoringKinds.def" CollectRangeStartRefactorings = collectRangeStartRefactorings(Result); } bool swift::ide:: refactorSwiftModule(ModuleDecl *M, RefactoringOptions Opts, SourceEditConsumer &EditConsumer, DiagnosticConsumer &DiagConsumer) { assert(Opts.Kind != RefactoringKind::None && "should have a refactoring kind."); // Use the default name if not specified. if (Opts.PreferredName.empty()) { Opts.PreferredName = getDefaultPreferredName(Opts.Kind).str(); } switch (Opts.Kind) { #define SEMANTIC_REFACTORING(KIND, NAME, ID) \ case RefactoringKind::KIND: { \ RefactoringAction##KIND Action(M, Opts, EditConsumer, DiagConsumer); \ if (RefactoringKind::KIND == RefactoringKind::LocalRename || \ Action.isApplicable()) \ return Action.performChange(); \ return true; \ } #include "swift/Refactoring/RefactoringKinds.def" case RefactoringKind::GlobalRename: case RefactoringKind::FindGlobalRenameRanges: case RefactoringKind::FindLocalRenameRanges: llvm_unreachable("not a valid refactoring kind"); case RefactoringKind::None: llvm_unreachable("should not enter here."); } llvm_unreachable("unhandled kind"); } static std::vector resolveRenameLocations(ArrayRef RenameLocs, SourceFile &SF, DiagnosticEngine &Diags) { SourceManager &SM = SF.getASTContext().SourceMgr; unsigned BufferID = SF.getBufferID().value(); std::vector UnresolvedLocs; for (const RenameLoc &RenameLoc : RenameLocs) { DeclNameViewer OldName(RenameLoc.OldName); SourceLoc Location = SM.getLocForLineCol(BufferID, RenameLoc.Line, RenameLoc.Column); if (!OldName.isValid()) { Diags.diagnose(Location, diag::invalid_name, RenameLoc.OldName); return {}; } if (!RenameLoc.NewName.empty()) { DeclNameViewer NewName(RenameLoc.NewName); ArrayRef ParamNames = NewName.args(); bool newOperator = Lexer::isOperator(NewName.base()); bool NewNameIsValid = NewName.isValid() && (Lexer::isIdentifier(NewName.base()) || newOperator) && std::all_of(ParamNames.begin(), ParamNames.end(), [](StringRef Label) { return Label.empty() || Lexer::isIdentifier(Label); }); if (!NewNameIsValid) { Diags.diagnose(Location, diag::invalid_name, RenameLoc.NewName); return {}; } if (NewName.partsCount() != OldName.partsCount()) { Diags.diagnose(Location, diag::arity_mismatch, RenameLoc.NewName, RenameLoc.OldName); return {}; } if (RenameLoc.Usage == NameUsage::Call && !RenameLoc.IsFunctionLike) { Diags.diagnose(Location, diag::name_not_functionlike, RenameLoc.NewName); return {}; } } bool isOperator = Lexer::isOperator(OldName.base()); UnresolvedLocs.push_back({ Location, (RenameLoc.Usage == NameUsage::Unknown || (RenameLoc.Usage == NameUsage::Call && !isOperator)) }); } NameMatcher Resolver(SF); return Resolver.resolve(UnresolvedLocs, SF.getAllTokens()); } int swift::ide::syntacticRename(SourceFile *SF, ArrayRef RenameLocs, SourceEditConsumer &EditConsumer, DiagnosticConsumer &DiagConsumer) { assert(SF && "null source file"); SourceManager &SM = SF->getASTContext().SourceMgr; DiagnosticEngine DiagEngine(SM); DiagEngine.addConsumer(DiagConsumer); auto ResolvedLocs = resolveRenameLocations(RenameLocs, *SF, DiagEngine); if (ResolvedLocs.size() != RenameLocs.size()) return true; // Already diagnosed. size_t index = 0; llvm::StringSet<> ReplaceTextContext; for(const RenameLoc &Rename: RenameLocs) { ResolvedLoc &Resolved = ResolvedLocs[index++]; TextReplacementsRenamer Renamer(SM, Rename.OldName, Rename.NewName, ReplaceTextContext); RegionType Type = Renamer.addSyntacticRenameRanges(Resolved, Rename); if (Type == RegionType::Mismatch) { DiagEngine.diagnose(Resolved.Range.getStart(), diag::mismatched_rename, Rename.NewName); EditConsumer.accept(SM, Type, None); } else { EditConsumer.accept(SM, Type, Renamer.getReplacements()); } } return false; } int swift::ide::findSyntacticRenameRanges( SourceFile *SF, ArrayRef RenameLocs, FindRenameRangesConsumer &RenameConsumer, DiagnosticConsumer &DiagConsumer) { assert(SF && "null source file"); SourceManager &SM = SF->getASTContext().SourceMgr; DiagnosticEngine DiagEngine(SM); DiagEngine.addConsumer(DiagConsumer); auto ResolvedLocs = resolveRenameLocations(RenameLocs, *SF, DiagEngine); if (ResolvedLocs.size() != RenameLocs.size()) return true; // Already diagnosed. size_t index = 0; for (const RenameLoc &Rename : RenameLocs) { ResolvedLoc &Resolved = ResolvedLocs[index++]; RenameRangeDetailCollector Renamer(SM, Rename.OldName); RegionType Type = Renamer.addSyntacticRenameRanges(Resolved, Rename); if (Type == RegionType::Mismatch) { DiagEngine.diagnose(Resolved.Range.getStart(), diag::mismatched_rename, Rename.NewName); RenameConsumer.accept(SM, Type, None); } else { RenameConsumer.accept(SM, Type, Renamer.Ranges); } } return false; } int swift::ide::findLocalRenameRanges( SourceFile *SF, RangeConfig Range, FindRenameRangesConsumer &RenameConsumer, DiagnosticConsumer &DiagConsumer) { assert(SF && "null source file"); SourceManager &SM = SF->getASTContext().SourceMgr; DiagnosticEngine Diags(SM); Diags.addConsumer(DiagConsumer); auto StartLoc = Lexer::getLocForStartOfToken(SM, Range.getStart(SM)); ResolvedCursorInfo CursorInfo = evaluateOrDefault(SF->getASTContext().evaluator, CursorInfoRequest{CursorInfoOwner(SF, StartLoc)}, ResolvedCursorInfo()); auto ValueRefCursorInfo = dyn_cast(&CursorInfo); if (!ValueRefCursorInfo || !ValueRefCursorInfo->getValueD()) { Diags.diagnose(StartLoc, diag::unresolved_location); return true; } ValueDecl *VD = ValueRefCursorInfo->typeOrValue(); Optional RefInfo; if (ValueRefCursorInfo->isRef()) RefInfo = {CursorInfo.getSourceFile(), CursorInfo.getLoc(), ValueRefCursorInfo->isKeywordArgument()}; llvm::SmallVector Scopes; analyzeRenameScope(VD, RefInfo, Diags, Scopes); if (Scopes.empty()) return true; RenameRangeCollector RangeCollector(VD, StringRef()); for (DeclContext *DC : Scopes) indexDeclContext(DC, RangeCollector); return findSyntacticRenameRanges(SF, RangeCollector.results(), RenameConsumer, DiagConsumer); }