[refactoring] Implement "Convert to Trailing Closure" refactoring action

This commit is contained in:
Rintaro Ishizaki
2017-10-16 00:24:57 -07:00
parent 53e514518b
commit a57199cf68
12 changed files with 336 additions and 9 deletions

View File

@@ -2712,6 +2712,12 @@ public:
Expr *getSubExpr() const { return SubExpr; } Expr *getSubExpr() const { return SubExpr; }
void setSubExpr(Expr *e) { SubExpr = e; } void setSubExpr(Expr *e) { SubExpr = e; }
Expr *getSyntacticSubExpr() const {
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(SubExpr))
return ICE->getSyntacticSubExpr();
return SubExpr;
}
static bool classof(const Expr *E) { static bool classof(const Expr *E) {
return E->getKind() >= ExprKind::First_ImplicitConversionExpr && return E->getKind() >= ExprKind::First_ImplicitConversionExpr &&
E->getKind() <= ExprKind::Last_ImplicitConversionExpr; E->getKind() <= ExprKind::Last_ImplicitConversionExpr;

View File

@@ -44,6 +44,8 @@ CURSOR_REFACTORING(CollapseNestedIfExpr, "Collapse Nested If Expression", collap
CURSOR_REFACTORING(ConvertToDoCatch, "Convert To Do/Catch", convert.do.catch) CURSOR_REFACTORING(ConvertToDoCatch, "Convert To Do/Catch", convert.do.catch)
CURSOR_REFACTORING(TrailingClosure, "Convert To Trailing Closure", trailingclosure)
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr) RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function) RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)

View File

@@ -40,22 +40,28 @@ class ContextFinder : public SourceEntityWalker {
SourceFile &SF; SourceFile &SF;
ASTContext &Ctx; ASTContext &Ctx;
SourceManager &SM; SourceManager &SM;
ASTNode Target; SourceRange Target;
llvm::function_ref<bool(ASTNode)> IsContext; llvm::function_ref<bool(ASTNode)> IsContext;
SmallVector<ASTNode, 4> AllContexts; SmallVector<ASTNode, 4> AllContexts;
bool contains(ASTNode Enclosing) { bool contains(ASTNode Enclosing) {
auto Result = SM.rangeContains(Enclosing.getSourceRange(), auto Result = SM.rangeContains(Enclosing.getSourceRange(), Target);
Target.getSourceRange());
if (Result && IsContext(Enclosing)) if (Result && IsContext(Enclosing))
AllContexts.push_back(Enclosing); AllContexts.push_back(Enclosing);
return Result; return Result;
} }
public: public:
ContextFinder(SourceFile &SF, ASTNode Target, ContextFinder(SourceFile &SF, ASTNode TargetNode,
llvm::function_ref<bool(ASTNode)> IsContext = llvm::function_ref<bool(ASTNode)> IsContext =
[](ASTNode N) { return true; }) : SF(SF), [](ASTNode N) { return true; }) :
Ctx(SF.getASTContext()), SM(Ctx.SourceMgr), Target(Target), SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
IsContext(IsContext) {} Target(TargetNode.getSourceRange()), IsContext(IsContext) {}
ContextFinder(SourceFile &SF, SourceLoc TargetLoc,
llvm::function_ref<bool(ASTNode)> IsContext =
[](ASTNode N) { return true; }) :
SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
Target(TargetLoc), IsContext(IsContext) {
assert(TargetLoc.isValid() && "Invalid loc to find");
}
bool walkToDeclPre(Decl *D, CharSourceRange Range) override { return contains(D); } bool walkToDeclPre(Decl *D, CharSourceRange Range) override { return contains(D); }
bool walkToStmtPre(Stmt *S) override { return contains(S); } bool walkToStmtPre(Stmt *S) override { return contains(S); }
bool walkToExprPre(Expr *E) override { return contains(E); } bool walkToExprPre(Expr *E) override { return contains(E); }
@@ -1722,7 +1728,8 @@ public:
FillProtocolStubContext FillProtocolStubContext:: FillProtocolStubContext FillProtocolStubContext::
getContextFromCursorInfo(ResolvedCursorInfo CursorInfo) { getContextFromCursorInfo(ResolvedCursorInfo CursorInfo) {
assert(CursorInfo.isValid()); if(!CursorInfo.isValid())
return FillProtocolStubContext();
if (!CursorInfo.IsRef) { if (!CursorInfo.IsRef) {
// If the type name is on the declared nominal, e.g. "class A {}" // If the type name is on the declared nominal, e.g. "class A {}"
if (auto ND = dyn_cast<NominalTypeDecl>(CursorInfo.ValueD)) { if (auto ND = dyn_cast<NominalTypeDecl>(CursorInfo.ValueD)) {
@@ -2127,6 +2134,110 @@ bool RefactoringActionSimplifyNumberLiteral::performChange() {
return true; return true;
} }
static CallExpr *findTrailingClosureTarget(SourceManager &SM,
ResolvedCursorInfo CursorInfo) {
if (CursorInfo.Kind == CursorInfoKind::StmtStart)
// StmtStart postion can't be a part of CallExpr.
return nullptr;
// Find inner most CallExpr
ContextFinder
Finder(*CursorInfo.SF, CursorInfo.Loc,
[](ASTNode N) {
return N.isStmt(StmtKind::Brace) || N.isExpr(ExprKind::Call);
});
Finder.resolve();
if (Finder.getContexts().empty()
|| !Finder.getContexts().back().is<Expr*>())
return nullptr;
CallExpr *CE = cast<CallExpr>(Finder.getContexts().back().get<Expr*>());
// The last arugment is a closure?
Expr *Args = CE->getArg();
if (!Args)
return nullptr;
Expr *LastArg;
if (auto *TSE = dyn_cast<TupleShuffleExpr>(Args))
Args = TSE->getSubExpr();
if (auto *PE = dyn_cast<ParenExpr>(Args)) {
LastArg = PE->getSubExpr();
} else {
auto *TE = cast<TupleExpr>(Args);
if (TE->getNumElements() == 0)
return nullptr;
LastArg = TE->getElements().back();
}
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(LastArg))
LastArg = ICE->getSyntacticSubExpr();
if (isa<ClosureExpr>(LastArg) || isa<CaptureListExpr>(LastArg))
return CE;
return nullptr;
}
bool RefactoringActionTrailingClosure::
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &Diag) {
SourceManager &SM = CursorInfo.SF->getASTContext().SourceMgr;
return findTrailingClosureTarget(SM, CursorInfo);
}
bool RefactoringActionTrailingClosure::performChange() {
auto *CE = findTrailingClosureTarget(SM, CursorInfo);
if (!CE)
return true;
Expr *Args = CE->getArg();
if (auto *TSE = dyn_cast<TupleShuffleExpr>(Args))
Args = TSE;
Expr *ClosureArg = nullptr;
Expr *PrevArg = nullptr;
SourceLoc LPLoc, RPLoc;
if (auto *PE = dyn_cast<ParenExpr>(Args)) {
ClosureArg = PE->getSubExpr();
LPLoc = PE->getLParenLoc();
RPLoc = PE->getRParenLoc();
} else {
auto *TE = cast<TupleExpr>(Args);
auto NumArgs = TE->getNumElements();
if (NumArgs == 0)
return true;
LPLoc = TE->getLParenLoc();
RPLoc = TE->getRParenLoc();
ClosureArg = TE->getElement(NumArgs - 1);
if (NumArgs > 1)
PrevArg = TE->getElement(NumArgs - 2);
}
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(ClosureArg))
ClosureArg = ICE->getSyntacticSubExpr();
if (LPLoc.isInvalid() || RPLoc.isInvalid())
return true;
// Replace:
// * Open paren with ' ' if the closure is sole argument.
// * Comma with ') ' otherwise.
if (PrevArg) {
CharSourceRange PreRange(
SM,
Lexer::getLocForEndOfToken(SM, PrevArg->getEndLoc()),
ClosureArg->getStartLoc());
EditConsumer.accept(SM, PreRange, ") ");
} else {
CharSourceRange PreRange(
SM, LPLoc, ClosureArg->getStartLoc());
EditConsumer.accept(SM, PreRange, " ");
}
// Remove original closing paren.
CharSourceRange PostRange(
SM,
Lexer::getLocForEndOfToken(SM, ClosureArg->getEndLoc()),
Lexer::getLocForEndOfToken(SM, RPLoc));
EditConsumer.remove(SM, PostRange);
return false;
}
static bool rangeStartMayNeedRename(ResolvedRangeInfo Info) { static bool rangeStartMayNeedRename(ResolvedRangeInfo Info) {
switch(Info.Kind) { switch(Info.Kind) {
case RangeKind::SingleExpression: { case RangeKind::SingleExpression: {

View File

@@ -0,0 +1,46 @@
struct Foo {
static func foo(a: () -> Int) {}
func qux(x: Int, y: () -> Int ) {}
}
func testTrailingClosure() -> String {
Foo.foo(a: { 1 })
Foo.bar(a: { print(3); return 1 })
Foo().qux(x: 1, y: { 1 })
let _ = Foo().quux(x: 1, y: { 1 })
[1,2,3]
.filter({ $0 % 2 == 0 })
.map({ $0 + 1 })
}
// RUN: %refactor -source-filename %s -pos=7:3 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=7:6 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=7:7 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=7:10 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=7:11 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=7:12 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=7:14 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=7:16 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=7:18 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=7:19 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=8:3 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=8:11 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=9:3 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=9:8 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=10:3 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=10:9 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=10:17 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=12:4 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=13:5 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// RUN: %refactor -source-filename %s -pos=14:5 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
// CHECK-TRAILING-CLOSURE: Convert To Trailing Closure
// CHECK-NO-TRAILING-CLOSURE: Action begins
// CHECK-NO-TRAILING-CLOSURE-NOT: Convert To Trailing Closure
// CHECK-NO-TRAILING-CLOSURE: Action ends

View File

@@ -0,0 +1,21 @@
struct Foo {
static func foo(a: () -> Int) {}
func qux(x: Int, y: () -> Int ) {}
}
func testTrailingClosure() -> String {
Foo.foo(a: { 1 })
Foo.bar(a: { print(3); return 1 })
Foo().qux(x: 1, y: { 1 })
let _ = Foo().quux(x: 1) { 1 }
[1,2,3]
.filter({ $0 % 2 == 0 })
.map({ $0 + 1 })
}

View File

@@ -0,0 +1,21 @@
struct Foo {
static func foo(a: () -> Int) {}
func qux(x: Int, y: () -> Int ) {}
}
func testTrailingClosure() -> String {
Foo.foo(a: { 1 })
Foo.bar(a: { print(3); return 1 })
Foo().qux(x: 1, y: { 1 })
let _ = Foo().quux(x: 1, y: { 1 })
[1,2,3]
.filter { $0 % 2 == 0 }
.map({ $0 + 1 })
}

View File

@@ -0,0 +1,21 @@
struct Foo {
static func foo(a: () -> Int) {}
func qux(x: Int, y: () -> Int ) {}
}
func testTrailingClosure() -> String {
Foo.foo(a: { 1 })
Foo.bar(a: { print(3); return 1 })
Foo().qux(x: 1, y: { 1 })
let _ = Foo().quux(x: 1, y: { 1 })
[1,2,3]
.filter({ $0 % 2 == 0 })
.map { $0 + 1 }
}

View File

@@ -0,0 +1,21 @@
struct Foo {
static func foo(a: () -> Int) {}
func qux(x: Int, y: () -> Int ) {}
}
func testTrailingClosure() -> String {
Foo.foo { 1 }
Foo.bar(a: { print(3); return 1 })
Foo().qux(x: 1, y: { 1 })
let _ = Foo().quux(x: 1, y: { 1 })
[1,2,3]
.filter({ $0 % 2 == 0 })
.map({ $0 + 1 })
}

View File

@@ -0,0 +1,21 @@
struct Foo {
static func foo(a: () -> Int) {}
func qux(x: Int, y: () -> Int ) {}
}
func testTrailingClosure() -> String {
Foo.foo(a: { 1 })
Foo.bar { print(3); return 1 }
Foo().qux(x: 1, y: { 1 })
let _ = Foo().quux(x: 1, y: { 1 })
[1,2,3]
.filter({ $0 % 2 == 0 })
.map({ $0 + 1 })
}

View File

@@ -0,0 +1,21 @@
struct Foo {
static func foo(a: () -> Int) {}
func qux(x: Int, y: () -> Int ) {}
}
func testTrailingClosure() -> String {
Foo.foo(a: { 1 })
Foo.bar(a: { print(3); return 1 })
Foo().qux(x: 1) { 1 }
let _ = Foo().quux(x: 1, y: { 1 })
[1,2,3]
.filter({ $0 % 2 == 0 })
.map({ $0 + 1 })
}

View File

@@ -0,0 +1,34 @@
struct Foo {
static func foo(a: () -> Int) {}
func qux(x: Int, y: () -> Int ) {}
}
func testTrailingClosure() -> String {
Foo.foo(a: { 1 })
Foo.bar(a: { print(3); return 1 })
Foo().qux(x: 1, y: { 1 })
let _ = Foo().quux(x: 1, y: { 1 })
[1,2,3]
.filter({ $0 % 2 == 0 })
.map({ $0 + 1 })
}
// RUN: rm -rf %t.result && mkdir -p %t.result
// RUN: %refactor -trailingclosure -source-filename %s -pos=7:3 > %t.result/L7.swift
// RUN: diff -u %S/Outputs/basic/L7.swift.expected %t.result/L7.swift
// RUN: %refactor -trailingclosure -source-filename %s -pos=8:11 > %t.result/L8.swift
// RUN: diff -u %S/Outputs/basic/L8.swift.expected %t.result/L8.swift
// RUN: %refactor -trailingclosure -source-filename %s -pos=9:8 > %t.result/L9.swift
// RUN: diff -u %S/Outputs/basic/L9.swift.expected %t.result/L9.swift
// RUN: %refactor -trailingclosure -source-filename %s -pos=10:17 > %t.result/L10.swift
// RUN: diff -u %S/Outputs/basic/L10.swift.expected %t.result/L10.swift
// RUN: %refactor -trailingclosure -source-filename %s -pos=13:5 > %t.result/L13.swift
// RUN: diff -u %S/Outputs/basic/L13.swift.expected %t.result/L13.swift
// RUN: %refactor -trailingclosure -source-filename %s -pos=14:5 > %t.result/L14.swift
// RUN: diff -u %S/Outputs/basic/L14.swift.expected %t.result/L14.swift

View File

@@ -55,7 +55,9 @@ Action(llvm::cl::desc("kind:"), llvm::cl::init(RefactoringKind::None),
clEnumValN(RefactoringKind::FindGlobalRenameRanges, clEnumValN(RefactoringKind::FindGlobalRenameRanges,
"find-rename-ranges", "Find detailed ranges for syntactic rename"), "find-rename-ranges", "Find detailed ranges for syntactic rename"),
clEnumValN(RefactoringKind::FindLocalRenameRanges, clEnumValN(RefactoringKind::FindLocalRenameRanges,
"find-local-rename-ranges", "Find detailed ranges for local rename"))); "find-local-rename-ranges", "Find detailed ranges for local rename"),
clEnumValN(RefactoringKind::TrailingClosure,
"trailingclosure", "Perform trailing closure refactoring")));
static llvm::cl::opt<std::string> static llvm::cl::opt<std::string>