[Async Refactoring] Wrap code in a continuation if conversion doesn't yield reasonable results

If we are requested to convert a function to async, but the call in the function’s body that eventually calls the completion handler doesn’t have an async alternative, we are currently copying the call as-is, replacing any calls to the completion handler by placeholders.

For example,
```swift
func testDispatch(completionHandler: @escaping (Int) -> Void) {
  DispatchQueue.global.async {
     completionHandler(longSyncFunc())
  }
}
```
becomes
```swift
func testDispatch() async -> Int  {
  DispatchQueue.global.async {
     <#completionHandler#>(longSyncFunc())
  }
}
```

and

```swift
func testUrlSession(completionHandler: @escaping (Data) -> Void) {
  let task = URLSession.shared.dataTask(with: request) { data, response, error in
    completion(data!)
  }
  task.resume()
}
```
becomes
```swift
func testUrlSession() async -> Data {
  let task = URLSession.shared.dataTask(with: request) { data, response, error in
    <#completion#>(data!)
  }
  task.resume()
}
```

Both of these are better modelled using continuations. Thus, if we find an expression that contains a call to the completion handler and can’t be hoisted to an await statement, we are wrapping the rest of the current scope in a `withChecked(Throwing)Continuation`, producing the following results:

```swift
func testDispatch() async -> Int {
  return await withCheckedContinuation { (continuation: CheckedContinuation<Int, Never>) in
    DispatchQueue.global.async {
      continuation.resume(returning: syncComputation())
    }
  }
}
```

and

```swift
func testDataTask() async -> Int?
  return await withCheckedContinuation { (continuation: CheckedContinuation<Data, Never>) in
    let task = URLSession.shared.dataTask { data, response, error in
      continuation.resume(returning: data!)
    }
    task.resume()
  }
}
```

I think both are much closer to what the developer is actually expecting.

Resolves rdar://79304583
This commit is contained in:
Alex Hoppen
2021-07-01 15:57:49 +02:00
parent d0472e1b21
commit 54fcc90841
3 changed files with 685 additions and 73 deletions

View File

@@ -4183,6 +4183,8 @@ struct AsyncHandlerDesc {
}
}
HandlerType getHandlerType() const { return Type; }
/// Get the type of the completion handler.
swift::Type getType() const {
if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
@@ -5697,6 +5699,33 @@ private:
}
};
/// Checks whether an ASTNode contains a reference to a given declaration.
class DeclReferenceFinder : private SourceEntityWalker {
bool HasFoundReference = false;
const Decl *Search;
bool walkToExprPre(Expr *E) override {
if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
if (DRE->getDecl() == Search) {
HasFoundReference = true;
return false;
}
}
return true;
}
DeclReferenceFinder(const Decl *Search) : Search(Search) {}
public:
/// Returns \c true if \p node contains a reference to \p Search, \c false
/// otherwise.
static bool containsReference(ASTNode Node, const ValueDecl *Search) {
DeclReferenceFinder Checker(Search);
Checker.walk(Node);
return Checker.HasFoundReference;
}
};
/// Builds up async-converted code for an AST node.
///
/// If it is a function, its declaration will have `async` added. If a
@@ -5727,6 +5756,21 @@ private:
/// the code the user intended. In most cases the refactoring will continue,
/// with any unhandled decls wrapped in placeholders instead.
class AsyncConverter : private SourceEntityWalker {
struct Scope {
llvm::DenseSet<DeclBaseName> Names;
/// If this scope is wrapped in a \c withChecked(Throwing)Continuation, the
/// name of the continuation that must be resumed where there previously was
/// a call to the function's completion handler.
/// Otherwise an empty identifier.
Identifier ContinuationName;
Scope(Identifier ContinuationName)
: Names(), ContinuationName(ContinuationName) {}
/// Whether this scope is wrapped in a \c withChecked(Throwing)Continuation.
bool isWrappedInContination() const { return !ContinuationName.empty(); }
};
SourceFile *SF;
SourceManager &SM;
DiagnosticEngine &DiagEngine;
@@ -5755,9 +5799,12 @@ class AsyncConverter : private SourceEntityWalker {
// declarations of old completion handler parametes, as well as the
// replacement for other hoisted declarations and their references
llvm::DenseMap<const Decl *, Identifier> Names;
// Names of decls in each scope, where the first element is the initial scope
// and the last is the current scope.
llvm::SmallVector<llvm::DenseSet<DeclBaseName>, 4> ScopedNames;
/// The scopes (containing all name decls and whether the scope is wrapped in
/// a continuation) as the AST is being walked. The first element is the
/// initial scope and the last is the current scope.
llvm::SmallVector<Scope, 4> Scopes;
// Mapping of \c BraceStmt -> declarations referenced in that statement
// without first being declared. These are used to fill the \c ScopeNames
// map on entering that scope.
@@ -6149,6 +6196,58 @@ private:
addRange(LastAddedLoc, P->getEndLoc(), /*ToEndOfToken*/ true);
}
/// Check whether \p Node requires the remainder of this scope to be wrapped
/// in a \c withChecked(Throwing)Continuation. If it is necessary, add
/// a call to \c withChecked(Throwing)Continuation and modify the current
/// scope (\c Scopes.back() ) so that it knows it's wrapped in a continuation.
///
/// Wrapping a node in a continuation is necessary if the following conditions
/// are satisfied:
/// - It contains a reference to the \c TopHandler's completion hander,
/// because these completion handler calls need to be promoted to \c return
/// statements in the refactored method, but
/// - We cannot hoist the completion handler of \p Node, because it doesn't
/// have an async alternative by our heuristics (e.g. because of a
/// completion handler name mismatch or because it also returns a value
/// synchronously).
void wrapScopeInContinationIfNecessary(ASTNode Node) {
if (NestedExprCount != 0) {
// We can't start a continuation in the middle of an expression
return;
}
if (Scopes.back().isWrappedInContination()) {
// We are already in a continuation. No need to add another one.
return;
}
if (!DeclReferenceFinder::containsReference(Node,
TopHandler.getHandler())) {
// The node doesn't have a reference to the function's completion handler.
// It can stay a call with a completion handler, because we don't need to
// promote a completion handler call to a 'return'.
return;
}
// Wrap the current call in a continuation
Identifier contName = createUniqueName("continuation");
Scopes.back().Names.insert(contName);
Scopes.back().ContinuationName = contName;
insertCustom(Node.getStartLoc(), [&]() {
OS << tok::kw_return << ' ';
if (TopHandler.HasError) {
OS << tok::kw_try << ' ';
}
OS << "await ";
if (TopHandler.HasError) {
OS << "withCheckedThrowingContinuation ";
} else {
OS << "withCheckedContinuation ";
}
OS << tok::l_brace << ' ' << contName << ' ' << tok::kw_in << '\n';
});
}
bool walkToPatternPre(Pattern *P) override {
// If we're not converting a pattern, there's nothing extra to do.
if (!ConvertingPattern)
@@ -6167,18 +6266,21 @@ private:
bool walkToDeclPre(Decl *D, CharSourceRange Range) override {
if (isa<PatternBindingDecl>(D)) {
// We can't hoist a closure inside a PatternBindingDecl. If it contains
// a call to the completion handler, wrap it in a continuation.
wrapScopeInContinationIfNecessary(D);
NestedExprCount++;
return true;
}
// Functions and types already have their names in \c ScopedNames, only
// Functions and types already have their names in \c Scopes.Names, only
// variables should need to be renamed.
if (isa<VarDecl>(D)) {
// If we don't already have a name for the var, assign it one. Note that
// vars in binding patterns may already have assigned names here.
if (Names.find(D) == Names.end()) {
auto Ident = assignUniqueName(D, StringRef());
ScopedNames.back().insert(Ident);
Scopes.back().Names.insert(Ident);
}
addCustom(D->getSourceRange(), [&]() {
OS << newNameFor(D);
@@ -6239,11 +6341,20 @@ private:
return addCustom(E->getSourceRange(),
[&]() { OS << newNameFor(D, true); });
}
} else if (NestedExprCount == 0) {
if (CallExpr *CE = TopHandler.getAsHandlerCall(E))
} else if (CallExpr *CE = TopHandler.getAsHandlerCall(E)) {
if (Scopes.back().isWrappedInContination()) {
return addCustom(CE->getSourceRange(),
[&]() { addHandlerCallToContinuation(CE); });
} else if (NestedExprCount == 0) {
return addCustom(CE->getSourceRange(), [&]() { addHandlerCall(CE); });
if (auto *CE = dyn_cast<CallExpr>(E)) {
}
} else if (auto *CE = dyn_cast<CallExpr>(E)) {
// Try and hoist a call's completion handler. Don't do so if
// - the current expression is nested (we can't start hoisting in the
// middle of an expression)
// - the current scope is wrapped in a continuation (we can't have await
// calls in the continuation block)
if (NestedExprCount == 0 && !Scopes.back().isWrappedInContination()) {
// If the refactoring is on the call itself, do not require the callee
// to have the @completionHandlerAsync attribute or a completion-like
// name.
@@ -6256,6 +6367,10 @@ private:
}
}
// We didn't do any special conversion for this expression. If needed, wrap
// it in a continuation.
wrapScopeInContinationIfNecessary(E);
NestedExprCount++;
return true;
}
@@ -6319,18 +6434,37 @@ private:
}
bool walkToStmtPost(Stmt *S) override {
if (startsNewScope(S))
ScopedNames.pop_back();
if (startsNewScope(S)) {
bool ClosedScopeWasWrappedInContinuation =
Scopes.back().isWrappedInContination();
Scopes.pop_back();
if (ClosedScopeWasWrappedInContinuation &&
!Scopes.back().isWrappedInContination()) {
// The nested scope was wrapped in a continuation but the current one
// isn't anymore. Add the '}' that corresponds to the the call to
// withChecked(Throwing)Continuation.
insertCustom(S->getEndLoc(), [&]() { OS << tok::r_brace << '\n'; });
}
}
return true;
}
bool addCustom(SourceRange Range, std::function<void()> Custom = {}) {
bool addCustom(SourceRange Range, llvm::function_ref<void()> Custom = {}) {
addRange(LastAddedLoc, Range.Start);
Custom();
LastAddedLoc = Lexer::getLocForEndOfToken(SM, Range.End);
return false;
}
/// Insert custom text at the given \p Loc that shouldn't replace any existing
/// source code.
bool insertCustom(SourceLoc Loc, llvm::function_ref<void()> Custom = {}) {
addRange(LastAddedLoc, Loc);
Custom();
LastAddedLoc = Loc;
return false;
}
void addRange(SourceLoc Start, SourceLoc End, bool ToEndOfToken = false) {
if (ToEndOfToken) {
OS << Lexer::getCharSourceRangeFromSourceRange(SM,
@@ -6439,6 +6573,9 @@ private:
void addDo() { OS << tok::kw_do << " " << tok::l_brace << "\n"; }
void addHandlerCall(const CallExpr *CE) {
assert(TopHandler.getAsHandlerCall(const_cast<CallExpr *>(CE)) == CE &&
"addHandlerCall must be used with a call to the TopHandler's "
"completion handler");
auto Exprs = TopHandler.extractResultArgs(CE);
bool AddedReturnOrThrow = true;
@@ -6462,20 +6599,61 @@ private:
if (!Args.empty()) {
if (AddedReturnOrThrow)
OS << " ";
if (Args.size() > 1)
OS << tok::l_paren;
for (size_t I = 0, E = Args.size(); I < E; ++I) {
if (I > 0)
OS << tok::comma << " ";
unsigned I = 0;
addTupleOf(Args, OS, [&](Expr *Elt) {
// Can't just add the range as we need to perform replacements
convertNode(Args[I], /*StartOverride=*/CE->getArgumentLabelLoc(I),
convertNode(Elt, /*StartOverride=*/CE->getArgumentLabelLoc(I),
/*ConvertCalls=*/false);
}
if (Args.size() > 1)
OS << tok::r_paren;
I++;
});
}
}
/// Assuming that \p CE is a call to \c TopHandler's completion handler and
/// that the current scope is wrapped in a continuation, replace it with a
/// call to the continuation.
void addHandlerCallToContinuation(const CallExpr *CE) {
assert(TopHandler.getAsHandlerCall(const_cast<CallExpr *>(CE)) == CE &&
"addHandlerCallToContinuation must be used with a call to the "
"TopHandler's completion handler");
assert(Scopes.back().isWrappedInContination());
ArrayRef<Expr *> Args;
StringRef ResumeArgumentLabel;
switch (TopHandler.getHandlerType()) {
case HandlerType::PARAMS: {
auto Exprs = TopHandler.extractResultArgs(CE);
Args = Exprs.args();
if (!Exprs.isError()) {
ResumeArgumentLabel = "returning";
} else {
ResumeArgumentLabel = "throwing";
}
break;
}
case HandlerType::RESULT: {
Args = callArgs(CE).ref();
ResumeArgumentLabel = "with";
break;
}
case HandlerType::INVALID:
llvm_unreachable("Invalid top handler");
}
Identifier ContName = Scopes.back().ContinuationName;
OS << ContName << tok::period << "resume" << tok::l_paren
<< ResumeArgumentLabel << tok::colon << ' ';
unsigned I = 0;
addTupleOf(Args, OS, [&](Expr *Elt) {
// Can't just add the range as we need to perform replacements
convertNode(Elt, /*StartOverride=*/CE->getArgumentLabelLoc(I),
/*ConvertCalls=*/false);
I++;
});
OS << tok::r_paren;
}
/// From the given expression \p E, which is an argument to a function call,
/// extract the passed closure if there is one. Otherwise return \c nullptr.
ClosureExpr *extractCallback(Expr *E) {
@@ -6729,7 +6907,7 @@ private:
StringRef ResultName;
if (!HandlerDesc.willAsyncReturnVoid()) {
Identifier Unique = createUniqueName("result");
ScopedNames.back().insert(Unique);
Scopes.back().Names.insert(Unique);
ResultName = Unique.str();
OS << tok::kw_let << " " << ResultName;
@@ -6999,7 +7177,7 @@ private:
Identifier createUniqueName(StringRef Name) {
Identifier Ident = getASTContext().getIdentifier(Name);
auto &CurrentNames = ScopedNames.back();
auto &CurrentNames = Scopes.back().Names;
if (CurrentNames.count(Ident)) {
// Add a number to the end of the name until it's unique given the current
// names in scope.
@@ -7018,7 +7196,7 @@ private:
/// Create a unique name for the variable declared by \p D that doesn't
/// clash with any other names in scope, using \p BoundName as the base name
/// if not empty and the name of \p D otherwise. Adds this name to both
/// \c Names and the current scope's names (\c ScopedNames).
/// \c Names and the current scope's names (\c Scopes.Names).
Identifier assignUniqueName(const Decl *D, StringRef BoundName) {
Identifier Ident;
if (BoundName.empty()) {
@@ -7037,7 +7215,7 @@ private:
}
Names.try_emplace(D, Ident);
ScopedNames.back().insert(Ident);
Scopes.back().Names.insert(Ident);
return Ident;
}
@@ -7051,11 +7229,18 @@ private:
}
void addNewScope(const llvm::DenseSet<const Decl *> &Decls) {
ScopedNames.push_back({});
for (auto DeclAndNumRefs : Decls) {
auto Name = getDeclName(DeclAndNumRefs);
if (Scopes.empty()) {
Scopes.emplace_back(/*ContinuationName=*/Identifier());
} else {
// If the parent scope is nested in a continuation, the new one is also.
// Carry over the continuation name.
Identifier PreviousContinuationName = Scopes.back().ContinuationName;
Scopes.emplace_back(PreviousContinuationName);
}
for (auto D : Decls) {
auto Name = getDeclName(D);
if (!Name.empty())
ScopedNames.back().insert(Name);
Scopes.back().Names.insert(Name);
}
}
@@ -7242,13 +7427,17 @@ private:
void addAsyncFuncReturnType(const AsyncHandlerDesc &HandlerDesc) {
// Type or (Type1, Type2, ...)
SmallVector<LabeledReturnType, 2> Scratch;
addTupleOf(HandlerDesc.getAsyncReturnTypes(Scratch), OS,
[&](LabeledReturnType LabelAndType) {
if (!LabelAndType.Label.empty()) {
OS << LabelAndType.Label << tok::colon << " ";
}
LabelAndType.Ty->print(OS);
});
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
if (ReturnTypes.empty()) {
OS << "Void";
} else {
addTupleOf(ReturnTypes, OS, [&](LabeledReturnType LabelAndType) {
if (!LabelAndType.Label.empty()) {
OS << LabelAndType.Label << tok::colon << " ";
}
LabelAndType.Ty->print(OS);
});
}
}
/// If \p FD is generic, adds a type annotation with the return type of the