[Refactoring] Support refactoring calls to async if a variable or function is used as completion handler

Previously, we only supported  refactoring a function to call the async alternative if a closure was used for the callback parameter. With this change, we also support calling a arbitrary function (or variable with function type) that is passed to the completion handler argument.

The implementation basically re-uses the code we already have to create the legacy function’s body (which calls the newly created async version and then forwards the arguments to the legacy completion handler).

To describe the completion handler that the result is being forwarded to, I’m also using `AsyncHandlerDesc`, but since the completion handler may be a variable, it doesn’t necessarily have an `Index` within a function decl that declares it. Because of this, I split the `AsyncHandlerDesc` up into a context-free `AsyncHandlerDesc` (without an `Index`) and `AsyncHandlerParamDesc` (which includes the `Index`). It turns out that `AsyncHandlerDesc` is sufficient in most places.

Resolves rdar://77460524
This commit is contained in:
Alex Hoppen
2021-05-07 18:11:01 +02:00
parent b8ec77892c
commit dd978cca0b
2 changed files with 458 additions and 95 deletions

View File

@@ -25,14 +25,15 @@
#include "swift/AST/USRGeneration.h" #include "swift/AST/USRGeneration.h"
#include "swift/Basic/Edit.h" #include "swift/Basic/Edit.h"
#include "swift/Basic/StringExtras.h" #include "swift/Basic/StringExtras.h"
#include "swift/ClangImporter/ClangImporter.h"
#include "swift/Frontend/Frontend.h" #include "swift/Frontend/Frontend.h"
#include "swift/IDE/IDERequests.h" #include "swift/IDE/IDERequests.h"
#include "swift/Index/Index.h" #include "swift/Index/Index.h"
#include "swift/ClangImporter/ClangImporter.h"
#include "swift/Parse/Lexer.h" #include "swift/Parse/Lexer.h"
#include "swift/Sema/IDETypeChecking.h" #include "swift/Sema/IDETypeChecking.h"
#include "swift/Subsystems.h" #include "swift/Subsystems.h"
#include "clang/Rewrite/Core/RewriteBuffer.h" #include "clang/Rewrite/Core/RewriteBuffer.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringSet.h"
@@ -3946,19 +3947,25 @@ PtrArrayRef<Expr *> callArgs(const ApplyExpr *AE) {
return PtrArrayRef<Expr *>(); return PtrArrayRef<Expr *>();
} }
FuncDecl *getUnderlyingFunc(const Expr *Fn) { /// A more aggressive variant of \c Expr::getReferencedDecl that also looks
/// through autoclosures created to pass the \c self parameter to a member funcs
ValueDecl *getReferencedDecl(const Expr *Fn) {
Fn = Fn->getSemanticsProvidingExpr(); Fn = Fn->getSemanticsProvidingExpr();
if (auto *DRE = dyn_cast<DeclRefExpr>(Fn)) if (auto *DRE = dyn_cast<DeclRefExpr>(Fn))
return dyn_cast_or_null<FuncDecl>(DRE->getDecl()); return DRE->getDecl();
if (auto ApplyE = dyn_cast<SelfApplyExpr>(Fn)) if (auto ApplyE = dyn_cast<SelfApplyExpr>(Fn))
return getUnderlyingFunc(ApplyE->getFn()); return getReferencedDecl(ApplyE->getFn());
if (auto *ACE = dyn_cast<AutoClosureExpr>(Fn)) { if (auto *ACE = dyn_cast<AutoClosureExpr>(Fn)) {
if (auto *Unwrapped = ACE->getUnwrappedCurryThunkExpr()) if (auto *Unwrapped = ACE->getUnwrappedCurryThunkExpr())
return getUnderlyingFunc(Unwrapped); return getReferencedDecl(Unwrapped);
} }
return nullptr; return nullptr;
} }
FuncDecl *getUnderlyingFunc(const Expr *Fn) {
return dyn_cast_or_null<FuncDecl>(getReferencedDecl(Fn));
}
/// Find the outermost call of the given location /// Find the outermost call of the given location
CallExpr *findOuterCall(const ResolvedCursorInfo &CursorInfo) { CallExpr *findOuterCall(const ResolvedCursorInfo &CursorInfo) {
auto IncludeInContext = [](ASTNode N) { auto IncludeInContext = [](ASTNode N) {
@@ -4064,40 +4071,32 @@ public:
enum class HandlerType { INVALID, PARAMS, RESULT }; enum class HandlerType { INVALID, PARAMS, RESULT };
/// Given a function with an async alternative (or one that *could* have an /// Given a function with an async alternative (or one that *could* have an
/// async alternative), stores information about the handler parameter. /// async alternative), stores information about the completion handler.
/// The completion handler can be either a variable (which includes a parameter)
/// or a function
struct AsyncHandlerDesc { struct AsyncHandlerDesc {
const ParamDecl *Handler = nullptr; PointerUnion<const VarDecl *, const AbstractFunctionDecl *> Handler = nullptr;
int Index = -1;
HandlerType Type = HandlerType::INVALID; HandlerType Type = HandlerType::INVALID;
bool HasError = false; bool HasError = false;
static AsyncHandlerDesc find(const FuncDecl *FD, bool ignoreName) { static AsyncHandlerDesc get(const ValueDecl *Handler, bool ignoreName) {
if (!FD || FD->hasAsync() || FD->hasThrows())
return AsyncHandlerDesc();
// Require at least one parameter and void return type
auto *Params = FD->getParameters();
if (Params->size() == 0 || !FD->getResultInterfaceType()->isVoid())
return AsyncHandlerDesc();
AsyncHandlerDesc HandlerDesc; AsyncHandlerDesc HandlerDesc;
if (auto Var = dyn_cast<VarDecl>(Handler)) {
// Assume the handler is the last parameter for now HandlerDesc.Handler = Var;
HandlerDesc.Index = Params->size() - 1; } else if (auto Func = dyn_cast<AbstractFunctionDecl>(Handler)) {
HandlerDesc.Handler = Params->get(HandlerDesc.Index); HandlerDesc.Handler = Func;
} else {
// Callback must not be attributed with @autoclosure // The handler must be a variable or function
if (HandlerDesc.Handler->isAutoClosure())
return AsyncHandlerDesc(); return AsyncHandlerDesc();
}
// Callback must have a completion-like name (if we're not ignoring it) // Callback must have a completion-like name (if we're not ignoring it)
if (!ignoreName && if (!ignoreName && !isCompletionHandlerParamName(HandlerDesc.getNameStr()))
!isCompletionHandlerParamName(HandlerDesc.Handler->getNameStr()))
return AsyncHandlerDesc(); return AsyncHandlerDesc();
// Callback must be a function type and return void. Doesn't need to have // Callback must be a function type and return void. Doesn't need to have
// any parameters - may just be a "I'm done" callback // any parameters - may just be a "I'm done" callback
auto *HandlerTy = HandlerDesc.Handler->getType()->getAs<AnyFunctionType>(); auto *HandlerTy = HandlerDesc.getType()->getAs<AnyFunctionType>();
if (!HandlerTy || !HandlerTy->getResult()->isVoid()) if (!HandlerTy || !HandlerTy->getResult()->isVoid())
return AsyncHandlerDesc(); return AsyncHandlerDesc();
@@ -4126,7 +4125,7 @@ struct AsyncHandlerDesc {
if (!HandlerParams.empty()) { if (!HandlerParams.empty()) {
auto LastParamTy = HandlerParams.back().getParameterType(); auto LastParamTy = HandlerParams.back().getParameterType();
HandlerDesc.HasError = isErrorType(LastParamTy->getOptionalObjectType(), HandlerDesc.HasError = isErrorType(LastParamTy->getOptionalObjectType(),
FD->getModuleContext()); Handler->getModuleContext());
} }
} }
@@ -4135,8 +4134,55 @@ struct AsyncHandlerDesc {
bool isValid() const { return Type != HandlerType::INVALID; } bool isValid() const { return Type != HandlerType::INVALID; }
/// Return the declaration of the completion handler as a \c ValueDecl.
/// In practice, the handler will always be a \c VarDecl or \c
/// AbstractFunctionDecl.
/// \c getNameStr and \c getType provide access functions that are available
/// for both variables and functions, but not on \c ValueDecls.
const ValueDecl *getHandler() const {
if (!Handler) {
return nullptr;
}
if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
return Var;
} else if (auto Func = Handler.dyn_cast<const AbstractFunctionDecl *>()) {
return Func;
} else {
llvm_unreachable("Unknown handler type");
}
}
/// Return the name of the completion handler. If it is a variable, the
/// variable name, if it's a function, the function base name.
StringRef getNameStr() const {
if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
return Var->getNameStr();
} else if (auto Func = Handler.dyn_cast<const AbstractFunctionDecl *>()) {
return Func->getNameStr();
} else {
llvm_unreachable("Unknown handler type");
}
}
/// Get the type of the completion handler.
swift::Type getType() const {
if (auto Var = Handler.dyn_cast<const VarDecl *>()) {
return Var->getType();
} else if (auto Func = Handler.dyn_cast<const AbstractFunctionDecl *>()) {
auto Type = Func->getInterfaceType();
// Undo the self curry thunk if we are referencing a member function.
if (Func->hasImplicitSelfDecl()) {
assert(Type->is<AnyFunctionType>());
Type = Type->getAs<AnyFunctionType>()->getResult();
}
return Type;
} else {
llvm_unreachable("Unknown handler type");
}
}
ArrayRef<AnyFunctionType::Param> params() const { ArrayRef<AnyFunctionType::Param> params() const {
auto Ty = Handler->getType()->getAs<AnyFunctionType>(); auto Ty = getType()->getAs<AnyFunctionType>();
assert(Ty && "Type must be a function type"); assert(Ty && "Type must be a function type");
return Ty->getParams(); return Ty->getParams();
} }
@@ -4186,7 +4232,7 @@ struct AsyncHandlerDesc {
if (Node.isExpr(swift::ExprKind::Call)) { if (Node.isExpr(swift::ExprKind::Call)) {
CallExpr *CE = cast<CallExpr>(Node.dyn_cast<Expr *>()); CallExpr *CE = cast<CallExpr>(Node.dyn_cast<Expr *>());
if (CE->getFn()->getReferencedDecl().getDecl() == Handler) if (CE->getFn()->getReferencedDecl().getDecl() == getHandler())
return CE; return CE;
} }
return nullptr; return nullptr;
@@ -4290,6 +4336,39 @@ struct AsyncHandlerDesc {
} }
}; };
/// Given a completion handler that is part of a function signature, stores
/// information about that completion handler and its index within the function
/// declaration.
struct AsyncHandlerParamDesc : public AsyncHandlerDesc {
/// The index of the completion handler in the function that declares it.
int Index = -1;
AsyncHandlerParamDesc() : AsyncHandlerDesc() {}
AsyncHandlerParamDesc(const AsyncHandlerDesc &Handler, int Index)
: AsyncHandlerDesc(Handler), Index(Index) {}
static AsyncHandlerParamDesc find(const FuncDecl *FD, bool ignoreName) {
if (!FD || FD->hasAsync() || FD->hasThrows())
return AsyncHandlerParamDesc();
// Require at least one parameter and void return type
auto *Params = FD->getParameters();
if (Params->size() == 0 || !FD->getResultInterfaceType()->isVoid())
return AsyncHandlerParamDesc();
// Assume the handler is the last parameter for now
int Index = Params->size() - 1;
const ParamDecl *Param = Params->get(Index);
// Callback must not be attributed with @autoclosure
if (Param->isAutoClosure())
return AsyncHandlerParamDesc();
return AsyncHandlerParamDesc(AsyncHandlerDesc::get(Param, ignoreName),
Index);
}
};
enum class ConditionType { INVALID, NIL, NOT_NIL }; enum class ConditionType { INVALID, NIL, NOT_NIL };
/// Finds the `Subject` being compared to in various conditions. Also finds any /// Finds the `Subject` being compared to in various conditions. Also finds any
@@ -4795,7 +4874,7 @@ class AsyncConverter : private SourceEntityWalker {
// Completion handler of `StartNode` (if it's a function with an async // Completion handler of `StartNode` (if it's a function with an async
// alternative) // alternative)
const AsyncHandlerDesc &TopHandler; const AsyncHandlerParamDesc &TopHandler;
SmallString<0> Buffer; SmallString<0> Buffer;
llvm::raw_svector_ostream OS; llvm::raw_svector_ostream OS;
@@ -4821,10 +4900,10 @@ class AsyncConverter : private SourceEntityWalker {
public: public:
AsyncConverter(SourceManager &SM, DiagnosticEngine &DiagEngine, AsyncConverter(SourceManager &SM, DiagnosticEngine &DiagEngine,
ASTNode StartNode, const AsyncHandlerDesc &TopHandler) ASTNode StartNode, const AsyncHandlerParamDesc &TopHandler)
: SM(SM), DiagEngine(DiagEngine), StartNode(StartNode), : SM(SM), DiagEngine(DiagEngine), StartNode(StartNode),
TopHandler(TopHandler), Buffer(), OS(Buffer) { TopHandler(TopHandler), Buffer(), OS(Buffer) {
Placeholders.insert(TopHandler.Handler); Placeholders.insert(TopHandler.getHandler());
} }
bool convert() { bool convert() {
@@ -4855,40 +4934,16 @@ public:
return false; return false;
} }
FuncDecl *FD = cast<FuncDecl>(StartNode.get<Decl *>()); FuncDecl *FD = cast<FuncDecl>(StartNode.get<Decl *>());
Identifier CompletionHandlerName = TopHandler.Handler->getParameterName();
OS << tok::l_brace << "\n"; // start function body OS << tok::l_brace << "\n"; // start function body
OS << "async " << tok::l_brace << "\n"; OS << "async " << tok::l_brace << "\n";
if (TopHandler.HasError) { addHoistedNamedCallback(FD, TopHandler, TopHandler.getNameStr(), [&]() {
addDo(); if (TopHandler.HasError) {
if (!TopHandler.willAsyncReturnVoid()) { OS << tok::kw_try << " ";
OS << tok::kw_let << " result";
addResultTypeAnnotationIfNecessary(FD, TopHandler);
OS << " " << tok::equal << " ";
}
OS << tok::kw_try << " await ";
addCallToAsyncMethod(FD, TopHandler);
OS << "\n";
addCallToCompletionHandler(/*HasResult=*/true, CompletionHandlerName, FD,
TopHandler);
OS << "\n"
<< tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n";
addCallToCompletionHandler(/*HasResult=*/false, CompletionHandlerName, FD,
TopHandler);
OS << "\n" << tok::r_brace << "\n"; // end catch
} else {
if (!TopHandler.willAsyncReturnVoid()) {
OS << tok::kw_let << " result";
addResultTypeAnnotationIfNecessary(FD, TopHandler);
OS << " " << tok::equal << " ";
} }
OS << "await "; OS << "await ";
addCallToAsyncMethod(FD, TopHandler); addCallToAsyncMethod(FD, TopHandler);
OS << "\n"; });
addCallToCompletionHandler(/*HasResult=*/true, CompletionHandlerName, FD,
TopHandler);
OS << "\n";
}
OS << tok::r_brace << "\n"; // end 'async' OS << tok::r_brace << "\n"; // end 'async'
OS << tok::r_brace << "\n"; // end function body OS << tok::r_brace << "\n"; // end function body
return true; return true;
@@ -5026,7 +5081,7 @@ private:
return addCustom(CE->getSourceRange(), [&]() { addHandlerCall(CE); }); return addCustom(CE->getSourceRange(), [&]() { addHandlerCall(CE); });
if (auto *CE = dyn_cast<CallExpr>(E)) { if (auto *CE = dyn_cast<CallExpr>(E)) {
auto HandlerDesc = AsyncHandlerDesc::find( auto HandlerDesc = AsyncHandlerParamDesc::find(
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE); getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
if (HandlerDesc.isValid()) if (HandlerDesc.isValid())
return addCustom(CE->getSourceRange(), return addCustom(CE->getSourceRange(),
@@ -5220,33 +5275,73 @@ private:
/// From the given expression \p E, which is an argument to a function call, /// From the given expression \p E, which is an argument to a function call,
/// extract the passed closure if there is one. Otherwise return \c nullptr. /// extract the passed closure if there is one. Otherwise return \c nullptr.
ClosureExpr *extractCallback(Expr *E) { ClosureExpr *extractCallback(Expr *E) {
E = lookThroughFunctionConversionExpr(E);
if (auto Closure = dyn_cast<ClosureExpr>(E)) { if (auto Closure = dyn_cast<ClosureExpr>(E)) {
return Closure; return Closure;
} else if (auto CaptureList = dyn_cast<CaptureListExpr>(E)) { } else if (auto CaptureList = dyn_cast<CaptureListExpr>(E)) {
return CaptureList->getClosureBody(); return CaptureList->getClosureBody();
} else if (auto FunctionConversion = dyn_cast<FunctionConversionExpr>(E)) {
// Closure arguments marked as e.g. `@convention(block)` produce arguments
// that are `FunctionConversionExpr`.
return extractCallback(FunctionConversion->getSubExpr());
} else { } else {
return nullptr; return nullptr;
} }
} }
/// Callback arguments marked as e.g. `@convention(block)` produce arguments
/// that are `FunctionConversionExpr`.
/// We don't care about the conversions and want to shave them off.
Expr *lookThroughFunctionConversionExpr(Expr *E) {
if (auto FunctionConversion = dyn_cast<FunctionConversionExpr>(E)) {
return lookThroughFunctionConversionExpr(
FunctionConversion->getSubExpr());
} else {
return E;
}
}
void addHoistedCallback(const CallExpr *CE, void addHoistedCallback(const CallExpr *CE,
const AsyncHandlerDesc &HandlerDesc) { const AsyncHandlerParamDesc &HandlerDesc) {
auto ArgList = callArgs(CE); auto ArgList = callArgs(CE);
if ((size_t)HandlerDesc.Index >= ArgList.ref().size()) { if ((size_t)HandlerDesc.Index >= ArgList.ref().size()) {
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg); DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
return; return;
} }
ClosureExpr *Callback = extractCallback(ArgList.ref()[HandlerDesc.Index]); Expr *CallbackArg =
if (!Callback) { lookThroughFunctionConversionExpr(ArgList.ref()[HandlerDesc.Index]);
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg); if (ClosureExpr *Callback = extractCallback(CallbackArg)) {
// The user is using a closure for the completion handler
addHoistedClosureCallback(CE, HandlerDesc, Callback, ArgList);
return; return;
} }
if (auto CallbackDecl = getReferencedDecl(CallbackArg)) {
// The completion handler that is called as part of the \p CE call.
// This will be called once the async function returns.
auto CompletionHandler = AsyncHandlerDesc::get(CallbackDecl,
/*ignoreName=*/true);
if (CompletionHandler.isValid()) {
if (auto CalledFunc = getUnderlyingFunc(CE->getFn())) {
StringRef HandlerName = Lexer::getCharSourceRangeFromSourceRange(
SM, CallbackArg->getSourceRange())
.str();
addHoistedNamedCallback(
CalledFunc, CompletionHandler, HandlerName, [&] {
addAwaitCall(CE, ArgList.ref(), ClassifiedBlock(), {},
HandlerDesc, /*AddDeclarations=*/false);
});
return;
}
}
}
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
}
/// Add a call to the async alternative of \p CE and convert the \p Callback
/// to be executed after the async call. \p HandlerDesc describes the
/// completion handler in the function that's called by \p CE and \p ArgList
/// are the arguments being passed in \p CE.
void addHoistedClosureCallback(const CallExpr *CE,
const AsyncHandlerDesc &HandlerDesc,
const ClosureExpr *Callback,
PtrArrayRef<Expr *> ArgList) {
ArrayRef<const ParamDecl *> CallbackParams = ArrayRef<const ParamDecl *> CallbackParams =
Callback->getParameters()->getArray(); Callback->getParameters()->getArray();
ArrayRef<ASTNode> CallbackBody = Callback->getBody()->getElements(); ArrayRef<ASTNode> CallbackBody = Callback->getBody()->getElements();
@@ -5350,6 +5445,44 @@ private:
} }
} }
/// Add a call to the async alternative of \p FD. Afterwards, pass the results
/// of the async call to the completion handler, named \p HandlerName and
/// described by \p HandlerDesc.
/// \p AddAwaitCall adds the call to the refactored async method to the output
/// stream without storing the result to any variables.
/// This is used when the user didn't use a closure for the callback, but
/// passed in a variable or function name for the completion handler.
void addHoistedNamedCallback(const FuncDecl *FD,
const AsyncHandlerDesc &HandlerDesc,
StringRef HandlerName,
std::function<void(void)> AddAwaitCall) {
if (HandlerDesc.HasError) {
addDo();
if (!HandlerDesc.willAsyncReturnVoid()) {
OS << tok::kw_let << " result";
addResultTypeAnnotationIfNecessary(FD, HandlerDesc);
OS << " " << tok::equal << " ";
}
AddAwaitCall();
OS << "\n";
addCallToCompletionHandler(/*HasResult=*/true, HandlerDesc, HandlerName);
OS << "\n";
OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n";
addCallToCompletionHandler(/*HasResult=*/false, HandlerDesc, HandlerName);
OS << "\n" << tok::r_brace << "\n"; // end catch
} else {
if (!HandlerDesc.willAsyncReturnVoid()) {
OS << tok::kw_let << " result";
addResultTypeAnnotationIfNecessary(FD, HandlerDesc);
OS << " " << tok::equal << " ";
}
AddAwaitCall();
OS << "\n";
addCallToCompletionHandler(/*HasResult=*/true, HandlerDesc, HandlerName);
OS << "\n";
}
}
void addAwaitCall(const CallExpr *CE, ArrayRef<Expr *> Args, void addAwaitCall(const CallExpr *CE, ArrayRef<Expr *> Args,
const ClassifiedBlock &SuccessBlock, const ClassifiedBlock &SuccessBlock,
ArrayRef<const ParamDecl *> SuccessParams, ArrayRef<const ParamDecl *> SuccessParams,
@@ -5512,7 +5645,7 @@ private:
OS << FD->getBaseName() << tok::l_paren; OS << FD->getBaseName() << tok::l_paren;
bool FirstParam = true; bool FirstParam = true;
for (auto Param : *FD->getParameters()) { for (auto Param : *FD->getParameters()) {
if (Param == HandlerDesc.Handler) { if (Param == HandlerDesc.getHandler()) {
/// We don't need to pass the completion handler to the async method. /// We don't need to pass the completion handler to the async method.
continue; continue;
} }
@@ -5532,8 +5665,9 @@ private:
/// If the error type of \p HandlerDesc is more specialized than \c Error, /// If the error type of \p HandlerDesc is more specialized than \c Error,
/// adds an 'as! CustomError' cast to the more specialized error type to the /// adds an 'as! CustomError' cast to the more specialized error type to the
/// output stream. /// output stream.
void addCastToCustomErrorTypeIfNecessary(const AsyncHandlerDesc &HandlerDesc, void
const ASTContext &Ctx) { addCastToCustomErrorTypeIfNecessary(const AsyncHandlerDesc &HandlerDesc) {
const ASTContext &Ctx = HandlerDesc.getHandler()->getASTContext();
auto ErrorType = *HandlerDesc.getErrorType(); auto ErrorType = *HandlerDesc.getErrorType();
if (ErrorType->getCanonicalType() != Ctx.getExceptionType()) { if (ErrorType->getCanonicalType() != Ctx.getExceptionType()) {
OS << " " << tok::kw_as << tok::exclaim_postfix << " "; OS << " " << tok::kw_as << tok::exclaim_postfix << " ";
@@ -5541,22 +5675,21 @@ private:
} }
} }
/// Adds the \c Index -th parameter to the completion handler of \p FD. /// Adds the \c Index -th parameter to the completion handler described by \p
/// \p HanderDesc describes which of \p FD's parameters is the completion /// HanderDesc.
/// handler. If \p HasResult is \c true, it is assumed that a variable named /// If \p HasResult is \c true, it is assumed that a variable named
/// 'result' contains the result returned from the async alternative. If the /// 'result' contains the result returned from the async alternative. If the
/// callback also takes an error parameter, \c nil passed to the completion /// callback also takes an error parameter, \c nil passed to the completion
/// handler for the error. If \p HasResult is \c false, it is a assumed that a /// handler for the error. If \p HasResult is \c false, it is a assumed that a
/// variable named 'error' contains the error thrown from the async method and /// variable named 'error' contains the error thrown from the async method and
/// 'nil' will be passed to the completion handler for all result parameters. /// 'nil' will be passed to the completion handler for all result parameters.
void addCompletionHandlerArgument(size_t Index, bool HasResult, void addCompletionHandlerArgument(size_t Index, bool HasResult,
const FuncDecl *FD,
const AsyncHandlerDesc &HandlerDesc) { const AsyncHandlerDesc &HandlerDesc) {
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) { if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
// The error parameter is the last argument of the completion handler. // The error parameter is the last argument of the completion handler.
if (!HasResult) { if (!HasResult) {
OS << "error"; OS << "error";
addCastToCustomErrorTypeIfNecessary(HandlerDesc, FD->getASTContext()); addCastToCustomErrorTypeIfNecessary(HandlerDesc);
} else { } else {
OS << tok::kw_nil; OS << tok::kw_nil;
} }
@@ -5591,15 +5724,12 @@ private:
} }
} }
/// If the completion handler of a call to \p FD is named \p HandlerName, /// Add a call to the completion handler named \p HandlerName and described by
/// add a call to \p HandlerName passing all the required arguments. \p /// \p HandlerDesc, passing all the required arguments. See \c
/// HandlerDesc describes which of \p FD's parameters is the completion /// getCompletionHandlerArgument for how the arguments are synthesized.
/// handler hat is being called. See \c getCompletionHandlerArgument for how void addCallToCompletionHandler(bool HasResult,
/// the arguments are synthesized if the completion handler takes arguments, const AsyncHandlerDesc &HandlerDesc,
/// not a \c Result type. StringRef HandlerName) {
void addCallToCompletionHandler(bool HasResult, Identifier HandlerName,
const FuncDecl *FD,
const AsyncHandlerDesc &HandlerDesc) {
OS << HandlerName << tok::l_paren; OS << HandlerName << tok::l_paren;
// Construct arguments to pass to the completion handler // Construct arguments to pass to the completion handler
@@ -5612,7 +5742,7 @@ private:
if (I > 0) { if (I > 0) {
OS << tok::comma << " "; OS << tok::comma << " ";
} }
addCompletionHandlerArgument(I, HasResult, FD, HandlerDesc); addCompletionHandlerArgument(I, HasResult, HandlerDesc);
} }
break; break;
} }
@@ -5622,7 +5752,7 @@ private:
<< tok::r_paren; << tok::r_paren;
} else { } else {
OS << tok::period_prefix << "failure" << tok::l_paren << "error"; OS << tok::period_prefix << "failure" << tok::l_paren << "error";
addCastToCustomErrorTypeIfNecessary(HandlerDesc, FD->getASTContext()); addCastToCustomErrorTypeIfNecessary(HandlerDesc);
OS << tok::r_paren; OS << tok::r_paren;
} }
break; break;
@@ -5689,8 +5819,8 @@ bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
if (!CE) if (!CE)
return false; return false;
auto HandlerDesc = AsyncHandlerDesc::find(getUnderlyingFunc(CE->getFn()), auto HandlerDesc = AsyncHandlerParamDesc::find(getUnderlyingFunc(CE->getFn()),
/*ignoreName=*/true); /*ignoreName=*/true);
return HandlerDesc.isValid(); return HandlerDesc.isValid();
} }
@@ -5709,7 +5839,7 @@ bool RefactoringActionConvertCallToAsyncAlternative::performChange() {
assert(CE && assert(CE &&
"Should not run performChange when refactoring is not applicable"); "Should not run performChange when refactoring is not applicable");
AsyncHandlerDesc TempDesc; AsyncHandlerParamDesc TempDesc;
AsyncConverter Converter(SM, DiagEngine, CE, TempDesc); AsyncConverter Converter(SM, DiagEngine, CE, TempDesc);
if (!Converter.convert()) if (!Converter.convert())
return true; return true;
@@ -5737,7 +5867,7 @@ bool RefactoringActionConvertToAsync::performChange() {
assert(FD && assert(FD &&
"Should not run performChange when refactoring is not applicable"); "Should not run performChange when refactoring is not applicable");
auto HandlerDesc = AsyncHandlerDesc::find(FD, /*ignoreName=*/true); auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
AsyncConverter Converter(SM, DiagEngine, FD, HandlerDesc); AsyncConverter Converter(SM, DiagEngine, FD, HandlerDesc);
if (!Converter.convert()) if (!Converter.convert())
return true; return true;
@@ -5754,7 +5884,7 @@ bool RefactoringActionAddAsyncAlternative::isApplicable(
if (!FD) if (!FD)
return false; return false;
auto HandlerDesc = AsyncHandlerDesc::find(FD, /*ignoreName=*/true); auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
return HandlerDesc.isValid(); return HandlerDesc.isValid();
} }
@@ -5771,7 +5901,7 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
assert(FD && assert(FD &&
"Should not run performChange when refactoring is not applicable"); "Should not run performChange when refactoring is not applicable");
auto HandlerDesc = AsyncHandlerDesc::find(FD, /*ignoreName=*/true); auto HandlerDesc = AsyncHandlerParamDesc::find(FD, /*ignoreName=*/true);
assert(HandlerDesc.isValid() && assert(HandlerDesc.isValid() &&
"Should not run performChange when refactoring is not applicable"); "Should not run performChange when refactoring is not applicable");

View File

@@ -0,0 +1,233 @@
enum CustomError: Error {
case invalid
case insecure
}
typealias SomeCallback = (String) -> Void
func simple(completion: (String) -> Void) { }
func simpleWithArg(a: Int, completion: (String) -> Void) { }
func multipleResults(completion: (String, Int) -> Void) { }
func nonOptionalError(completion: (String, Error) -> Void) { }
func noParams(completion: () -> Void) { }
func error(completion: (String?, Error?) -> Void) { }
func errorOnly(completion: (Error?) -> Void) { }
func errorNonOptionalResult(completion: (String, Error?) -> Void) { }
func alias(completion: SomeCallback) { }
func simpleResult(completion: (Result<String, Never>) -> Void) { }
func errorResult(completion: (Result<String, Error>) -> Void) { }
func customErrorResult(completion: (Result<String, CustomError>) -> Void) { }
func optionalSingle(completion: (String?) -> Void) { }
func manyOptional(_ completion: (String?, Int?) -> Void) { }
func generic<T, R>(completion: (T, R) -> Void) { }
func genericResult<T>(completion: (T?, Error?) -> Void) where T: Numeric { }
func genericError<E>(completion: (String?, E?) -> Void) where E: Error { }
func defaultArgs(a: Int, b: Int = 10, completion: (String) -> Void) { }
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=SIMPLE-WITH-VARIABLE-COMPLETION-HANDLER %s
func testSimpleWithVariableCompletionHandler(completionHandler: (String) -> Void) {
simple(completion: completionHandler)
}
// SIMPLE-WITH-VARIABLE-COMPLETION-HANDLER: let result = await simple()
// SIMPLE-WITH-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=SIMPLE-WITH-ARG-VARIABLE-COMPLETION-HANDLER %s
func testSimpleWithArgVariableCompletionHandler(b: Int, completionHandler: (String) -> Void) {
simpleWithArg(a: b, completion: completionHandler)
}
// SIMPLE-WITH-ARG-VARIABLE-COMPLETION-HANDLER: let result = await simpleWithArg(a: b)
// SIMPLE-WITH-ARG-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=SIMPLE-WITH-CONSTANT-ARG-VARIABLE-COMPLETION-HANDLER %s
func testSimpleWithConstantArgVariableCompletionHandler(completionHandler: (String) -> Void) {
simpleWithArg(a: 1, completion: completionHandler)
}
// SIMPLE-WITH-CONSTANT-ARG-VARIABLE-COMPLETION-HANDLER: let result = await simpleWithArg(a: 1)
// SIMPLE-WITH-CONSTANT-ARG-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=MULTIPLE-RESULTS-VARIABLE-COMPLETION-HANDLER %s
func testMultipleResultsVariableCompletionHandler(completionHandler: (String, Int) -> Void) {
multipleResults(completion: completionHandler)
}
// MULTIPLE-RESULTS-VARIABLE-COMPLETION-HANDLER: let result = await multipleResults()
// MULTIPLE-RESULTS-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result.0, result.1)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=NON-OPTIONAL-ERROR-VARIABLE-COMPLETION-HANDLER %s
func testNonOptionalErrorVariableCompletionHandler(completionHandler: (String, Error) -> Void) {
nonOptionalError(completion: completionHandler)
}
// NON-OPTIONAL-ERROR-VARIABLE-COMPLETION-HANDLER: let result = await nonOptionalError()
// NON-OPTIONAL-ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result.0, result.1)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=NO-PARAMS-VARIABLE-COMPLETION-HANDLER %s
func testNoParamsVariableCompletionHandler(completionHandler: () -> Void) {
noParams(completion: completionHandler)
}
// NO-PARAMS-VARIABLE-COMPLETION-HANDLER: await noParams()
// NO-PARAMS-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler()
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=ERROR-VARIABLE-COMPLETION-HANDLER %s
func testErrorWithVariableCompletionHandler(completionHandler: (String?, Error?) -> Void) {
error(completion: completionHandler)
}
// ERROR-VARIABLE-COMPLETION-HANDLER: do {
// ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: let result = try await error()
// ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result, nil)
// ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: } catch {
// ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(nil, error)
// ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: }
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=ERROR-ONLY-VARIABLE-COMPLETION-HANDLER %s
func testErrorOnlyWithVariableCompletionHandler(completionHandler: (Error?) -> Void) {
errorOnly(completion: completionHandler)
}
// ERROR-ONLY-VARIABLE-COMPLETION-HANDLER: do {
// ERROR-ONLY-VARIABLE-COMPLETION-HANDLER-NEXT: try await errorOnly()
// ERROR-ONLY-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(nil)
// ERROR-ONLY-VARIABLE-COMPLETION-HANDLER-NEXT: } catch {
// ERROR-ONLY-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(error)
// ERROR-ONLY-VARIABLE-COMPLETION-HANDLER-NEXT: }
// FIXME: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3
func testErrorNonOptionalResultWithVariableCompletionHandler(completionHandler: (String, Error?) -> Void) {
errorNonOptionalResult(completion: completionHandler)
}
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=ALIAS-VARIABLE-COMPLETION-HANDLER %s
func testAliasWithVariableCompletionHandler(completionHandler: SomeCallback) {
alias(completion: completionHandler)
}
// ALIAS-VARIABLE-COMPLETION-HANDLER: let result = await alias()
// ALIAS-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=SIMPLE-RESULT-VARIABLE-COMPLETION-HANDLER %s
func testSimpleResultVariableCompletionHandler(completionHandler: (Result<String, Never>) -> Void) {
simpleResult(completion: completionHandler)
}
// SIMPLE-RESULT-VARIABLE-COMPLETION-HANDLER: let result = await simpleResult()
// SIMPLE-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(.success(result))
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=ERROR-RESULT-VARIABLE-COMPLETION-HANDLER %s
func testErrorResultVariableCompletionHandler(completionHandler: (Result<String, Error>) -> Void) {
errorResult(completion: completionHandler)
}
// ERROR-RESULT-VARIABLE-COMPLETION-HANDLER: do {
// ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: let result = try await errorResult()
// ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(.success(result))
// ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: } catch {
// ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(.failure(error))
// ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: }
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=CUSTOM-ERROR-RESULT-VARIABLE-COMPLETION-HANDLER %s
func testErrorResultVariableCompletionHandler(completionHandler: (Result<String, CustomError>) -> Void) {
customErrorResult(completion: completionHandler)
}
// CUSTOM-ERROR-RESULT-VARIABLE-COMPLETION-HANDLER: do {
// CUSTOM-ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: let result = try await customErrorResult()
// CUSTOM-ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(.success(result))
// CUSTOM-ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: } catch {
// CUSTOM-ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(.failure(error as! CustomError))
// CUSTOM-ERROR-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: }
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=OPTIONAL-SINGLE-VARIABLE-COMPLETION-HANDLER %s
func testOptionalSingleVariableCompletionHandler(completionHandler: (String?) -> Void) {
optionalSingle(completion: completionHandler)
}
// OPTIONAL-SINGLE-VARIABLE-COMPLETION-HANDLER: let result = await optionalSingle()
// OPTIONAL-SINGLE-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=MANY-OPTIONAL-VARIABLE-COMPLETION-HANDLER %s
func testManyOptionalVariableCompletionHandler(completionHandler: (String?, Int?) -> Void) {
manyOptional(completionHandler)
}
// MANY-OPTIONAL-VARIABLE-COMPLETION-HANDLER: let result = await manyOptional()
// MANY-OPTIONAL-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result.0, result.1)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=GENERIC-VARIABLE-COMPLETION-HANDLER %s
func testGenericVariableCompletionHandler<T, R>(completionHandler: (T, R) -> Void) {
generic(completion: completionHandler)
}
// GENERIC-VARIABLE-COMPLETION-HANDLER: let result: (T, R) = await generic()
// GENERIC-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result.0, result.1)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=SPECIALIZE-GENERIC-VARIABLE-COMPLETION-HANDLER %s
func testSpecializeGenericsVariableCompletionHandler(completionHandler: (String, Int) -> Void) {
generic(completion: completionHandler)
}
// SPECIALIZE-GENERIC-VARIABLE-COMPLETION-HANDLER: let result: (String, Int) = await generic()
// SPECIALIZE-GENERIC-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result.0, result.1)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=GENERIC-RESULT-VARIABLE-COMPLETION-HANDLER %s
func testGenericResultVariableCompletionHandler<T>(completionHandler: (T?, Error?) -> Void) where T: Numeric {
genericResult(completion: completionHandler)
}
// GENERIC-RESULT-VARIABLE-COMPLETION-HANDLER: do {
// GENERIC-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: let result: T = try await genericResult()
// GENERIC-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result, nil)
// GENERIC-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: } catch {
// GENERIC-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(nil, error)
// GENERIC-RESULT-VARIABLE-COMPLETION-HANDLER-NEXT: }
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=GENERIC-ERROR-VARIABLE-COMPLETION-HANDLER %s
func testGenericErrorVariableCompletionHandler<MyGenericError>(completionHandler: (String?, MyGenericError?) -> Void) where MyGenericError: Error {
genericError(completion: completionHandler)
}
// GENERIC-ERROR-VARIABLE-COMPLETION-HANDLER: do {
// GENERIC-ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: let result: String = try await genericError()
// GENERIC-ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result, nil)
// GENERIC-ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: } catch {
// GENERIC-ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(nil, error as! MyGenericError)
// GENERIC-ERROR-VARIABLE-COMPLETION-HANDLER-NEXT: }
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=DEFAULT-ARGS-VARIABLE-COMPLETION-HANDLER %s
func testDefaultArgsVariableCompletionHandler(completionHandler: (String) -> Void) {
defaultArgs(a: 5, completion: completionHandler)
}
// DEFAULT-ARGS-VARIABLE-COMPLETION-HANDLER: let result = await defaultArgs(a: 5)
// DEFAULT-ARGS-VARIABLE-COMPLETION-HANDLER-NEXT: completionHandler(result)
func myPrint(_ message: String) {
print(message)
}
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):3 | %FileCheck -check-prefix=GLOBAL-FUNC-AS-COMPLETION-HANDLER %s
func testGlobalFuncAsCompletionHandler() {
simple(completion: myPrint)
}
// GLOBAL-FUNC-AS-COMPLETION-HANDLER: let result = await simple()
// GLOBAL-FUNC-AS-COMPLETION-HANDLER-NEXT: myPrint(result)
class Foo {
var foo: Foo
init(foo: Foo) {
self.foo = foo
}
func myFooPrint(_ message: String) {
print("FOO: \(message)")
}
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):5 | %FileCheck -check-prefix=MEMBER-FUNC-AS-COMPLETION-HANDLER %s
func testMethodAsCompletionHandler() {
simple(completion: myFooPrint)
}
// MEMBER-FUNC-AS-COMPLETION-HANDLER: let result = await simple()
// MEMBER-FUNC-AS-COMPLETION-HANDLER-NEXT: myFooPrint(result)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):5 | %FileCheck -check-prefix=MEMBER-FUNC-ON-OTHER-OBJECT-AS-COMPLETION-HANDLER %s
func testMethodOnOtherObjectAsCompletionHandler(foo: Foo) {
simple(completion: foo.myFooPrint)
}
// MEMBER-FUNC-ON-OTHER-OBJECT-AS-COMPLETION-HANDLER: let result = await simple()
// MEMBER-FUNC-ON-OTHER-OBJECT-AS-COMPLETION-HANDLER-NEXT: foo.myFooPrint(result)
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+2):5 | %FileCheck -check-prefix=MEMBER-FUNC-ON-NESTED-OTHER-OBJECT-AS-COMPLETION-HANDLER %s
func testMethodOnNestedOtherObjectAsCompletionHandler(foo: Foo) {
simple(completion: foo.foo.myFooPrint)
}
// MEMBER-FUNC-ON-NESTED-OTHER-OBJECT-AS-COMPLETION-HANDLER: let result = await simple()
// MEMBER-FUNC-ON-NESTED-OTHER-OBJECT-AS-COMPLETION-HANDLER-NEXT: foo.foo.myFooPrint(result)
}