diff --git a/include/swift/IDE/Utils.h b/include/swift/IDE/Utils.h index 0538f9ad228..b27506a8575 100644 --- a/include/swift/IDE/Utils.h +++ b/include/swift/IDE/Utils.h @@ -17,6 +17,7 @@ #include "swift/Basic/LLVM.h" #include "swift/AST/ASTNode.h" #include "swift/AST/DeclNameLoc.h" +#include "swift/AST/Effects.h" #include "swift/AST/Module.h" #include "swift/AST/ASTPrinter.h" #include "swift/IDE/SourceEntityWalker.h" @@ -345,7 +346,7 @@ struct ResolvedRangeInfo { ArrayRef TokensInRange; CharSourceRange ContentRange; bool HasSingleEntry; - bool ThrowingUnhandledError; + PossibleEffects UnhandledEffects; OrphanKind Orphan; // The topmost ast nodes contained in the given range. @@ -359,7 +360,7 @@ struct ResolvedRangeInfo { ArrayRef TokensInRange, DeclContext* RangeContext, Expr *CommonExprParent, bool HasSingleEntry, - bool ThrowingUnhandledError, + PossibleEffects UnhandledEffects, OrphanKind Orphan, ArrayRef ContainedNodes, ArrayRef DeclaredDecls, ArrayRef ReferencedDecls): Kind(Kind), @@ -367,7 +368,7 @@ struct ResolvedRangeInfo { TokensInRange(TokensInRange), ContentRange(calculateContentRange(TokensInRange)), HasSingleEntry(HasSingleEntry), - ThrowingUnhandledError(ThrowingUnhandledError), + UnhandledEffects(UnhandledEffects), Orphan(Orphan), ContainedNodes(ContainedNodes), DeclaredDecls(DeclaredDecls), ReferencedDecls(ReferencedDecls), @@ -376,7 +377,7 @@ struct ResolvedRangeInfo { ResolvedRangeInfo(ArrayRef TokensInRange) : ResolvedRangeInfo(RangeKind::Invalid, {nullptr, ExitState::Unsure}, TokensInRange, nullptr, /*Commom Expr Parent*/nullptr, - /*Single entry*/true, /*unhandled error*/false, + /*Single entry*/true, /*UnhandledEffects*/{}, OrphanKind::None, {}, {}, {}) {} ResolvedRangeInfo(): ResolvedRangeInfo(ArrayRef()) {} void print(llvm::raw_ostream &OS) const; diff --git a/lib/IDE/IDERequests.cpp b/lib/IDE/IDERequests.cpp index 8d949fdfa28..5530b30f221 100644 --- a/lib/IDE/IDERequests.cpp +++ b/lib/IDE/IDERequests.cpp @@ -12,6 +12,7 @@ #include "swift/AST/ASTPrinter.h" #include "swift/AST/Decl.h" +#include "swift/AST/Effects.h" #include "swift/AST/NameLookup.h" #include "swift/AST/ASTDemangler.h" #include "swift/Basic/SourceManager.h" @@ -377,45 +378,45 @@ public: ResolvedRangeInfo resolve(); }; -static bool hasUnhandledError(ArrayRef Nodes) { - class ThrowingEntityAnalyzer : public SourceEntityWalker { - bool Throwing; +static PossibleEffects getUnhandledEffects(ArrayRef Nodes) { + class EffectsAnalyzer : public SourceEntityWalker { + PossibleEffects Effects; public: - ThrowingEntityAnalyzer(): Throwing(false) {} bool walkToStmtPre(Stmt *S) override { if (auto DCS = dyn_cast(S)) { if (DCS->isSyntacticallyExhaustive()) return false; - Throwing = true; + Effects |= EffectKind::Throws; } else if (isa(S)) { - Throwing = true; + Effects |= EffectKind::Throws; } - return !Throwing; + return true; } bool walkToExprPre(Expr *E) override { // Don't walk into closures, they only produce effects when called. if (isa(E)) return false; - - if (isa(E)) { - Throwing = true; - } - return !Throwing; + + if (isa(E)) + Effects |= EffectKind::Throws; + if (isa(E)) + Effects |= EffectKind::Async; + + return true; } bool walkToDeclPre(Decl *D, CharSourceRange Range) override { return false; } - bool walkToDeclPost(Decl *D) override { return !Throwing; } - bool walkToStmtPost(Stmt *S) override { return !Throwing; } - bool walkToExprPost(Expr *E) override { return !Throwing; } - bool isThrowing() { return Throwing; } + PossibleEffects getEffects() const { return Effects; } }; - return Nodes.end() != std::find_if(Nodes.begin(), Nodes.end(), [](ASTNode N) { - ThrowingEntityAnalyzer Analyzer; + PossibleEffects Effects; + for (auto N : Nodes) { + EffectsAnalyzer Analyzer; Analyzer.walk(N); - return Analyzer.isThrowing(); - }); + Effects |= Analyzer.getEffects(); + } + return Effects; } struct RangeResolver::Implementation { @@ -553,7 +554,7 @@ private: assert(ContainedASTNodes.size() == 1); // Single node implies single entry point, or is it? bool SingleEntry = true; - bool UnhandledError = hasUnhandledError({Node}); + auto UnhandledEffects = getUnhandledEffects({Node}); OrphanKind Kind = getOrphanKind(ContainedASTNodes); if (Node.is()) return ResolvedRangeInfo(RangeKind::SingleExpression, @@ -562,7 +563,7 @@ private: getImmediateContext(), /*Common Parent Expr*/nullptr, SingleEntry, - UnhandledError, Kind, + UnhandledEffects, Kind, llvm::makeArrayRef(ContainedASTNodes), llvm::makeArrayRef(DeclaredDecls), llvm::makeArrayRef(ReferencedDecls)); @@ -573,7 +574,7 @@ private: getImmediateContext(), /*Common Parent Expr*/nullptr, SingleEntry, - UnhandledError, Kind, + UnhandledEffects, Kind, llvm::makeArrayRef(ContainedASTNodes), llvm::makeArrayRef(DeclaredDecls), llvm::makeArrayRef(ReferencedDecls)); @@ -585,7 +586,7 @@ private: getImmediateContext(), /*Common Parent Expr*/nullptr, SingleEntry, - UnhandledError, Kind, + UnhandledEffects, Kind, llvm::makeArrayRef(ContainedASTNodes), llvm::makeArrayRef(DeclaredDecls), llvm::makeArrayRef(ReferencedDecls)); @@ -646,7 +647,7 @@ public: getImmediateContext(), Parent, hasSingleEntryPoint(ContainedASTNodes), - hasUnhandledError(ContainedASTNodes), + getUnhandledEffects(ContainedASTNodes), getOrphanKind(ContainedASTNodes), llvm::makeArrayRef(ContainedASTNodes), llvm::makeArrayRef(DeclaredDecls), @@ -893,7 +894,7 @@ public: TokensInRange, getImmediateContext(), nullptr, hasSingleEntryPoint(ContainedASTNodes), - hasUnhandledError(ContainedASTNodes), + getUnhandledEffects(ContainedASTNodes), getOrphanKind(ContainedASTNodes), llvm::makeArrayRef(ContainedASTNodes), llvm::makeArrayRef(DeclaredDecls), @@ -908,7 +909,7 @@ public: getImmediateContext(), /*Common Parent Expr*/ nullptr, /*SinleEntry*/ true, - hasUnhandledError(ContainedASTNodes), + getUnhandledEffects(ContainedASTNodes), getOrphanKind(ContainedASTNodes), llvm::makeArrayRef(ContainedASTNodes), llvm::makeArrayRef(DeclaredDecls), diff --git a/lib/IDE/Refactoring.cpp b/lib/IDE/Refactoring.cpp index 07da5a8afb9..5a863e1da35 100644 --- a/lib/IDE/Refactoring.cpp +++ b/lib/IDE/Refactoring.cpp @@ -1304,7 +1304,9 @@ bool RefactoringActionExtractFunction::performChange() { } OS << ")"; - if (RangeInfo.ThrowingUnhandledError) + if (RangeInfo.UnhandledEffects.contains(EffectKind::Async)) + OS << " async"; + if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws)) OS << " " << tok::kw_throws; bool InsertedReturnType = false; @@ -1335,6 +1337,8 @@ bool RefactoringActionExtractFunction::performChange() { if (RangeInfo.UnhandledEffects.contains(EffectKind::Throws)) OS << tok::kw_try << " "; + if (RangeInfo.UnhandledEffects.contains(EffectKind::Async)) + OS << "await "; CallNameOffset = Buffer.size() - ReplaceBegin; OS << PreferredName << "("; diff --git a/lib/IDE/SwiftSourceDocInfo.cpp b/lib/IDE/SwiftSourceDocInfo.cpp index 8828c8c98a0..1fc5a139864 100644 --- a/lib/IDE/SwiftSourceDocInfo.cpp +++ b/lib/IDE/SwiftSourceDocInfo.cpp @@ -724,9 +724,12 @@ void ResolvedRangeInfo::print(llvm::raw_ostream &OS) const { OS << "Multi\n"; } - if (ThrowingUnhandledError) { + if (UnhandledEffects.contains(EffectKind::Throws)) { OS << "Throwing\n"; } + if (UnhandledEffects.contains(EffectKind::Async)) { + OS << "Async\n"; + } if (Orphan != OrphanKind::None) { OS << ""; diff --git a/test/refactoring/ExtractFunction/Outputs/await/async1.swift.expected b/test/refactoring/ExtractFunction/Outputs/await/async1.swift.expected new file mode 100644 index 00000000000..83dc90eeab2 --- /dev/null +++ b/test/refactoring/ExtractFunction/Outputs/await/async1.swift.expected @@ -0,0 +1,15 @@ +func longLongLongJourney() async -> Int { 0 } +func longLongLongAwryJourney() async throws -> Int { 0 } +func consumesAsync(_ fn: () async throws -> Void) rethrows {} + +fileprivate func new_name() async -> Int { +return await longLongLongJourney() +} + +func testThrowingClosure() async throws -> Int { + let x = await new_name() + let y = try await longLongLongAwryJourney() + 1 + try consumesAsync { try await longLongLongAwryJourney() } + return x + y +} + diff --git a/test/refactoring/ExtractFunction/Outputs/await/async2.swift.expected b/test/refactoring/ExtractFunction/Outputs/await/async2.swift.expected new file mode 100644 index 00000000000..a9814f953ba --- /dev/null +++ b/test/refactoring/ExtractFunction/Outputs/await/async2.swift.expected @@ -0,0 +1,15 @@ +func longLongLongJourney() async -> Int { 0 } +func longLongLongAwryJourney() async throws -> Int { 0 } +func consumesAsync(_ fn: () async throws -> Void) rethrows {} + +fileprivate func new_name() async throws -> Int { +return try await longLongLongAwryJourney() + 1 +} + +func testThrowingClosure() async throws -> Int { + let x = await longLongLongJourney() + let y = try await new_name() + try consumesAsync { try await longLongLongAwryJourney() } + return x + y +} + diff --git a/test/refactoring/ExtractFunction/Outputs/await/consumes_async.swift.expected b/test/refactoring/ExtractFunction/Outputs/await/consumes_async.swift.expected new file mode 100644 index 00000000000..807ec91e712 --- /dev/null +++ b/test/refactoring/ExtractFunction/Outputs/await/consumes_async.swift.expected @@ -0,0 +1,15 @@ +func longLongLongJourney() async -> Int { 0 } +func longLongLongAwryJourney() async throws -> Int { 0 } +func consumesAsync(_ fn: () async throws -> Void) rethrows {} + +fileprivate func new_name() throws { +try consumesAsync { try await longLongLongAwryJourney() } +} + +func testThrowingClosure() async throws -> Int { + let x = await longLongLongJourney() + let y = try await longLongLongAwryJourney() + 1 + try new_name() + return x + y +} + diff --git a/test/refactoring/ExtractFunction/await.swift b/test/refactoring/ExtractFunction/await.swift new file mode 100644 index 00000000000..760938e81c0 --- /dev/null +++ b/test/refactoring/ExtractFunction/await.swift @@ -0,0 +1,18 @@ +func longLongLongJourney() async -> Int { 0 } +func longLongLongAwryJourney() async throws -> Int { 0 } +func consumesAsync(_ fn: () async throws -> Void) rethrows {} + +func testThrowingClosure() async throws -> Int { + let x = await longLongLongJourney() + let y = try await longLongLongAwryJourney() + 1 + try consumesAsync { try await longLongLongAwryJourney() } + return x + y +} + +// RUN: %empty-directory(%t.result) +// RUN: %refactor -extract-function -source-filename %s -pos=6:11 -end-pos=6:38 >> %t.result/async1.swift +// RUN: diff -u %S/Outputs/await/async1.swift.expected %t.result/async1.swift +// RUN: %refactor -extract-function -source-filename %s -pos=7:11 -end-pos=7:50 >> %t.result/async2.swift +// RUN: diff -u %S/Outputs/await/async2.swift.expected %t.result/async2.swift +// RUN: %refactor -extract-function -source-filename %s -pos=8:1 -end-pos=8:60 >> %t.result/consumes_async.swift +// RUN: diff -u %S/Outputs/await/consumes_async.swift.expected %t.result/consumes_async.swift