mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Merge pull request #37185 from ahoppen/pr/legacy-async-method-refactor
[Refactoring] When adding an async alternative refactor the old method to call the async method using `async`
This commit is contained in:
@@ -4149,6 +4149,36 @@ struct AsyncHandlerDesc {
|
||||
return params();
|
||||
}
|
||||
|
||||
/// Get the type of the error that will be thrown by the \c async method or \c
|
||||
/// None if the completion handler doesn't accept an error parameter.
|
||||
/// This may be more specialized than the generic 'Error' type if the
|
||||
/// completion handler of the converted function takes a more specialized
|
||||
/// error type.
|
||||
Optional<swift::Type> getErrorType() const {
|
||||
if (HasError) {
|
||||
switch (Type) {
|
||||
case HandlerType::INVALID:
|
||||
return None;
|
||||
case HandlerType::PARAMS:
|
||||
// The last parameter of the completion handler is the error param
|
||||
return params().back().getPlainType()->lookThroughSingleOptionalType();
|
||||
case HandlerType::RESULT:
|
||||
assert(
|
||||
params().size() == 1 &&
|
||||
"Result handler should have the Result type as the only parameter");
|
||||
auto ResultType =
|
||||
params().back().getPlainType()->getAs<BoundGenericType>();
|
||||
auto GenericArgs = ResultType->getGenericArgs();
|
||||
assert(GenericArgs.size() == 2 && "Result should have two params");
|
||||
// The second (last) generic parameter of the Result type is the error
|
||||
// type.
|
||||
return GenericArgs.back();
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
/// The `CallExpr` if the given node is a call to the `Handler`
|
||||
CallExpr *getAsHandlerCall(ASTNode Node) const {
|
||||
if (!isValid())
|
||||
@@ -5319,6 +5349,262 @@ private:
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// When adding an async alternative method for the function declaration \c FD,
|
||||
/// this class tries to create a function body for the legacy function (the one
|
||||
/// with a completion handler), which calls the newly converted async function.
|
||||
/// There are certain situations in which we fail to create such a body, e.g.
|
||||
/// if the completion handler has the signature `(String, Error?) -> Void` in
|
||||
/// which case we can't synthesize the result of type \c String in the error
|
||||
/// case.
|
||||
class LegacyAlternativeBodyCreator {
|
||||
/// The old function declaration for which an async alternative has been added
|
||||
/// and whose body shall be rewritten to call the newly added async
|
||||
/// alternative.
|
||||
FuncDecl *FD;
|
||||
|
||||
/// The description of the completion handler in the old function declaration.
|
||||
AsyncHandlerDesc HandlerDesc;
|
||||
|
||||
std::string Buffer;
|
||||
llvm::raw_string_ostream OS;
|
||||
|
||||
/// Adds the call to the refactored 'async' method without the 'await'
|
||||
/// keyword to the output stream.
|
||||
void addCallToAsyncMethod() {
|
||||
OS << FD->getBaseName() << "(";
|
||||
bool FirstParam = true;
|
||||
for (auto Param : *FD->getParameters()) {
|
||||
if (Param == HandlerDesc.Handler) {
|
||||
/// We don't need to pass the completion handler to the async method.
|
||||
continue;
|
||||
}
|
||||
if (!FirstParam) {
|
||||
OS << ", ";
|
||||
} else {
|
||||
FirstParam = false;
|
||||
}
|
||||
if (!Param->getArgumentName().empty()) {
|
||||
OS << Param->getArgumentName() << ": ";
|
||||
}
|
||||
OS << Param->getParameterName();
|
||||
}
|
||||
OS << ")";
|
||||
}
|
||||
|
||||
/// If the returned error type is more specialized than \c Error, adds an
|
||||
/// 'as! CustomError' cast to the more specialized error type to the output
|
||||
/// stream.
|
||||
void addCastToCustomErrorTypeIfNecessary() {
|
||||
auto ErrorType = *HandlerDesc.getErrorType();
|
||||
if (ErrorType->getCanonicalType() !=
|
||||
FD->getASTContext().getExceptionType()) {
|
||||
OS << " as! ";
|
||||
ErrorType->lookThroughSingleOptionalType()->print(OS);
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds the \c Index -th parameter to the completion handler.
|
||||
/// If \p HasResult is \c true, it is assumed that a variable named 'result'
|
||||
/// contains the result returned from the async alternative. If the callback
|
||||
/// also takes an error parameter, \c nil passed to the completion handler for
|
||||
/// the error.
|
||||
/// If \p HasResult is \c false, it is a assumed that a variable named 'error'
|
||||
/// contains the error thrown from the async method and 'nil' will be passed
|
||||
/// to the completion handler for all result parameters.
|
||||
void addCompletionHandlerArgument(size_t Index, bool HasResult) {
|
||||
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
|
||||
// The error parameter is the last argument of the completion handler.
|
||||
if (!HasResult) {
|
||||
OS << "error";
|
||||
addCastToCustomErrorTypeIfNecessary();
|
||||
} else {
|
||||
OS << "nil";
|
||||
}
|
||||
} else {
|
||||
if (!HasResult) {
|
||||
OS << "nil";
|
||||
} else if (HandlerDesc
|
||||
.getSuccessParamAsyncReturnType(
|
||||
HandlerDesc.params()[Index].getPlainType())
|
||||
->isVoid()) {
|
||||
// Void return types are not returned by the async function, synthesize
|
||||
// a Void instance.
|
||||
OS << "()";
|
||||
} else if (HandlerDesc.getSuccessParams().size() > 1) {
|
||||
// If the async method returns a tuple, we need to pass its elements to
|
||||
// the completion handler separately. For example:
|
||||
//
|
||||
// func foo() async -> (String, Int) {}
|
||||
//
|
||||
// causes the following legacy body to be created:
|
||||
//
|
||||
// func foo(completion: (String, Int) -> Void) {
|
||||
// async {
|
||||
// let result = await foo()
|
||||
// completion(result.0, result.1)
|
||||
// }
|
||||
// }
|
||||
OS << "result." << Index;
|
||||
} else {
|
||||
OS << "result";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds the call to the completion handler. See \c
|
||||
/// getCompletionHandlerArgument for how the arguments are synthesized if the
|
||||
/// completion handler takes arguments, not a \c Result type.
|
||||
void addCallToCompletionHandler(bool HasResult) {
|
||||
OS << HandlerDesc.Handler->getParameterName() << "(";
|
||||
|
||||
// Construct arguments to pass to the completion handler
|
||||
switch (HandlerDesc.Type) {
|
||||
case HandlerType::INVALID:
|
||||
llvm_unreachable("Cannot be rewritten");
|
||||
break;
|
||||
case HandlerType::PARAMS: {
|
||||
for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
|
||||
if (I > 0) {
|
||||
OS << ", ";
|
||||
}
|
||||
addCompletionHandlerArgument(I, HasResult);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HandlerType::RESULT: {
|
||||
if (HasResult) {
|
||||
OS << ".success(result)";
|
||||
} else {
|
||||
OS << ".failure(error";
|
||||
addCastToCustomErrorTypeIfNecessary();
|
||||
OS << ")";
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
OS << ")"; // Close the call to the completion handler
|
||||
}
|
||||
|
||||
/// Adds the result type of the converted async function.
|
||||
void addAsyncFuncReturnType() {
|
||||
SmallVector<Type, 2> Scratch;
|
||||
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
|
||||
if (ReturnTypes.size() > 1) {
|
||||
OS << "(";
|
||||
}
|
||||
|
||||
llvm::interleave(
|
||||
ReturnTypes, [&](Type Ty) { Ty->print(OS); }, [&]() { OS << ", "; });
|
||||
|
||||
if (ReturnTypes.size() > 1) {
|
||||
OS << ")";
|
||||
}
|
||||
}
|
||||
|
||||
/// If the async alternative function is generic, adds the type annotation
|
||||
/// to the 'return' variable in the legacy function so that the generic
|
||||
/// parameters of the legacy function are passed to the generic function.
|
||||
/// For example for
|
||||
/// \code
|
||||
/// func foo<GenericParam>() async -> GenericParam {}
|
||||
/// \endcode
|
||||
/// we generate
|
||||
/// \code
|
||||
/// func foo<GenericParam>(completion: (T) -> Void) {
|
||||
/// async {
|
||||
/// let result: GenericParam = await foo()
|
||||
/// <------------>
|
||||
/// completion(result)
|
||||
/// }
|
||||
/// }
|
||||
/// \endcode
|
||||
/// This function adds the range marked by \c <----->
|
||||
void addResultTypeAnnotationIfNecessary() {
|
||||
if (FD->isGeneric()) {
|
||||
OS << ": ";
|
||||
addAsyncFuncReturnType();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
LegacyAlternativeBodyCreator(FuncDecl *FD, AsyncHandlerDesc HandlerDesc)
|
||||
: FD(FD), HandlerDesc(HandlerDesc), OS(Buffer) {}
|
||||
|
||||
bool canRewriteLegacyBody() {
|
||||
if (FD == nullptr || FD->getBody() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (FD->hasThrows()) {
|
||||
assert(!HandlerDesc.isValid() && "We shouldn't have found a handler desc "
|
||||
"if the original function throws");
|
||||
return false;
|
||||
}
|
||||
switch (HandlerDesc.Type) {
|
||||
case HandlerType::INVALID:
|
||||
return false;
|
||||
case HandlerType::PARAMS: {
|
||||
if (HandlerDesc.HasError) {
|
||||
// The non-error parameters must be optional so that we can set them to
|
||||
// nil in the error case.
|
||||
// The error parameter must be optional so we can set it to nil in the
|
||||
// success case.
|
||||
// Otherwise we can't synthesize the values to return for these
|
||||
// parameters.
|
||||
return llvm::all_of(HandlerDesc.params(),
|
||||
[](AnyFunctionType::Param Param) -> bool {
|
||||
return Param.getPlainType()->isOptional();
|
||||
});
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
case HandlerType::RESULT:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::string create() {
|
||||
assert(Buffer.empty() &&
|
||||
"LegacyAlternativeBodyCreator can only be used once");
|
||||
assert(canRewriteLegacyBody() &&
|
||||
"Cannot create a legacy body if the body can't be rewritten");
|
||||
OS << "{\n"; // start function body
|
||||
OS << "async {\n";
|
||||
if (HandlerDesc.HasError) {
|
||||
OS << "do {\n";
|
||||
if (!HandlerDesc.willAsyncReturnVoid()) {
|
||||
OS << "let result";
|
||||
addResultTypeAnnotationIfNecessary();
|
||||
OS << " = ";
|
||||
}
|
||||
OS << "try await ";
|
||||
addCallToAsyncMethod();
|
||||
OS << "\n";
|
||||
addCallToCompletionHandler(/*HasResult=*/true);
|
||||
OS << "\n"
|
||||
<< "} catch {\n";
|
||||
addCallToCompletionHandler(/*HasResult=*/false);
|
||||
OS << "\n"
|
||||
<< "}\n"; // end catch
|
||||
} else {
|
||||
if (!HandlerDesc.willAsyncReturnVoid()) {
|
||||
OS << "let result";
|
||||
addResultTypeAnnotationIfNecessary();
|
||||
OS << " = ";
|
||||
}
|
||||
OS << "await ";
|
||||
addCallToAsyncMethod();
|
||||
OS << "\n";
|
||||
addCallToCompletionHandler(/*HasResult=*/true);
|
||||
OS << "\n";
|
||||
}
|
||||
OS << "}\n"; // end 'async'
|
||||
OS << "}\n"; // end function body
|
||||
return Buffer;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace asyncrefactorings
|
||||
|
||||
bool RefactoringActionConvertCallToAsyncAlternative::isApplicable(
|
||||
@@ -5425,6 +5711,13 @@ bool RefactoringActionAddAsyncAlternative::performChange() {
|
||||
EditConsumer.accept(SM, FD->getAttributeInsertionLoc(false),
|
||||
"@available(*, deprecated, message: \"Prefer async "
|
||||
"alternative instead\")\n");
|
||||
LegacyAlternativeBodyCreator LegacyBody(FD, HandlerDesc);
|
||||
if (LegacyBody.canRewriteLegacyBody()) {
|
||||
EditConsumer.accept(SM,
|
||||
Lexer::getCharSourceRangeFromSourceRange(
|
||||
SM, FD->getBody()->getSourceRange()),
|
||||
LegacyBody.create());
|
||||
}
|
||||
Converter.insertAfter(FD, EditConsumer);
|
||||
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user