//===----------------------------------------------------------------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// #include "ExtractExprBase.h" #include "RefactoringActions.h" #include "Renamer.h" #include "Utils.h" #include "swift/AST/DiagnosticsRefactoring.h" #include "swift/IDETool/CompilerInvocation.h" using namespace swift::refactoring; static Type sanitizeType(Type Ty) { // Transform lvalue type to inout type so that we can print it properly. return Ty.transform([](Type Ty) { if (Ty->is()) { return Type(InOutType::get(Ty->getRValueType()->getCanonicalType())); } return Ty; }); } static SourceLoc getNewFuncInsertLoc(DeclContext *DC, DeclContext *&InsertToContext) { if (auto D = DC->getInnermostDeclarationDeclContext()) { // If extracting from a getter/setter, we should skip both the immediate // getter/setter function and the individual var decl. The pattern binding // decl is the position before which we should insert the newly extracted // function. if (auto *FD = dyn_cast(D)) { ValueDecl *SD = FD->getStorage(); switch (SD->getKind()) { case DeclKind::Var: if (auto *PBD = cast(SD)->getParentPatternBinding()) D = PBD; break; case DeclKind::Subscript: D = SD; break; default: break; } } auto Result = D->getStartLoc(); assert(Result.isValid()); // The insert loc should be before every decl attributes. for (auto Attr : D->getAttrs()) { auto Loc = Attr->getRangeWithAt().Start; if (Loc.isValid() && Loc.getOpaquePointerValue() < Result.getOpaquePointerValue()) Result = Loc; } // The insert loc should be before the doc comments associated with this // decl. if (!D->getRawComment().Comments.empty()) { auto Loc = D->getRawComment().Comments.front().Range.getStart(); if (Loc.isValid() && Loc.getOpaquePointerValue() < Result.getOpaquePointerValue()) { Result = Loc; } } InsertToContext = D->getDeclContext(); return Result; } return SourceLoc(); } static std::vector getNotableRegions(StringRef SourceText, unsigned NameOffset, StringRef Name, bool IsFunctionLike = false) { auto InputBuffer = llvm::MemoryBuffer::getMemBufferCopy(SourceText, ""); CompilerInvocation Invocation{}; Invocation.getFrontendOptions().InputsAndOutputs.addInput( InputFile("", true, InputBuffer.get(), file_types::TY_Swift)); Invocation.getFrontendOptions().ModuleName = "extract"; Invocation.getLangOptions().DisablePoundIfEvaluation = true; auto Instance = std::make_unique(); std::string InstanceSetupError; if (Instance->setup(Invocation, InstanceSetupError)) llvm_unreachable(InstanceSetupError.c_str()); unsigned BufferId = Instance->getPrimarySourceFile()->getBufferID().value(); SourceManager &SM = Instance->getSourceMgr(); SourceLoc NameLoc = SM.getLocForOffset(BufferId, NameOffset); auto LineAndCol = SM.getLineAndColumnInBuffer(NameLoc); NameMatcher Matcher(*Instance->getPrimarySourceFile()); auto Resolved = Matcher.resolve(llvm::makeArrayRef(NameLoc), llvm::None); assert(!Resolved.empty() && "Failed to resolve generated func name loc"); RenameLoc RenameConfig = {LineAndCol.first, LineAndCol.second, NameUsage::Definition, /*OldName=*/Name, IsFunctionLike}; RenameRangeDetailCollector Renamer(SM, Name); Renamer.addSyntacticRenameRanges(Resolved.back(), RenameConfig); auto Ranges = Renamer.Ranges; std::vector NoteRegions(Renamer.Ranges.size()); llvm::transform(Ranges, NoteRegions.begin(), [&SM](RenameRangeDetail &Detail) -> NoteRegion { auto Start = SM.getLineAndColumnInBuffer(Detail.Range.getStart()); auto End = SM.getLineAndColumnInBuffer(Detail.Range.getEnd()); return {Detail.RangeKind, Start.first, Start.second, End.first, End.second, Detail.Index}; }); return NoteRegions; } bool RefactoringActionExtractFunction::isApplicable( const ResolvedRangeInfo &Info, DiagnosticEngine &Diag) { switch (Info.Kind) { case RangeKind::PartOfExpression: case RangeKind::SingleDecl: case RangeKind::MultiTypeMemberDecl: case RangeKind::Invalid: return false; case RangeKind::SingleExpression: case RangeKind::SingleStatement: case RangeKind::MultiStatement: { return checkExtractConditions(Info, Diag) .success({CannotExtractReason::VoidType}); } } llvm_unreachable("unhandled kind"); } bool RefactoringActionExtractFunction::performChange() { // Check if the new name is ok. if (!Lexer::isIdentifier(PreferredName)) { DiagEngine.diagnose(SourceLoc(), diag::invalid_name, PreferredName); return true; } DeclContext *DC = RangeInfo.RangeContext; DeclContext *InsertToDC = nullptr; SourceLoc InsertLoc = getNewFuncInsertLoc(DC, InsertToDC); // Complain about no inserting position. if (InsertLoc.isInvalid()) { DiagEngine.diagnose(SourceLoc(), diag::no_insert_position); return true; } // Correct the given name if collision happens. PreferredName = correctNewDeclName(InsertToDC, PreferredName); // Collect the parameters to pass down to the new function. std::vector Parameters; for (auto &RD : RangeInfo.ReferencedDecls) { // If the referenced decl is declared elsewhere, no need to pass as // parameter if (RD.VD->getDeclContext() != DC) continue; // We don't need to pass down implicitly declared variables, e.g. error in // a catch block. if (RD.VD->isImplicit()) { SourceLoc Loc = RD.VD->getStartLoc(); if (Loc.isValid() && SM.isBeforeInBuffer(RangeInfo.ContentRange.getStart(), Loc) && SM.isBeforeInBuffer(Loc, RangeInfo.ContentRange.getEnd())) continue; } // If the referenced decl is declared inside the range, no need to pass // as parameter. if (RangeInfo.DeclaredDecls.end() != std::find_if(RangeInfo.DeclaredDecls.begin(), RangeInfo.DeclaredDecls.end(), [RD](DeclaredDecl DD) { return RD.VD == DD.VD; })) continue; // We don't need to pass down self. if (auto PD = dyn_cast(RD.VD)) { if (PD->isSelfParameter()) { continue; } } Parameters.emplace_back(RD.VD, sanitizeType(RD.Ty)); } SmallString<64> Buffer; unsigned FuncBegin = Buffer.size(); unsigned FuncNameOffset; { llvm::raw_svector_ostream OS(Buffer); if (!InsertToDC->isLocalContext()) { // Default to be file private. OS << tok::kw_fileprivate << " "; } // Inherit static if the containing function is. if (DC->getContextKind() == DeclContextKind::AbstractFunctionDecl) { if (auto FD = dyn_cast(static_cast(DC))) { if (FD->isStatic()) { OS << tok::kw_static << " "; } } } OS << tok::kw_func << " "; FuncNameOffset = Buffer.size() - FuncBegin; OS << PreferredName; OS << "("; for (auto &RD : Parameters) { OS << "_ " << RD.VD->getBaseName().userFacingName() << ": "; RD.Ty->reconstituteSugar(/*Recursive*/ true)->print(OS); if (&RD != &Parameters.back()) OS << ", "; } OS << ")"; if (RangeInfo.UnhandledEffects.contains(EffectKind::Async)) OS << " async"; if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws)) OS << " " << tok::kw_throws; bool InsertedReturnType = false; if (auto Ty = RangeInfo.getType()) { // If the type of the range is not void, specify the return type. if (!Ty->isVoid()) { OS << " " << tok::arrow << " "; sanitizeType(Ty)->reconstituteSugar(/*Recursive*/ true)->print(OS); InsertedReturnType = true; } } OS << " {\n"; // Add "return" if the extracted entity is an expression. if (RangeInfo.Kind == RangeKind::SingleExpression && InsertedReturnType) OS << tok::kw_return << " "; OS << RangeInfo.ContentRange.str() << "\n}\n\n"; } unsigned FuncEnd = Buffer.size(); unsigned ReplaceBegin = Buffer.size(); unsigned CallNameOffset; { llvm::raw_svector_ostream OS(Buffer); if (RangeInfo.exit() == ExitState::Positive) OS << tok::kw_return << " "; if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws)) OS << tok::kw_try << " "; if (RangeInfo.UnhandledEffects.contains(EffectKind::Async)) OS << "await "; CallNameOffset = Buffer.size() - ReplaceBegin; OS << PreferredName << "("; for (auto &RD : Parameters) { // Inout argument needs "&". if (RD.Ty->is()) OS << "&"; OS << RD.VD->getBaseName().userFacingName(); if (&RD != &Parameters.back()) OS << ", "; } OS << ")"; } unsigned ReplaceEnd = Buffer.size(); std::string ExtractedFuncName = PreferredName.str() + "("; for (size_t i = 0; i < Parameters.size(); ++i) { ExtractedFuncName += "_:"; } ExtractedFuncName += ")"; StringRef DeclStr(Buffer.begin() + FuncBegin, FuncEnd - FuncBegin); auto NotableFuncRegions = getNotableRegions(DeclStr, FuncNameOffset, ExtractedFuncName, /*IsFunctionLike=*/true); StringRef CallStr(Buffer.begin() + ReplaceBegin, ReplaceEnd - ReplaceBegin); auto NotableCallRegions = getNotableRegions(CallStr, CallNameOffset, ExtractedFuncName, /*IsFunctionLike=*/true); // Insert the new function's declaration. EditConsumer.accept(SM, InsertLoc, DeclStr, NotableFuncRegions); // Replace the code to extract with the function call. EditConsumer.accept(SM, RangeInfo.ContentRange, CallStr, NotableCallRegions); return false; }