Merge pull request #37185 from ahoppen/pr/legacy-async-method-refactor

[Refactoring] When adding an async alternative refactor the old method to call the async method using `async`
This commit is contained in:
Alex Hoppen
2021-05-04 18:33:25 +02:00
committed by GitHub
5 changed files with 666 additions and 40 deletions

View File

@@ -4149,6 +4149,36 @@ struct AsyncHandlerDesc {
return params();
}
/// Get the type of the error that will be thrown by the \c async method or \c
/// None if the completion handler doesn't accept an error parameter.
/// This may be more specialized than the generic 'Error' type if the
/// completion handler of the converted function takes a more specialized
/// error type.
Optional<swift::Type> getErrorType() const {
if (HasError) {
switch (Type) {
case HandlerType::INVALID:
return None;
case HandlerType::PARAMS:
// The last parameter of the completion handler is the error param
return params().back().getPlainType()->lookThroughSingleOptionalType();
case HandlerType::RESULT:
assert(
params().size() == 1 &&
"Result handler should have the Result type as the only parameter");
auto ResultType =
params().back().getPlainType()->getAs<BoundGenericType>();
auto GenericArgs = ResultType->getGenericArgs();
assert(GenericArgs.size() == 2 && "Result should have two params");
// The second (last) generic parameter of the Result type is the error
// type.
return GenericArgs.back();
}
} else {
return None;
}
}
/// The `CallExpr` if the given node is a call to the `Handler`
CallExpr *getAsHandlerCall(ASTNode Node) const {
if (!isValid())
@@ -5319,6 +5349,262 @@ private:
}
}
};
/// When adding an async alternative method for the function declaration \c FD,
/// this class tries to create a function body for the legacy function (the one
/// with a completion handler), which calls the newly converted async function.
/// There are certain situations in which we fail to create such a body, e.g.
/// if the completion handler has the signature `(String, Error?) -> Void` in
/// which case we can't synthesize the result of type \c String in the error
/// case.
class LegacyAlternativeBodyCreator {
/// The old function declaration for which an async alternative has been added
/// and whose body shall be rewritten to call the newly added async
/// alternative.
FuncDecl *FD;
/// The description of the completion handler in the old function declaration.
AsyncHandlerDesc HandlerDesc;
std::string Buffer;
llvm::raw_string_ostream OS;
/// Adds the call to the refactored 'async' method without the 'await'
/// keyword to the output stream.
void addCallToAsyncMethod() {
OS << FD->getBaseName() << "(";
bool FirstParam = true;
for (auto Param : *FD->getParameters()) {
if (Param == HandlerDesc.Handler) {
/// We don't need to pass the completion handler to the async method.
continue;
}
if (!FirstParam) {
OS << ", ";
} else {
FirstParam = false;
}
if (!Param->getArgumentName().empty()) {
OS << Param->getArgumentName() << ": ";
}
OS << Param->getParameterName();
}
OS << ")";
}
/// If the returned error type is more specialized than \c Error, adds an
/// 'as! CustomError' cast to the more specialized error type to the output
/// stream.
void addCastToCustomErrorTypeIfNecessary() {
auto ErrorType = *HandlerDesc.getErrorType();
if (ErrorType->getCanonicalType() !=
FD->getASTContext().getExceptionType()) {
OS << " as! ";
ErrorType->lookThroughSingleOptionalType()->print(OS);
}
}
/// Adds the \c Index -th parameter to the completion handler.
/// If \p HasResult is \c true, it is assumed that a variable named 'result'
/// contains the result returned from the async alternative. If the callback
/// also takes an error parameter, \c nil passed to the completion handler for
/// the error.
/// If \p HasResult is \c false, it is a assumed that a variable named 'error'
/// contains the error thrown from the async method and 'nil' will be passed
/// to the completion handler for all result parameters.
void addCompletionHandlerArgument(size_t Index, bool HasResult) {
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
// The error parameter is the last argument of the completion handler.
if (!HasResult) {
OS << "error";
addCastToCustomErrorTypeIfNecessary();
} else {
OS << "nil";
}
} else {
if (!HasResult) {
OS << "nil";
} else if (HandlerDesc
.getSuccessParamAsyncReturnType(
HandlerDesc.params()[Index].getPlainType())
->isVoid()) {
// Void return types are not returned by the async function, synthesize
// a Void instance.
OS << "()";
} else if (HandlerDesc.getSuccessParams().size() > 1) {
// If the async method returns a tuple, we need to pass its elements to
// the completion handler separately. For example:
//
// func foo() async -> (String, Int) {}
//
// causes the following legacy body to be created:
//
// func foo(completion: (String, Int) -> Void) {
// async {
// let result = await foo()
// completion(result.0, result.1)
// }
// }
OS << "result." << Index;
} else {
OS << "result";
}
}
}
/// Adds the call to the completion handler. See \c
/// getCompletionHandlerArgument for how the arguments are synthesized if the
/// completion handler takes arguments, not a \c Result type.
void addCallToCompletionHandler(bool HasResult) {
OS << HandlerDesc.Handler->getParameterName() << "(";
// Construct arguments to pass to the completion handler
switch (HandlerDesc.Type) {
case HandlerType::INVALID:
llvm_unreachable("Cannot be rewritten");
break;
case HandlerType::PARAMS: {
for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
if (I > 0) {
OS << ", ";
}
addCompletionHandlerArgument(I, HasResult);
}
break;
}
case HandlerType::RESULT: {
if (HasResult) {
OS << ".success(result)";
} else {
OS << ".failure(error";
addCastToCustomErrorTypeIfNecessary();
OS << ")";
}
break;
}
}
OS << ")"; // Close the call to the completion handler
}
/// Adds the result type of the converted async function.
void addAsyncFuncReturnType() {
SmallVector<Type, 2> Scratch;
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
if (ReturnTypes.size() > 1) {
OS << "(";
}
llvm::interleave(
ReturnTypes, [&](Type Ty) { Ty->print(OS); }, [&]() { OS << ", "; });
if (ReturnTypes.size() > 1) {
OS << ")";
}
}
/// If the async alternative function is generic, adds the type annotation
/// to the 'return' variable in the legacy function so that the generic
/// parameters of the legacy function are passed to the generic function.
/// For example for
/// \code
/// func foo<GenericParam>() async -> GenericParam {}
/// \endcode
/// we generate
/// \code
/// func foo<GenericParam>(completion: (T) -> Void) {
/// async {
/// let result: GenericParam = await foo()
/// <------------>
/// completion(result)
/// }
/// }
/// \endcode
/// This function adds the range marked by \c <----->
void addResultTypeAnnotationIfNecessary() {
if (FD->isGeneric()) {
OS << ": ";
addAsyncFuncReturnType();
}
}
public:
LegacyAlternativeBodyCreator(FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
: FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
bool canRewriteLegacyBody() {
if (FD == nullptr || FD->getBody() == nullptr) {
return false;
}
if (FD->hasThrows()) {
assert(!HandlerDesc.isValid() && "We shouldn't have found a handler desc "
"if the original function throws");
return false;
}
switch (HandlerDesc.Type) {
case HandlerType::INVALID:
return false;
case HandlerType::PARAMS: {
if (HandlerDesc.HasError) {
// The non-error parameters must be optional so that we can set them to
// nil in the error case.
// The error parameter must be optional so we can set it to nil in the
// success case.
// Otherwise we can't synthesize the values to return for these
// parameters.
return llvm::all_of(HandlerDesc.params(),
[](AnyFunctionType::Param Param) -> bool {
return Param.getPlainType()->isOptional();
});
} else {
return true;
}
}
case HandlerType::RESULT:
return true;
}
}
std::string create() {
assert(Buffer.empty() &&
"LegacyAlternativeBodyCreator can only be used once");
assert(canRewriteLegacyBody() &&
"Cannot create a legacy body if the body can't be rewritten");
OS << "{\n"; // start function body
OS << "async {\n";
if (HandlerDesc.HasError) {
OS << "do {\n";
if (!HandlerDesc.willAsyncReturnVoid()) {
OS << "let result";
addResultTypeAnnotationIfNecessary();
OS << " = ";
}
OS << "try await ";
addCallToAsyncMethod();
OS << "\n";
addCallToCompletionHandler(/*HasResult=*/true);
OS << "\n"
<< "} catch {\n";
addCallToCompletionHandler(/*HasResult=*/false);
OS << "\n"
<< "}\n"; // end catch
} else {
if (!HandlerDesc.willAsyncReturnVoid()) {
OS << "let result";
addResultTypeAnnotationIfNecessary();
OS << " = ";
}
OS << "await ";
addCallToAsyncMethod();
OS << "\n";
addCallToCompletionHandler(/*HasResult=*/true);
OS << "\n";
}
OS << "}\n"; // end 'async'
OS << "}\n"; // end function body
return Buffer;
}
};
} // namespace asyncrefactorings
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
@@ -5425,6 +5711,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
"@available(*, deprecated, message: \"Prefer async "
"alternative instead\")\n");
LegacyAlternativeBodyCreator LegacyBody(FD, HandlerDesc);
if (LegacyBody.canRewriteLegacyBody()) {
EditConsumer.accept(SM,
Lexer::getCharSourceRangeFromSourceRange(
SM, FD->getBody()->getSourceRange()),
LegacyBody.create());
}
Converter.insertAfter(FD, EditConsumer);
return false;