mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
1902 lines
66 KiB
C++
1902 lines
66 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "AsyncRefactoring.h"
|
|
#include "Utils.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
|
|
using namespace swift;
|
|
using namespace swift::refactoring::asyncrefactorings;
|
|
|
|
|
|
/// Base name of a decl if it has one, an empty \c DeclBaseName otherwise.
|
|
static DeclBaseName getDeclName(const Decl *D) {
|
|
if (auto *VD = dyn_cast<ValueDecl>(D)) {
|
|
if (VD->hasName())
|
|
return VD->getBaseName();
|
|
}
|
|
return DeclBaseName();
|
|
}
|
|
|
|
bool AsyncConverter::convert() {
|
|
assert(Buffer.empty() && "AsyncConverter can only be used once");
|
|
|
|
if (auto *FD = dyn_cast_or_null<FuncDecl>(StartNode.dyn_cast<Decl *>())) {
|
|
addFuncDecl(FD);
|
|
if (FD->getBody()) {
|
|
convertNode(FD->getBody());
|
|
}
|
|
} else {
|
|
convertNode(StartNode, /*StartOverride=*/{}, /*ConvertCalls=*/true,
|
|
/*IncludeComments=*/false);
|
|
}
|
|
return !DiagEngine.hadAnyError();
|
|
}
|
|
|
|
bool AsyncConverter::createLegacyBody() {
|
|
assert(Buffer.empty() && "AsyncConverter can only be used once");
|
|
|
|
if (!canCreateLegacyBody())
|
|
return false;
|
|
|
|
FuncDecl *FD = cast<FuncDecl>(cast<Decl *>(StartNode));
|
|
OS << tok::l_brace << "\n"; // start function body
|
|
OS << "Task " << tok::l_brace << "\n";
|
|
addHoistedNamedCallback(FD, TopHandler, TopHandler.getNameStr(), [&]() {
|
|
if (TopHandler.HasError) {
|
|
OS << tok::kw_try << " ";
|
|
}
|
|
OS << "await ";
|
|
|
|
// Since we're *creating* the async alternative here, there shouldn't
|
|
// already be one. Thus, just assume that the call to the alternative is
|
|
// the same as the call to the old completion handler function, minus the
|
|
// completion handler arg.
|
|
addForwardingCallTo(FD, /*HandlerReplacement=*/"");
|
|
});
|
|
OS << "\n";
|
|
OS << tok::r_brace << "\n"; // end 'Task'
|
|
OS << tok::r_brace << "\n"; // end function body
|
|
return true;
|
|
}
|
|
|
|
bool AsyncConverter::createAsyncWrapper() {
|
|
assert(Buffer.empty() && "AsyncConverter can only be used once");
|
|
auto *FD = cast<FuncDecl>(cast<Decl *>(StartNode));
|
|
|
|
// First add the new async function declaration.
|
|
addFuncDecl(FD);
|
|
OS << tok::l_brace << "\n";
|
|
|
|
// Then add the body.
|
|
OS << tok::kw_return << " ";
|
|
if (TopHandler.HasError)
|
|
OS << tok::kw_try << " ";
|
|
|
|
OS << "await ";
|
|
|
|
// withChecked[Throwing]Continuation { continuation in
|
|
if (TopHandler.HasError) {
|
|
OS << "withCheckedThrowingContinuation";
|
|
} else {
|
|
OS << "withCheckedContinuation";
|
|
}
|
|
OS << " " << tok::l_brace << " continuation " << tok::kw_in << "\n";
|
|
|
|
// fnWithHandler(args...) { ... }
|
|
auto ClosureStr =
|
|
getAsyncWrapperCompletionClosure("continuation", TopHandler);
|
|
addForwardingCallTo(FD, /*HandlerReplacement=*/ClosureStr);
|
|
|
|
OS << "\n";
|
|
OS << tok::r_brace << "\n"; // end continuation closure
|
|
OS << tok::r_brace << "\n"; // end function body
|
|
return true;
|
|
}
|
|
|
|
void AsyncConverter::replace(ASTNode Node, SourceEditConsumer &EditConsumer,
|
|
SourceLoc StartOverride) {
|
|
SourceRange Range = Node.getSourceRange();
|
|
if (StartOverride.isValid()) {
|
|
Range = SourceRange(StartOverride, Range.End);
|
|
}
|
|
CharSourceRange CharRange =
|
|
Lexer::getCharSourceRangeFromSourceRange(SM, Range);
|
|
EditConsumer.accept(SM, CharRange, Buffer.str());
|
|
Buffer.clear();
|
|
}
|
|
|
|
void AsyncConverter::insertAfter(ASTNode Node,
|
|
SourceEditConsumer &EditConsumer) {
|
|
EditConsumer.insertAfter(SM, Node.getEndLoc(), "\n\n");
|
|
EditConsumer.insertAfter(SM, Node.getEndLoc(), Buffer.str());
|
|
Buffer.clear();
|
|
}
|
|
|
|
bool AsyncConverter::canCreateLegacyBody() {
|
|
FuncDecl *FD = dyn_cast<FuncDecl>(StartNode.dyn_cast<Decl *>());
|
|
if (!FD) {
|
|
return false;
|
|
}
|
|
if (FD == nullptr || FD->getBody() == nullptr) {
|
|
return false;
|
|
}
|
|
if (FD->hasThrows()) {
|
|
assert(!TopHandler.isValid() && "We shouldn't have found a handler desc "
|
|
"if the original function throws");
|
|
return false;
|
|
}
|
|
return TopHandler.isValid();
|
|
}
|
|
|
|
std::string AsyncConverter::getAsyncWrapperCompletionClosure(
|
|
StringRef ContName, const AsyncHandlerParamDesc &HandlerDesc) {
|
|
std::string OutputStr;
|
|
llvm::raw_string_ostream OS(OutputStr);
|
|
|
|
OS << tok::l_brace; // start closure
|
|
|
|
// Prepare parameter names for the closure.
|
|
auto SuccessParams = HandlerDesc.getSuccessParams();
|
|
SmallVector<SmallString<4>, 2> SuccessParamNames;
|
|
for (auto idx : indices(SuccessParams)) {
|
|
SuccessParamNames.emplace_back("result");
|
|
|
|
// If we have multiple success params, number them e.g res1, res2...
|
|
if (SuccessParams.size() > 1)
|
|
SuccessParamNames.back().append(std::to_string(idx + 1));
|
|
}
|
|
std::optional<SmallString<4>> ErrName;
|
|
if (HandlerDesc.getErrorParam())
|
|
ErrName.emplace("error");
|
|
|
|
auto HasAnyParams = !SuccessParamNames.empty() || ErrName;
|
|
if (HasAnyParams)
|
|
OS << " ";
|
|
|
|
// res1, res2
|
|
llvm::interleave(
|
|
SuccessParamNames, [&](auto Name) { OS << Name; },
|
|
[&]() { OS << tok::comma << " "; });
|
|
|
|
// , err
|
|
if (ErrName) {
|
|
if (!SuccessParamNames.empty())
|
|
OS << tok::comma << " ";
|
|
|
|
OS << *ErrName;
|
|
}
|
|
if (HasAnyParams)
|
|
OS << " " << tok::kw_in;
|
|
|
|
OS << "\n";
|
|
|
|
// The closure body.
|
|
switch (HandlerDesc.Type) {
|
|
case HandlerType::PARAMS: {
|
|
// For a (Success?, Error?) -> Void handler, we do an if let on the error.
|
|
if (ErrName) {
|
|
// if let err = err {
|
|
OS << tok::kw_if << " " << tok::kw_let << " ";
|
|
OS << *ErrName << " " << tok::equal << " " << *ErrName << " ";
|
|
OS << tok::l_brace << "\n";
|
|
for (auto Idx : indices(SuccessParamNames)) {
|
|
auto ParamTy = SuccessParams[Idx].getParameterType();
|
|
if (!HandlerDesc.shouldUnwrap(ParamTy))
|
|
continue;
|
|
}
|
|
|
|
// continuation.resume(throwing: err)
|
|
OS << ContName << tok::period << "resume" << tok::l_paren;
|
|
OS << "throwing" << tok::colon << " " << *ErrName;
|
|
OS << tok::r_paren << "\n";
|
|
|
|
// return }
|
|
OS << tok::kw_return << "\n";
|
|
OS << tok::r_brace << "\n";
|
|
}
|
|
|
|
// If we have any success params that we need to unwrap, insert a guard.
|
|
for (auto Idx : indices(SuccessParamNames)) {
|
|
auto &Name = SuccessParamNames[Idx];
|
|
auto ParamTy = SuccessParams[Idx].getParameterType();
|
|
if (!HandlerDesc.shouldUnwrap(ParamTy))
|
|
continue;
|
|
|
|
// guard let res = res else {
|
|
OS << tok::kw_guard << " " << tok::kw_let << " ";
|
|
OS << Name << " " << tok::equal << " " << Name << " " << tok::kw_else;
|
|
OS << " " << tok::l_brace << "\n";
|
|
|
|
// fatalError(...)
|
|
OS << "fatalError" << tok::l_paren;
|
|
OS << "\"Expected non-nil result '" << Name << "' for nil error\"";
|
|
OS << tok::r_paren << "\n";
|
|
|
|
// End guard.
|
|
OS << tok::r_brace << "\n";
|
|
}
|
|
|
|
// continuation.resume(returning: (res1, res2, ...))
|
|
OS << ContName << tok::period << "resume" << tok::l_paren;
|
|
OS << "returning" << tok::colon << " ";
|
|
addTupleOf(SuccessParamNames, OS, [&](auto Ref) { OS << Ref; });
|
|
OS << tok::r_paren << "\n";
|
|
break;
|
|
}
|
|
case HandlerType::RESULT: {
|
|
// continuation.resume(with: res)
|
|
assert(SuccessParamNames.size() == 1);
|
|
OS << ContName << tok::period << "resume" << tok::l_paren;
|
|
OS << "with" << tok::colon << " " << SuccessParamNames[0];
|
|
OS << tok::r_paren << "\n";
|
|
break;
|
|
}
|
|
case HandlerType::INVALID:
|
|
llvm_unreachable("Should not have an invalid handler here");
|
|
}
|
|
|
|
OS << tok::r_brace; // end closure
|
|
return OutputStr;
|
|
}
|
|
|
|
CharSourceRange AsyncConverter::getPrecedingCommentRange(SourceLoc Loc) {
|
|
auto Tokens = SF->getAllTokens();
|
|
auto TokenIter = token_lower_bound(Tokens, Loc);
|
|
if (TokenIter == Tokens.end() || !TokenIter->hasComment())
|
|
return CharSourceRange();
|
|
return TokenIter->getCommentRange();
|
|
}
|
|
|
|
SourceLoc AsyncConverter::getLocIncludingPrecedingComment(SourceLoc Loc) {
|
|
auto CommentRange = getPrecedingCommentRange(Loc);
|
|
if (CommentRange.isInvalid())
|
|
return Loc;
|
|
return CommentRange.getStart();
|
|
}
|
|
|
|
void AsyncConverter::printCommentIfNeeded(SourceLoc Loc) {
|
|
auto CommentRange = getPrecedingCommentRange(Loc);
|
|
if (CommentRange.isValid())
|
|
OS << "\n" << CommentRange.str();
|
|
}
|
|
|
|
void AsyncConverter::convertNodes(const NodesToPrint &ToPrint) {
|
|
// Sort the possible comment locs in reverse order so we can pop them as we
|
|
// go.
|
|
SmallVector<SourceLoc, 2> CommentLocs;
|
|
CommentLocs.append(ToPrint.getPossibleCommentLocs().begin(),
|
|
ToPrint.getPossibleCommentLocs().end());
|
|
llvm::sort(CommentLocs.begin(), CommentLocs.end(), [](auto lhs, auto rhs) {
|
|
return lhs.getOpaquePointerValue() > rhs.getOpaquePointerValue();
|
|
});
|
|
|
|
// First print the nodes we've been asked to print.
|
|
for (auto Node : ToPrint.getNodes()) {
|
|
// If we need to print comments, do so now.
|
|
while (!CommentLocs.empty()) {
|
|
auto CommentLoc = CommentLocs.back().getOpaquePointerValue();
|
|
auto NodeLoc = Node.getStartLoc().getOpaquePointerValue();
|
|
assert(CommentLoc != NodeLoc &&
|
|
"Added node to both comment locs and nodes to print?");
|
|
|
|
// If the comment occurs after the node, don't print now. Wait until
|
|
// the right node comes along.
|
|
if (CommentLoc > NodeLoc)
|
|
break;
|
|
|
|
printCommentIfNeeded(CommentLocs.pop_back_val());
|
|
}
|
|
OS << "\n";
|
|
convertNode(Node);
|
|
}
|
|
|
|
// We're done printing nodes. Make sure to output the remaining comments.
|
|
while (!CommentLocs.empty())
|
|
printCommentIfNeeded(CommentLocs.pop_back_val());
|
|
}
|
|
|
|
void AsyncConverter::convertNode(ASTNode Node, SourceLoc StartOverride,
|
|
bool ConvertCalls,
|
|
bool IncludePrecedingComment) {
|
|
if (!StartOverride.isValid())
|
|
StartOverride = Node.getStartLoc();
|
|
|
|
// Make sure to include any preceding comments attached to the loc
|
|
if (IncludePrecedingComment)
|
|
StartOverride = getLocIncludingPrecedingComment(StartOverride);
|
|
|
|
llvm::SaveAndRestore<SourceLoc> RestoreLoc(LastAddedLoc, StartOverride);
|
|
llvm::SaveAndRestore<int> RestoreCount(NestedExprCount, ConvertCalls ? 0 : 1);
|
|
|
|
walk(Node);
|
|
addRange(LastAddedLoc, Node.getEndLoc(), /*ToEndOfToken=*/true);
|
|
}
|
|
|
|
void AsyncConverter::convertPattern(const Pattern *P) {
|
|
// Only print semantic patterns. This cleans up the output of the transform
|
|
// and works around some bogus source locs that can appear with typed
|
|
// patterns in if let statements.
|
|
P = P->getSemanticsProvidingPattern();
|
|
|
|
// Set up the start of the pattern as the last loc printed to make sure we
|
|
// accurately fill in the gaps as we customize the printing of sub-patterns.
|
|
llvm::SaveAndRestore<SourceLoc> RestoreLoc(LastAddedLoc, P->getStartLoc());
|
|
llvm::SaveAndRestore<bool> RestoreFlag(ConvertingPattern, true);
|
|
|
|
walk(const_cast<Pattern *>(P));
|
|
addRange(LastAddedLoc, P->getEndLoc(), /*ToEndOfToken*/ true);
|
|
}
|
|
|
|
void AsyncConverter::wrapScopeInContinationIfNecessary(ASTNode Node) {
|
|
if (NestedExprCount != 0) {
|
|
// We can't start a continuation in the middle of an expression
|
|
return;
|
|
}
|
|
if (Scopes.back().isWrappedInContination()) {
|
|
// We are already in a continuation. No need to add another one.
|
|
return;
|
|
}
|
|
if (!DeclReferenceFinder::containsReference(Node, TopHandler.getHandler())) {
|
|
// The node doesn't have a reference to the function's completion handler.
|
|
// It can stay a call with a completion handler, because we don't need to
|
|
// promote a completion handler call to a 'return'.
|
|
return;
|
|
}
|
|
|
|
// Wrap the current call in a continuation
|
|
|
|
Identifier contName = createUniqueName("continuation");
|
|
Scopes.back().Names.insert(contName);
|
|
Scopes.back().ContinuationName = contName;
|
|
|
|
insertCustom(Node.getStartLoc(), [&]() {
|
|
OS << tok::kw_return << ' ';
|
|
if (TopHandler.HasError) {
|
|
OS << tok::kw_try << ' ';
|
|
}
|
|
OS << "await ";
|
|
if (TopHandler.HasError) {
|
|
OS << "withCheckedThrowingContinuation ";
|
|
} else {
|
|
OS << "withCheckedContinuation ";
|
|
}
|
|
OS << tok::l_brace << ' ' << contName << ' ' << tok::kw_in << '\n';
|
|
});
|
|
}
|
|
|
|
bool AsyncConverter::walkToPatternPre(Pattern *P) {
|
|
// If we're not converting a pattern, there's nothing extra to do.
|
|
if (!ConvertingPattern)
|
|
return true;
|
|
|
|
// When converting a pattern, don't print the 'let' or 'var' of binding
|
|
// subpatterns, as they're illegal when nested in PBDs, and we print a
|
|
// top-level one.
|
|
if (auto *BP = dyn_cast<BindingPattern>(P)) {
|
|
return addCustom(BP->getSourceRange(),
|
|
[&]() { convertPattern(BP->getSubPattern()); });
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool AsyncConverter::walkToDeclPre(Decl *D, CharSourceRange Range) {
|
|
if (isa<PatternBindingDecl>(D)) {
|
|
// We can't hoist a closure inside a PatternBindingDecl. If it contains
|
|
// a call to the completion handler, wrap it in a continuation.
|
|
wrapScopeInContinationIfNecessary(D);
|
|
NestedExprCount++;
|
|
return true;
|
|
}
|
|
|
|
// Functions and types already have their names in \c Scopes.Names, only
|
|
// variables should need to be renamed.
|
|
if (isa<VarDecl>(D)) {
|
|
// If we don't already have a name for the var, assign it one. Note that
|
|
// vars in binding patterns may already have assigned names here.
|
|
if (Names.find(D) == Names.end()) {
|
|
auto Ident = assignUniqueName(D, StringRef());
|
|
Scopes.back().Names.insert(Ident);
|
|
}
|
|
addCustom(D->getSourceRange(), [&]() { OS << newNameFor(D); });
|
|
}
|
|
|
|
// Note we don't walk into any nested local function decls. If we start
|
|
// doing so in the future, be sure to update the logic that deals with
|
|
// converting unhandled returns into placeholders in walkToStmtPre.
|
|
return false;
|
|
}
|
|
|
|
bool AsyncConverter::walkToDeclPost(Decl *D) {
|
|
NestedExprCount--;
|
|
return true;
|
|
}
|
|
|
|
#define PLACEHOLDER_START "<#"
|
|
#define PLACEHOLDER_END "#>"
|
|
bool AsyncConverter::walkToExprPre(Expr *E) {
|
|
// We've already added any shorthand if declaration, don't add its
|
|
// synthesized initializer as well.
|
|
if (shorthandIfInits.contains(E))
|
|
return true;
|
|
|
|
// TODO: Handle Result.get as well
|
|
if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
|
|
if (auto *D = DRE->getDecl()) {
|
|
// Look through to the parent var decl if we have one. This ensures we
|
|
// look at the var in a case stmt's pattern rather than the var that's
|
|
// implicitly declared in the body.
|
|
if (auto *VD = dyn_cast<VarDecl>(D)) {
|
|
if (auto *Parent = VD->getParentVarDecl())
|
|
D = Parent;
|
|
}
|
|
|
|
bool AddPlaceholder = Placeholders.count(D);
|
|
StringRef Name = newNameFor(D, false);
|
|
if (AddPlaceholder || !Name.empty())
|
|
return addCustom(DRE->getSourceRange(), [&]() {
|
|
if (AddPlaceholder)
|
|
OS << PLACEHOLDER_START;
|
|
if (!Name.empty())
|
|
OS << Name;
|
|
else
|
|
D->getName().print(OS);
|
|
if (AddPlaceholder)
|
|
OS << PLACEHOLDER_END;
|
|
});
|
|
}
|
|
} else if (isa<ForceValueExpr>(E) || isa<BindOptionalExpr>(E)) {
|
|
// Remove a force unwrap or optional chain of a returned success value,
|
|
// as it will no longer be optional. For force unwraps, this is always a
|
|
// valid transform. For optional chains, it is a locally valid transform
|
|
// within the optional chain e.g foo?.x -> foo.x, but may change the type
|
|
// of the overall chain, which could cause errors elsewhere in the code.
|
|
// However this is generally more useful to the user than just leaving
|
|
// 'foo' as a placeholder. Note this is only the case when no other
|
|
// optionals are involved in the chain, e.g foo?.x?.y -> foo.x?.y is
|
|
// completely valid.
|
|
if (auto *D = E->getReferencedDecl().getDecl()) {
|
|
if (Unwraps.count(D))
|
|
return addCustom(E->getSourceRange(),
|
|
[&]() { OS << newNameFor(D, true); });
|
|
}
|
|
} else if (CallExpr *CE = TopHandler.getAsHandlerCall(E)) {
|
|
if (Scopes.back().isWrappedInContination()) {
|
|
return addCustom(E->getSourceRange(),
|
|
[&]() { convertHandlerToContinuationResume(CE); });
|
|
} else if (NestedExprCount == 0) {
|
|
return addCustom(E->getSourceRange(),
|
|
[&]() { convertHandlerToReturnOrThrows(CE); });
|
|
}
|
|
} else if (auto *CE = dyn_cast<CallExpr>(E)) {
|
|
// Try and hoist a call's completion handler. Don't do so if
|
|
// - the current expression is nested (we can't start hoisting in the
|
|
// middle of an expression)
|
|
// - the current scope is wrapped in a continuation (we can't have await
|
|
// calls in the continuation block)
|
|
if (NestedExprCount == 0 && !Scopes.back().isWrappedInContination()) {
|
|
// If the refactoring is on the call itself, do not require the callee
|
|
// to have the @available attribute or a completion-like name.
|
|
auto HandlerDesc = AsyncHandlerParamDesc::find(
|
|
getUnderlyingFunc(CE->getFn()),
|
|
/*RequireAttributeOrName=*/StartNode.dyn_cast<Expr *>() != CE);
|
|
if (HandlerDesc.isValid()) {
|
|
return addCustom(CE->getSourceRange(),
|
|
[&]() { addHoistedCallback(CE, HandlerDesc); });
|
|
}
|
|
}
|
|
}
|
|
|
|
// A void SingleValueStmtExpr is semantically more like a statement than
|
|
// an expression, so recurse without bumping the expr depth or wrapping in
|
|
// continuation.
|
|
if (auto *SVE = dyn_cast<SingleValueStmtExpr>(E)) {
|
|
auto ty = SVE->getType();
|
|
if (!ty || ty->isVoid())
|
|
return true;
|
|
}
|
|
|
|
// We didn't do any special conversion for this expression. If needed, wrap
|
|
// it in a continuation.
|
|
wrapScopeInContinationIfNecessary(E);
|
|
|
|
NestedExprCount++;
|
|
return true;
|
|
}
|
|
|
|
bool AsyncConverter::replaceRangeWithPlaceholder(SourceRange range) {
|
|
return addCustom(range, [&]() {
|
|
OS << PLACEHOLDER_START;
|
|
addRange(range, /*toEndOfToken*/ true);
|
|
OS << PLACEHOLDER_END;
|
|
});
|
|
}
|
|
|
|
bool AsyncConverter::walkToExprPost(Expr *E) {
|
|
if (auto *SVE = dyn_cast<SingleValueStmtExpr>(E)) {
|
|
auto ty = SVE->getType();
|
|
if (!ty || ty->isVoid())
|
|
return true;
|
|
}
|
|
NestedExprCount--;
|
|
return true;
|
|
}
|
|
|
|
#undef PLACEHOLDER_START
|
|
#undef PLACEHOLDER_END
|
|
|
|
bool AsyncConverter::walkToStmtPre(Stmt *S) {
|
|
// Keep track of any shorthand initializer expressions
|
|
if (auto *labeledConditional = dyn_cast<LabeledConditionalStmt>(S)) {
|
|
for (const auto &condition : labeledConditional->getCond()) {
|
|
if (auto *init = condition.getSynthesizedShorthandInitOrNull()) {
|
|
shorthandIfInits.insert(init);
|
|
}
|
|
}
|
|
}
|
|
|
|
// CaseStmt has an implicit BraceStmt inside it, which *should* start a new
|
|
// scope, so don't check isImplicit here.
|
|
if (startsNewScope(S)) {
|
|
// Add all names of decls referenced within this statement that aren't
|
|
// also declared first, plus any contexts. Note that \c getReferencedDecl
|
|
// will only return a value for a \c BraceStmt. This means that \c IfStmt
|
|
// (and other statements with conditions) will have their own empty scope,
|
|
// which is fine for our purposes - their existing names are always valid.
|
|
// The body of those statements will include the decls if they've been
|
|
// referenced, so shadowing is still avoided there.
|
|
if (auto *ReferencedDecls = ScopedDecls.getReferencedDecls(S)) {
|
|
llvm::DenseSet<const Decl *> Decls;
|
|
for (auto DeclAndNumRefs : *ReferencedDecls)
|
|
Decls.insert(DeclAndNumRefs.first);
|
|
addNewScope(Decls);
|
|
} else {
|
|
addNewScope({});
|
|
}
|
|
} else if (Hoisting && !S->isImplicit()) {
|
|
// Some break and return statements need to be turned into placeholders,
|
|
// as they may no longer perform the control flow that the user is
|
|
// expecting.
|
|
if (auto *BS = dyn_cast<BreakStmt>(S)) {
|
|
// For a break, if it's jumping out of a switch statement that we've
|
|
// re-written as a part of the transform, turn it into a placeholder, as
|
|
// it would have been lifted out of the switch statement.
|
|
if (auto *SS = dyn_cast<SwitchStmt>(BS->getTarget())) {
|
|
if (HandledSwitches.contains(SS))
|
|
return replaceRangeWithPlaceholder(S->getSourceRange());
|
|
}
|
|
} else if (isa<ReturnStmt>(S) && NestedExprCount == 0) {
|
|
// For a return, if it's not nested inside another closure or function,
|
|
// turn it into a placeholder, as it will be lifted out of the callback.
|
|
// Note that we only turn the 'return' token into a placeholder as we
|
|
// still want to be able to apply transforms to the argument.
|
|
replaceRangeWithPlaceholder(S->getStartLoc());
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool AsyncConverter::walkToStmtPost(Stmt *S) {
|
|
if (startsNewScope(S)) {
|
|
bool ClosedScopeWasWrappedInContinuation =
|
|
Scopes.back().isWrappedInContination();
|
|
Scopes.pop_back();
|
|
if (ClosedScopeWasWrappedInContinuation &&
|
|
!Scopes.back().isWrappedInContination()) {
|
|
// The nested scope was wrapped in a continuation but the current one
|
|
// isn't anymore. Add the '}' that corresponds to the call to
|
|
// withChecked(Throwing)Continuation.
|
|
insertCustom(S->getEndLoc(), [&]() { OS << tok::r_brace << '\n'; });
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool AsyncConverter::addCustom(SourceRange Range,
|
|
llvm::function_ref<void()> Custom) {
|
|
addRange(LastAddedLoc, Range.Start);
|
|
Custom();
|
|
LastAddedLoc = Lexer::getLocForEndOfToken(SM, Range.End);
|
|
return false;
|
|
}
|
|
|
|
bool AsyncConverter::insertCustom(SourceLoc Loc,
|
|
llvm::function_ref<void()> Custom) {
|
|
addRange(LastAddedLoc, Loc);
|
|
Custom();
|
|
LastAddedLoc = Loc;
|
|
return false;
|
|
}
|
|
|
|
void AsyncConverter::addRange(SourceLoc Start, SourceLoc End,
|
|
bool ToEndOfToken) {
|
|
if (ToEndOfToken) {
|
|
OS << Lexer::getCharSourceRangeFromSourceRange(SM, SourceRange(Start, End))
|
|
.str();
|
|
} else {
|
|
OS << CharSourceRange(SM, Start, End).str();
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addRange(SourceRange Range, bool ToEndOfToken) {
|
|
addRange(Range.Start, Range.End, ToEndOfToken);
|
|
}
|
|
|
|
void AsyncConverter::addFuncDecl(const FuncDecl *FD) {
|
|
auto *Params = FD->getParameters();
|
|
auto *HandlerParam = TopHandler.getHandlerParam();
|
|
auto ParamPos = TopHandler.handlerParamPosition();
|
|
|
|
// If the completion handler parameter has a default argument, the async
|
|
// version is effectively @discardableResult, as not all the callers care
|
|
// about receiving the completion call.
|
|
if (HandlerParam && HandlerParam->isDefaultArgument())
|
|
OS << tok::at_sign << "discardableResult"
|
|
<< "\n";
|
|
|
|
// First chunk: start -> the parameter to remove (if any)
|
|
SourceLoc LeftEndLoc;
|
|
switch (ParamPos) {
|
|
case AsyncHandlerParamDesc::Position::None:
|
|
case AsyncHandlerParamDesc::Position::Only:
|
|
case AsyncHandlerParamDesc::Position::First:
|
|
// Handler is the first param (or there is none), so only include the (
|
|
LeftEndLoc = Params->getLParenLoc().getAdvancedLoc(1);
|
|
break;
|
|
case AsyncHandlerParamDesc::Position::Middle:
|
|
// Handler is somewhere in the middle of the params, so we need to
|
|
// include any comments and comma up until the handler
|
|
LeftEndLoc = Params->get(TopHandler.Index)->getStartLoc();
|
|
LeftEndLoc = getLocIncludingPrecedingComment(LeftEndLoc);
|
|
break;
|
|
case AsyncHandlerParamDesc::Position::Last:
|
|
// Handler is the last param, which means we don't want the comma. This
|
|
// is a little annoying since we *do* want the comments past for the
|
|
// last parameter
|
|
LeftEndLoc = Lexer::getLocForEndOfToken(
|
|
SM, Params->get(TopHandler.Index - 1)->getEndLoc());
|
|
// Skip to the end of any comments
|
|
Token Next =
|
|
Lexer::getTokenAtLocation(SM, LeftEndLoc, CommentRetentionMode::None);
|
|
if (Next.getKind() != tok::NUM_TOKENS)
|
|
LeftEndLoc = Next.getLoc();
|
|
break;
|
|
}
|
|
addRange(FD->getSourceRangeIncludingAttrs().Start, LeftEndLoc);
|
|
|
|
// Second chunk: end of the parameter to remove -> right parenthesis
|
|
SourceLoc MidStartLoc;
|
|
SourceLoc MidEndLoc = Params->getRParenLoc().getAdvancedLoc(1);
|
|
switch (ParamPos) {
|
|
case AsyncHandlerParamDesc::Position::None:
|
|
// No handler param, so make sure to include them all
|
|
MidStartLoc = LeftEndLoc;
|
|
break;
|
|
case AsyncHandlerParamDesc::Position::First:
|
|
case AsyncHandlerParamDesc::Position::Middle:
|
|
// Handler param is either the first or one of the middle params. Skip
|
|
// past it but make sure to include comments preceding the param after
|
|
// the handler
|
|
MidStartLoc = Params->get(TopHandler.Index + 1)->getStartLoc();
|
|
MidStartLoc = getLocIncludingPrecedingComment(MidStartLoc);
|
|
break;
|
|
case AsyncHandlerParamDesc::Position::Only:
|
|
case AsyncHandlerParamDesc::Position::Last:
|
|
// Handler param is last, this is easy since there's no other params
|
|
// to copy over
|
|
MidStartLoc = Params->getRParenLoc();
|
|
break;
|
|
}
|
|
addRange(MidStartLoc, MidEndLoc);
|
|
|
|
// Third chunk: add in async and throws if necessary
|
|
if (!FD->hasAsync())
|
|
OS << " async";
|
|
if (FD->hasThrows() || TopHandler.HasError)
|
|
// TODO: Add throws if converting a function and it has a converted call
|
|
// without a do/catch
|
|
OS << " " << tok::kw_throws;
|
|
|
|
// Fourth chunk: if no parent handler (ie. not adding an async
|
|
// alternative), the rest of the decl. Otherwise, add in the new return
|
|
// type
|
|
if (!TopHandler.isValid()) {
|
|
SourceLoc RightStartLoc = MidEndLoc;
|
|
if (FD->hasThrows()) {
|
|
RightStartLoc = Lexer::getLocForEndOfToken(SM, FD->getThrowsLoc());
|
|
}
|
|
SourceLoc RightEndLoc =
|
|
FD->getBody() ? FD->getBody()->getLBraceLoc() : RightStartLoc;
|
|
addRange(RightStartLoc, RightEndLoc);
|
|
return;
|
|
}
|
|
|
|
SmallVector<LabeledReturnType, 2> Scratch;
|
|
auto ReturnTypes = TopHandler.getAsyncReturnTypes(Scratch);
|
|
if (ReturnTypes.empty()) {
|
|
OS << " ";
|
|
return;
|
|
}
|
|
|
|
// Print the function result type, making sure to omit a '-> Void' return.
|
|
if (!TopHandler.willAsyncReturnVoid()) {
|
|
OS << " -> ";
|
|
addAsyncFuncReturnType(TopHandler);
|
|
}
|
|
|
|
if (FD->hasBody())
|
|
OS << " ";
|
|
|
|
// TODO: Should remove the generic param and where clause for the error
|
|
// param if it exists (and no other parameter uses that type)
|
|
TrailingWhereClause *TWC = FD->getTrailingWhereClause();
|
|
if (TWC && TWC->getWhereLoc().isValid()) {
|
|
auto Range = TWC->getSourceRange();
|
|
OS << Lexer::getCharSourceRangeFromSourceRange(SM, Range).str();
|
|
if (FD->hasBody())
|
|
OS << " ";
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addFallbackVars(ArrayRef<const ParamDecl *> FallbackParams,
|
|
const ClosureCallbackParams &AllParams) {
|
|
for (auto *Param : FallbackParams) {
|
|
auto Ty = Param->getTypeInContext();
|
|
auto ParamName = newNameFor(Param);
|
|
|
|
// If this is the known bool success param, we can use 'let' and type it
|
|
// as non-optional, as it gets bound in both blocks.
|
|
if (AllParams.isKnownBoolFlagParam(Param)) {
|
|
OS << tok::kw_let << " " << ParamName << ": ";
|
|
Ty->print(OS);
|
|
OS << "\n";
|
|
continue;
|
|
}
|
|
|
|
OS << tok::kw_var << " " << ParamName << ": ";
|
|
Ty->print(OS);
|
|
if (!Ty->getOptionalObjectType())
|
|
OS << "?";
|
|
|
|
OS << " = " << tok::kw_nil << "\n";
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addDo() {
|
|
OS << tok::kw_do << " " << tok::l_brace << "\n";
|
|
}
|
|
|
|
bool AsyncConverter::isErrorAlreadyHandled(HandlerResult Result) {
|
|
assert(Result.isError());
|
|
assert(Result.args().size() == 1 &&
|
|
"There should only be one error parameter");
|
|
// We assume that the error has already been handled if its variable
|
|
// declaration doesn't exist anymore, which is the case if it's in
|
|
// Placeholders but not in Unwraps (if it's in Placeholders and Unwraps
|
|
// an optional Error has simply been promoted to a non-optional Error).
|
|
if (auto *DRE = dyn_cast<DeclRefExpr>(Result.args().back().getExpr())) {
|
|
if (Placeholders.count(DRE->getDecl()) && !Unwraps.count(DRE->getDecl())) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool AsyncConverter::isExpressionOptional(Expr *E) {
|
|
if (isa<InjectIntoOptionalExpr>(E)) {
|
|
// E is downgrading a non-Optional result to an Optional. Its source
|
|
// representation isn't Optional.
|
|
return false;
|
|
}
|
|
if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
|
|
if (Unwraps.count(DRE->getDecl())) {
|
|
// E has been promoted to a non-Optional value. It can't be used as an
|
|
// Optional anymore.
|
|
return false;
|
|
}
|
|
}
|
|
if (!E->getType().isNull() && E->getType()->isOptional()) {
|
|
return true;
|
|
}
|
|
// We couldn't determine the type. Assume non-Optional.
|
|
return false;
|
|
}
|
|
|
|
void AsyncConverter::convertHandlerCall(
|
|
const CallExpr *CE,
|
|
llvm::function_ref<void(HandlerResult)> AddConvertedHandlerCall,
|
|
llvm::function_ref<void(StringRef)> AddConvertedErrorCall) {
|
|
auto Result =
|
|
TopHandler.extractResultArgs(CE, /*ReturnErrorArgsIfAmbiguous=*/true);
|
|
if (!TopHandler.isAmbiguousCallToParamHandler(CE)) {
|
|
if (Result.isError()) {
|
|
if (!isErrorAlreadyHandled(Result)) {
|
|
// If the error has already been handled, we don't need to add another
|
|
// throwing call.
|
|
AddConvertedHandlerCall(Result);
|
|
}
|
|
} else {
|
|
AddConvertedHandlerCall(Result);
|
|
}
|
|
} else {
|
|
assert(Result.isError() && "If the call was ambiguous, we should have "
|
|
"retrieved its error representation");
|
|
assert(Result.args().size() == 1 &&
|
|
"There should only be one error parameter");
|
|
Expr *ErrorExpr = Result.args().back().getExpr();
|
|
if (isErrorAlreadyHandled(Result)) {
|
|
// The error has already been handled, interpret the call as a success
|
|
// call.
|
|
auto SuccessExprs = TopHandler.extractResultArgs(
|
|
CE, /*ReturnErrorArgsIfAmbiguous=*/false);
|
|
AddConvertedHandlerCall(SuccessExprs);
|
|
} else if (!isExpressionOptional(ErrorExpr)) {
|
|
// The error is never nil. No matter what the success param is, we
|
|
// interpret it as an error call.
|
|
AddConvertedHandlerCall(Result);
|
|
} else {
|
|
// The call was truly ambiguous. Add an
|
|
// if let error = <convert error arg> {
|
|
// throw error // or equivalent
|
|
// } else {
|
|
// <interpret call as success call>
|
|
// }
|
|
auto SuccessExprs = TopHandler.extractResultArgs(
|
|
CE, /*ReturnErrorArgsIfAmbiguous=*/false);
|
|
|
|
// The variable 'error' is only available in the 'if let' scope, so we
|
|
// don't need to create a new unique one.
|
|
StringRef ErrorName = "error";
|
|
OS << tok::kw_if << ' ' << tok::kw_let << ' ' << ErrorName << ' '
|
|
<< tok::equal << ' ';
|
|
convertNode(ErrorExpr, /*StartOverride=*/{}, /*ConvertCalls=*/false);
|
|
OS << ' ' << tok::l_brace << '\n';
|
|
AddConvertedErrorCall(ErrorName);
|
|
OS << tok::r_brace << ' ' << tok::kw_else << ' ' << tok::l_brace << '\n';
|
|
AddConvertedHandlerCall(SuccessExprs);
|
|
OS << '\n' << tok::r_brace;
|
|
}
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::convertHandlerToReturnOrThrows(const CallExpr *CE) {
|
|
return convertHandlerCall(
|
|
CE,
|
|
[&](HandlerResult Exprs) {
|
|
convertHandlerToReturnOrThrowsImpl(CE, Exprs);
|
|
},
|
|
[&](StringRef ErrorName) {
|
|
OS << tok::kw_throw << ' ' << ErrorName << '\n';
|
|
});
|
|
}
|
|
|
|
void AsyncConverter::convertHandlerToReturnOrThrowsImpl(const CallExpr *CE,
|
|
HandlerResult Result) {
|
|
bool AddedReturnOrThrow = true;
|
|
if (!Result.isError()) {
|
|
// It's possible the user has already written an explicit return statement
|
|
// for the completion handler call, e.g 'return completion(args...)'. In
|
|
// that case, be sure not to add another return.
|
|
auto *parent = getWalker().Parent.getAsStmt();
|
|
if (isa_and_nonnull<ReturnStmt>(parent) &&
|
|
!cast<ReturnStmt>(parent)->isImplicit()) {
|
|
// The statement already has a return keyword. Don't add another one.
|
|
AddedReturnOrThrow = false;
|
|
} else {
|
|
OS << tok::kw_return;
|
|
}
|
|
} else {
|
|
OS << tok::kw_throw;
|
|
}
|
|
|
|
auto Args = Result.args();
|
|
if (!Args.empty()) {
|
|
if (AddedReturnOrThrow)
|
|
OS << ' ';
|
|
|
|
addTupleOf(Args, OS, [&](Argument Arg) {
|
|
// Special case: If the completion handler is a params handler that
|
|
// takes an error, we could pass arguments to it without unwrapping
|
|
// them. E.g.
|
|
// simpleWithError { (res: String?, error: Error?) in
|
|
// completion(res, nil)
|
|
// }
|
|
// But after refactoring `simpleWithError` to an async function we have
|
|
// let res: String = await simple()
|
|
// and `res` is no longer an `Optional`. Thus it's in `Placeholders` and
|
|
// `Unwraps` and any reference to it will be replaced by a placeholder
|
|
// unless it is wrapped in an unwrapping expression. This would cause us
|
|
// to create `return <#res# >`.
|
|
// Under our assumption that either the error or the result parameter
|
|
// are non-nil, the above call to the completion handler is equivalent
|
|
// to
|
|
// completion(res!, nil)
|
|
// which correctly yields
|
|
// return res
|
|
// Synthesize the force unwrap so that we get the expected results.
|
|
auto *E = Arg.getExpr();
|
|
if (TopHandler.getHandlerType() == HandlerType::PARAMS &&
|
|
TopHandler.HasError) {
|
|
if (auto DRE = dyn_cast<DeclRefExpr>(E->getSemanticsProvidingExpr())) {
|
|
auto D = DRE->getDecl();
|
|
if (Unwraps.count(D)) {
|
|
E = new (getASTContext()) ForceValueExpr(E, SourceLoc());
|
|
}
|
|
}
|
|
}
|
|
// Can't just add the range as we need to perform replacements
|
|
convertNode(E, /*StartOverride=*/Arg.getLabelLoc(),
|
|
/*ConvertCalls=*/false);
|
|
});
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::convertHandlerToContinuationResume(const CallExpr *CE) {
|
|
return convertHandlerCall(
|
|
CE,
|
|
[&](HandlerResult Exprs) {
|
|
convertHandlerToContinuationResumeImpl(CE, Exprs);
|
|
},
|
|
[&](StringRef ErrorName) {
|
|
Identifier ContinuationName = Scopes.back().ContinuationName;
|
|
OS << ContinuationName << tok::period << "resume" << tok::l_paren
|
|
<< "throwing" << tok::colon << ' ' << ErrorName;
|
|
OS << tok::r_paren << '\n';
|
|
});
|
|
}
|
|
|
|
void AsyncConverter::convertHandlerToContinuationResumeImpl(
|
|
const CallExpr *CE, HandlerResult Result) {
|
|
assert(Scopes.back().isWrappedInContination());
|
|
|
|
std::vector<Argument> Args;
|
|
StringRef ResumeArgumentLabel;
|
|
switch (TopHandler.getHandlerType()) {
|
|
case HandlerType::PARAMS: {
|
|
Args = Result.args();
|
|
if (!Result.isError()) {
|
|
ResumeArgumentLabel = "returning";
|
|
} else {
|
|
ResumeArgumentLabel = "throwing";
|
|
}
|
|
break;
|
|
}
|
|
case HandlerType::RESULT: {
|
|
Args = {CE->getArgs()->begin(), CE->getArgs()->end()};
|
|
ResumeArgumentLabel = "with";
|
|
break;
|
|
}
|
|
case HandlerType::INVALID:
|
|
llvm_unreachable("Invalid top handler");
|
|
}
|
|
|
|
// A vector in which each argument of Result has an entry. If the entry is
|
|
// not empty then that argument has been unwrapped using 'guard let' into
|
|
// a variable with that name.
|
|
SmallVector<Identifier, 4> ArgNames;
|
|
ArgNames.reserve(Args.size());
|
|
|
|
/// When unwrapping a result argument \p Arg into a variable using
|
|
/// 'guard let' return a suitable name for the unwrapped variable.
|
|
/// \p ArgIndex is the index of \p Arg in the results passed to the
|
|
/// completion handler.
|
|
auto GetSuitableNameForGuardUnwrap = [&](Expr *Arg,
|
|
unsigned ArgIndex) -> Identifier {
|
|
// If Arg is a DeclRef, use its name for the guard unwrap.
|
|
// guard let myVar1 = myVar.
|
|
if (auto DRE = dyn_cast<DeclRefExpr>(Arg)) {
|
|
return createUniqueName(DRE->getDecl()->getBaseIdentifier().str());
|
|
} else if (auto IIOE = dyn_cast<InjectIntoOptionalExpr>(Arg)) {
|
|
if (auto DRE = dyn_cast<DeclRefExpr>(IIOE->getSubExpr())) {
|
|
return createUniqueName(DRE->getDecl()->getBaseIdentifier().str());
|
|
}
|
|
}
|
|
if (Args.size() == 1) {
|
|
// We only have a single result. 'result' seems a resonable name.
|
|
return createUniqueName("result");
|
|
} else {
|
|
// We are returning a tuple. Name the result elements 'result' +
|
|
// index in tuple.
|
|
return createUniqueName("result" + std::to_string(ArgIndex));
|
|
}
|
|
};
|
|
|
|
unsigned ArgIndex = 0;
|
|
for (auto Arg : Args) {
|
|
auto *ArgExpr = Arg.getExpr();
|
|
Identifier ArgName;
|
|
if (isExpressionOptional(ArgExpr) && TopHandler.HasError) {
|
|
ArgName = GetSuitableNameForGuardUnwrap(ArgExpr, ArgIndex);
|
|
Scopes.back().Names.insert(ArgName);
|
|
OS << tok::kw_guard << ' ' << tok::kw_let << ' ' << ArgName << ' '
|
|
<< tok::equal << ' ';
|
|
|
|
// If the argument is a call with a trailing closure, the generated
|
|
// guard statement will not compile.
|
|
// e.g. 'guard let result1 = value.map { $0 + 1 } else { ... }'
|
|
// doesn't compile. Adding parentheses makes the code compile.
|
|
auto HasTrailingClosure = false;
|
|
if (auto *CE = dyn_cast<CallExpr>(ArgExpr)) {
|
|
if (CE->getArgs()->hasAnyTrailingClosures())
|
|
HasTrailingClosure = true;
|
|
}
|
|
|
|
if (HasTrailingClosure)
|
|
OS << tok::l_paren;
|
|
|
|
convertNode(ArgExpr, /*StartOverride=*/Arg.getLabelLoc(),
|
|
/*ConvertCalls=*/false);
|
|
|
|
if (HasTrailingClosure)
|
|
OS << tok::r_paren;
|
|
|
|
OS << ' ' << tok::kw_else << ' ' << tok::l_brace << '\n';
|
|
OS << "fatalError" << tok::l_paren;
|
|
OS << "\"Expected non-nil result ";
|
|
if (ArgName.str() != "result") {
|
|
OS << "'" << ArgName << "' ";
|
|
}
|
|
OS << "in the non-error case\"";
|
|
OS << tok::r_paren << '\n';
|
|
OS << tok::r_brace << '\n';
|
|
}
|
|
ArgNames.push_back(ArgName);
|
|
ArgIndex++;
|
|
}
|
|
|
|
Identifier ContName = Scopes.back().ContinuationName;
|
|
OS << ContName << tok::period << "resume" << tok::l_paren
|
|
<< ResumeArgumentLabel << tok::colon << ' ';
|
|
|
|
ArgIndex = 0;
|
|
addTupleOf(Args, OS, [&](Argument Arg) {
|
|
Identifier ArgName = ArgNames[ArgIndex];
|
|
if (!ArgName.empty()) {
|
|
OS << ArgName;
|
|
} else {
|
|
// Can't just add the range as we need to perform replacements
|
|
convertNode(Arg.getExpr(), /*StartOverride=*/Arg.getLabelLoc(),
|
|
/*ConvertCalls=*/false);
|
|
}
|
|
ArgIndex++;
|
|
});
|
|
OS << tok::r_paren;
|
|
}
|
|
|
|
ClosureExpr *AsyncConverter::extractCallback(Expr *E) {
|
|
E = lookThroughFunctionConversionExpr(E);
|
|
if (auto Closure = dyn_cast<ClosureExpr>(E)) {
|
|
return Closure;
|
|
} else if (auto CaptureList = dyn_cast<CaptureListExpr>(E)) {
|
|
return dyn_cast<ClosureExpr>(CaptureList->getClosureBody());
|
|
} else {
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
Expr *AsyncConverter::lookThroughFunctionConversionExpr(Expr *E) {
|
|
if (auto FunctionConversion = dyn_cast<FunctionConversionExpr>(E)) {
|
|
return lookThroughFunctionConversionExpr(FunctionConversion->getSubExpr());
|
|
} else {
|
|
return E;
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addHoistedCallback(
|
|
const CallExpr *CE, const AsyncHandlerParamDesc &HandlerDesc) {
|
|
llvm::SaveAndRestore<bool> RestoreHoisting(Hoisting, true);
|
|
|
|
auto *ArgList = CE->getArgs();
|
|
if (HandlerDesc.Index >= ArgList->size()) {
|
|
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
|
|
return;
|
|
}
|
|
|
|
Expr *CallbackArg =
|
|
lookThroughFunctionConversionExpr(ArgList->getExpr(HandlerDesc.Index));
|
|
if (ClosureExpr *Callback = extractCallback(CallbackArg)) {
|
|
// The user is using a closure for the completion handler
|
|
addHoistedClosureCallback(CE, HandlerDesc, Callback);
|
|
return;
|
|
}
|
|
if (auto CallbackDecl =
|
|
getReferencedDeclLookingThroughAutoclosures(CallbackArg)) {
|
|
if (CallbackDecl == TopHandler.getHandler()) {
|
|
// We are refactoring the function that declared the completion handler
|
|
// that would be called here. We can't call the completion handler
|
|
// anymore because it will be removed. But since the function that
|
|
// declared it is being refactored to async, we can just return the
|
|
// values.
|
|
if (!HandlerDesc.willAsyncReturnVoid()) {
|
|
OS << tok::kw_return << " ";
|
|
}
|
|
InlinePatternsToPrint InlinePatterns;
|
|
addAwaitCall(CE, ClassifiedBlock(), {}, InlinePatterns, HandlerDesc,
|
|
/*AddDeclarations*/ false);
|
|
return;
|
|
}
|
|
// We are not removing the completion handler, so we can call it once the
|
|
// async function returns.
|
|
|
|
// The completion handler that is called as part of the \p CE call.
|
|
// This will be called once the async function returns.
|
|
auto CompletionHandler =
|
|
AsyncHandlerDesc::get(CallbackDecl, /*RequireAttributeOrName=*/false);
|
|
if (CompletionHandler.isValid()) {
|
|
if (auto CalledFunc = getUnderlyingFunc(CE->getFn())) {
|
|
StringRef HandlerName = Lexer::getCharSourceRangeFromSourceRange(
|
|
SM, CallbackArg->getSourceRange())
|
|
.str();
|
|
addHoistedNamedCallback(
|
|
CalledFunc, CompletionHandler, HandlerName, [&] {
|
|
InlinePatternsToPrint InlinePatterns;
|
|
addAwaitCall(CE, ClassifiedBlock(), {}, InlinePatterns,
|
|
HandlerDesc, /*AddDeclarations*/ false);
|
|
});
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
|
|
}
|
|
|
|
void AsyncConverter::addBoolFlagParamBindingIfNeeded(
|
|
std::optional<KnownBoolFlagParam> Flag, BlockKind Block) {
|
|
if (!Flag)
|
|
return;
|
|
// Figure out the polarity of the binding based on the block we're in and
|
|
// whether the flag indicates success.
|
|
auto Polarity = true;
|
|
switch (Block) {
|
|
case BlockKind::SUCCESS:
|
|
break;
|
|
case BlockKind::ERROR:
|
|
Polarity = !Polarity;
|
|
break;
|
|
case BlockKind::FALLBACK:
|
|
llvm_unreachable("Not a valid place to bind");
|
|
}
|
|
if (!Flag->IsSuccessFlag)
|
|
Polarity = !Polarity;
|
|
|
|
OS << newNameFor(Flag->Param) << " " << tok::equal << " ";
|
|
OS << (Polarity ? tok::kw_true : tok::kw_false) << "\n";
|
|
}
|
|
|
|
void AsyncConverter::addHoistedClosureCallback(
|
|
const CallExpr *CE, const AsyncHandlerParamDesc &HandlerDesc,
|
|
const ClosureExpr *Callback) {
|
|
if (HandlerDesc.params().size() != Callback->getParameters()->size()) {
|
|
DiagEngine.diagnose(CE->getStartLoc(), diag::mismatched_callback_args);
|
|
return;
|
|
}
|
|
ClosureCallbackParams CallbackParams(HandlerDesc, Callback);
|
|
ClassifiedBlocks Blocks;
|
|
auto *CallbackBody = Callback->getBody();
|
|
if (!HandlerDesc.HasError) {
|
|
Blocks.SuccessBlock.addNodesInBraceStmt(CallbackBody);
|
|
} else if (!CallbackBody->getElements().empty()) {
|
|
CallbackClassifier::classifyInto(Blocks, CallbackParams, HandledSwitches,
|
|
DiagEngine, CallbackBody);
|
|
}
|
|
|
|
auto SuccessBindings = CallbackParams.getParamsToBind(BlockKind::SUCCESS);
|
|
auto *ErrParam = CallbackParams.getErrParam();
|
|
if (DiagEngine.hadAnyError()) {
|
|
// For now, only fallback when the results are params with an error param,
|
|
// in which case only the names are used (defaulted to the names of the
|
|
// params if none).
|
|
if (HandlerDesc.Type != HandlerType::PARAMS || !HandlerDesc.HasError)
|
|
return;
|
|
DiagEngine.resetHadAnyError();
|
|
|
|
// Note that we don't print any inline patterns here as we just want
|
|
// assignments to the names in the outer scope.
|
|
InlinePatternsToPrint InlinePatterns;
|
|
|
|
auto AllBindings = CallbackParams.getParamsToBind(BlockKind::FALLBACK);
|
|
|
|
prepareNames(ClassifiedBlock(), AllBindings, InlinePatterns);
|
|
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
|
|
BlockKind::FALLBACK);
|
|
addFallbackVars(AllBindings, CallbackParams);
|
|
addDo();
|
|
addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns,
|
|
HandlerDesc, /*AddDeclarations*/ false);
|
|
OS << "\n";
|
|
|
|
// If we have a known Bool success param, we need to bind it.
|
|
addBoolFlagParamBindingIfNeeded(CallbackParams.getKnownBoolFlagParam(),
|
|
BlockKind::SUCCESS);
|
|
addFallbackCatch(CallbackParams);
|
|
OS << "\n";
|
|
convertNodes(NodesToPrint::inBraceStmt(CallbackBody));
|
|
|
|
clearNames(AllBindings);
|
|
return;
|
|
}
|
|
|
|
auto *ErrOrResultParam = ErrParam;
|
|
if (auto *ResultParam = CallbackParams.getResultParam())
|
|
ErrOrResultParam = ResultParam;
|
|
|
|
auto ErrorNodes = Blocks.ErrorBlock.nodesToPrint().getNodes();
|
|
bool RequireDo = !ErrorNodes.empty();
|
|
// Check if we *actually* need a do/catch (see class comment)
|
|
if (ErrorNodes.size() == 1) {
|
|
auto Node = ErrorNodes[0];
|
|
if (auto *HandlerCall = TopHandler.getAsHandlerCall(Node)) {
|
|
auto Res = TopHandler.extractResultArgs(
|
|
HandlerCall, /*ReturnErrorArgsIfAmbiguous=*/true);
|
|
if (Res.args().size() == 1) {
|
|
// Skip if we have the param itself or the name it's bound to
|
|
auto *ArgExpr = Res.args()[0].getExpr();
|
|
auto *SingleDecl = ArgExpr->getReferencedDecl().getDecl();
|
|
auto ErrName = Blocks.ErrorBlock.boundName(ErrOrResultParam);
|
|
RequireDo = SingleDecl != ErrOrResultParam &&
|
|
!(Res.isError() && SingleDecl &&
|
|
SingleDecl->getName().isSimpleName(ErrName));
|
|
}
|
|
}
|
|
}
|
|
|
|
// If we're not requiring a 'do', we'll be dropping the error block. But
|
|
// let's make sure we at least preserve the comments in the error block by
|
|
// transplanting them into the success block. This should make sure they
|
|
// maintain a sensible ordering.
|
|
if (!RequireDo) {
|
|
auto ErrorNodes = Blocks.ErrorBlock.nodesToPrint();
|
|
for (auto CommentLoc : ErrorNodes.getPossibleCommentLocs())
|
|
Blocks.SuccessBlock.addPossibleCommentLoc(CommentLoc);
|
|
}
|
|
|
|
if (RequireDo) {
|
|
addDo();
|
|
}
|
|
|
|
auto InlinePatterns =
|
|
getInlinePatternsToPrint(Blocks.SuccessBlock, SuccessBindings, Callback);
|
|
|
|
prepareNames(Blocks.SuccessBlock, SuccessBindings, InlinePatterns);
|
|
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
|
|
BlockKind::SUCCESS);
|
|
|
|
addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns,
|
|
HandlerDesc, /*AddDeclarations=*/true);
|
|
printOutOfLineBindingPatterns(Blocks.SuccessBlock, InlinePatterns);
|
|
convertNodes(Blocks.SuccessBlock.nodesToPrint());
|
|
clearNames(SuccessBindings);
|
|
|
|
if (RequireDo) {
|
|
// We don't use inline patterns for the error path.
|
|
InlinePatternsToPrint ErrInlinePatterns;
|
|
|
|
// Always use the ErrParam name if none is bound.
|
|
prepareNames(Blocks.ErrorBlock, llvm::ArrayRef(ErrOrResultParam),
|
|
ErrInlinePatterns,
|
|
/*AddIfMissing=*/HandlerDesc.Type != HandlerType::RESULT);
|
|
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
|
|
BlockKind::ERROR);
|
|
|
|
addCatch(ErrOrResultParam);
|
|
convertNodes(Blocks.ErrorBlock.nodesToPrint());
|
|
OS << "\n" << tok::r_brace;
|
|
clearNames(llvm::ArrayRef(ErrOrResultParam));
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addHoistedNamedCallback(
|
|
const FuncDecl *FD, const AsyncHandlerDesc &HandlerDesc,
|
|
StringRef HandlerName, std::function<void(void)> AddAwaitCall) {
|
|
if (HandlerDesc.HasError) {
|
|
// "result" and "error" always okay to use here since they're added
|
|
// in their own scope, which only contains new code.
|
|
addDo();
|
|
if (!HandlerDesc.willAsyncReturnVoid()) {
|
|
OS << tok::kw_let << " result";
|
|
addResultTypeAnnotationIfNecessary(FD, HandlerDesc);
|
|
OS << " " << tok::equal << " ";
|
|
}
|
|
AddAwaitCall();
|
|
OS << "\n";
|
|
addCallToCompletionHandler("result", HandlerDesc, HandlerName);
|
|
OS << "\n";
|
|
OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n";
|
|
addCallToCompletionHandler(StringRef(), HandlerDesc, HandlerName);
|
|
OS << "\n" << tok::r_brace; // end catch
|
|
} else {
|
|
// This code may be placed into an existing scope, in that case create
|
|
// a unique "result" name so that it doesn't cause shadowing or redecls.
|
|
StringRef ResultName;
|
|
if (!HandlerDesc.willAsyncReturnVoid()) {
|
|
Identifier Unique = createUniqueName("result");
|
|
Scopes.back().Names.insert(Unique);
|
|
ResultName = Unique.str();
|
|
|
|
OS << tok::kw_let << " " << ResultName;
|
|
addResultTypeAnnotationIfNecessary(FD, HandlerDesc);
|
|
OS << " " << tok::equal << " ";
|
|
} else {
|
|
// The name won't end up being used, just give it a bogus one so that
|
|
// the result path is taken (versus the error path).
|
|
ResultName = "result";
|
|
}
|
|
AddAwaitCall();
|
|
OS << "\n";
|
|
addCallToCompletionHandler(ResultName, HandlerDesc, HandlerName);
|
|
}
|
|
}
|
|
|
|
const Pattern *AsyncConverter::bindingPatternToPrintInline(
|
|
const Decl *D, const ClassifiedBlock &Block,
|
|
const ClosureExpr *CallbackClosure) {
|
|
// Only currently done for callback closures.
|
|
if (!CallbackClosure)
|
|
return nullptr;
|
|
|
|
// If we can reduce the pattern bindings down to a single pattern, we may
|
|
// be able to print it inline.
|
|
auto *P = Block.getSinglePatternFor(D);
|
|
if (!P)
|
|
return nullptr;
|
|
|
|
// Patterns that bind a single var are always printed inline.
|
|
if (P->getSingleVar())
|
|
return P;
|
|
|
|
// If we have a multi-var binding, and the decl being bound is referenced
|
|
// elsewhere in the block, we cannot print the pattern immediately in the
|
|
// await call. Instead, we'll print it out of line.
|
|
auto *Decls = ScopedDecls.getReferencedDecls(CallbackClosure->getBody());
|
|
assert(Decls);
|
|
auto NumRefs = Decls->lookup(D);
|
|
return NumRefs == 1 ? P : nullptr;
|
|
}
|
|
|
|
AsyncConverter::InlinePatternsToPrint
|
|
AsyncConverter::getInlinePatternsToPrint(const ClassifiedBlock &Block,
|
|
ArrayRef<const ParamDecl *> Params,
|
|
const ClosureExpr *CallbackClosure) {
|
|
InlinePatternsToPrint Patterns;
|
|
for (auto *Param : Params) {
|
|
if (auto *P = bindingPatternToPrintInline(Param, Block, CallbackClosure))
|
|
Patterns[Param] = P;
|
|
}
|
|
return Patterns;
|
|
}
|
|
|
|
void AsyncConverter::printOutOfLineBindingPatterns(
|
|
const ClassifiedBlock &Block, const InlinePatternsToPrint &InlinePatterns) {
|
|
for (auto &Entry : Block.paramPatternBindings()) {
|
|
auto *D = Entry.first;
|
|
auto Aliases = Block.getAliasesFor(D);
|
|
|
|
for (auto *P : Entry.second) {
|
|
// If we already printed this as an inline pattern, there's nothing else
|
|
// to do.
|
|
if (InlinePatterns.lookup(D) == P)
|
|
continue;
|
|
|
|
// If this is an alias binding, it can be elided.
|
|
if (auto *SingleVar = P->getSingleVar()) {
|
|
if (Aliases.contains(SingleVar))
|
|
continue;
|
|
}
|
|
|
|
auto HasMutable = P->hasAnyMutableBindings();
|
|
OS << "\n" << (HasMutable ? tok::kw_var : tok::kw_let) << " ";
|
|
convertPattern(P);
|
|
OS << " = ";
|
|
OS << newNameFor(D);
|
|
}
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addAwaitCall(const CallExpr *CE,
|
|
const ClassifiedBlock &SuccessBlock,
|
|
ArrayRef<const ParamDecl *> SuccessParams,
|
|
const InlinePatternsToPrint &InlinePatterns,
|
|
const AsyncHandlerParamDesc &HandlerDesc,
|
|
bool AddDeclarations) {
|
|
auto *Args = CE->getArgs();
|
|
|
|
// Print the bindings to match the completion handler success parameters,
|
|
// making sure to omit in the case of a Void return.
|
|
if (!SuccessParams.empty() && !HandlerDesc.willAsyncReturnVoid()) {
|
|
auto AllLet = true;
|
|
|
|
// Gather the items to print for the variable bindings. This can either be
|
|
// a param decl, or a pattern that binds it.
|
|
using DeclOrPattern = llvm::PointerUnion<const Decl *, const Pattern *>;
|
|
SmallVector<DeclOrPattern, 4> ToPrint;
|
|
for (auto *Param : SuccessParams) {
|
|
// Check if we have an inline pattern to print.
|
|
if (auto *P = InlinePatterns.lookup(Param)) {
|
|
if (P->hasAnyMutableBindings())
|
|
AllLet = false;
|
|
ToPrint.push_back(P);
|
|
continue;
|
|
}
|
|
ToPrint.push_back(Param);
|
|
}
|
|
|
|
if (AddDeclarations) {
|
|
if (AllLet) {
|
|
OS << tok::kw_let;
|
|
} else {
|
|
OS << tok::kw_var;
|
|
}
|
|
OS << " ";
|
|
}
|
|
// 'res =' or '(res1, res2, ...) ='
|
|
addTupleOf(ToPrint, OS, [&](DeclOrPattern Elt) {
|
|
if (auto *P = Elt.dyn_cast<const Pattern *>()) {
|
|
convertPattern(P);
|
|
return;
|
|
}
|
|
OS << newNameFor(cast<const Decl *>(Elt));
|
|
});
|
|
OS << " " << tok::equal << " ";
|
|
}
|
|
|
|
if (HandlerDesc.HasError) {
|
|
OS << tok::kw_try << " ";
|
|
}
|
|
OS << "await ";
|
|
|
|
// Try to replace the name with that of the alternative. Use the existing
|
|
// name if for some reason that's not possible.
|
|
bool NameAdded = false;
|
|
if (HandlerDesc.Alternative) {
|
|
const ValueDecl *Named = HandlerDesc.Alternative;
|
|
if (auto *Accessor = dyn_cast<AccessorDecl>(HandlerDesc.Alternative))
|
|
Named = Accessor->getStorage();
|
|
if (!Named->getBaseName().isSpecial()) {
|
|
Names.try_emplace(HandlerDesc.Func, Named->getBaseName().getIdentifier());
|
|
convertNode(CE->getFn(), /*StartOverride=*/{}, /*ConvertCalls=*/false,
|
|
/*IncludeComments=*/false);
|
|
NameAdded = true;
|
|
}
|
|
}
|
|
if (!NameAdded) {
|
|
addRange(CE->getStartLoc(), CE->getFn()->getEndLoc(),
|
|
/*ToEndOfToken=*/true);
|
|
}
|
|
|
|
if (!HandlerDesc.alternativeIsAccessor())
|
|
OS << tok::l_paren;
|
|
|
|
size_t ConvertedArgIndex = 0;
|
|
ArrayRef<ParamDecl *> AlternativeParams;
|
|
if (HandlerDesc.Alternative)
|
|
AlternativeParams = HandlerDesc.Alternative->getParameters()->getArray();
|
|
|
|
for (auto I : indices(*Args)) {
|
|
auto Arg = Args->get(I);
|
|
auto *ArgExpr = Arg.getExpr();
|
|
if (I == HandlerDesc.Index || isa<DefaultArgumentExpr>(ArgExpr))
|
|
continue;
|
|
|
|
if (ConvertedArgIndex > 0)
|
|
OS << tok::comma << " ";
|
|
|
|
if (HandlerDesc.Alternative) {
|
|
// Skip argument if it's defaulted and has a different name
|
|
while (ConvertedArgIndex < AlternativeParams.size() &&
|
|
AlternativeParams[ConvertedArgIndex]->isDefaultArgument() &&
|
|
AlternativeParams[ConvertedArgIndex]->getArgumentName() !=
|
|
Arg.getLabel()) {
|
|
ConvertedArgIndex++;
|
|
}
|
|
|
|
if (ConvertedArgIndex < AlternativeParams.size()) {
|
|
// Could have a different argument label (or none), so add it instead
|
|
auto Name = AlternativeParams[ConvertedArgIndex]->getArgumentName();
|
|
if (!Name.empty())
|
|
OS << Name << ": ";
|
|
convertNode(ArgExpr, /*StartOverride=*/{}, /*ConvertCalls=*/false);
|
|
|
|
ConvertedArgIndex++;
|
|
continue;
|
|
}
|
|
|
|
// Fallthrough if arguments don't match up for some reason
|
|
}
|
|
|
|
// Can't just add the range as we need to perform replacements. Also
|
|
// make sure to include the argument label (if any)
|
|
convertNode(ArgExpr, /*StartOverride=*/Arg.getLabelLoc(),
|
|
/*ConvertCalls=*/false);
|
|
ConvertedArgIndex++;
|
|
}
|
|
|
|
if (!HandlerDesc.alternativeIsAccessor())
|
|
OS << tok::r_paren;
|
|
}
|
|
|
|
void AsyncConverter::addFallbackCatch(const ClosureCallbackParams &Params) {
|
|
auto *ErrParam = Params.getErrParam();
|
|
assert(ErrParam);
|
|
auto ErrName = newNameFor(ErrParam);
|
|
OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n"
|
|
<< ErrName << " = error\n";
|
|
|
|
// If we have a known Bool success param, we need to bind it.
|
|
addBoolFlagParamBindingIfNeeded(Params.getKnownBoolFlagParam(),
|
|
BlockKind::ERROR);
|
|
OS << tok::r_brace;
|
|
}
|
|
|
|
void AsyncConverter::addCatch(const ParamDecl *ErrParam) {
|
|
OS << "\n" << tok::r_brace << " " << tok::kw_catch << " ";
|
|
auto ErrName = newNameFor(ErrParam, false);
|
|
if (!ErrName.empty() && ErrName != "_") {
|
|
OS << tok::kw_let << " " << ErrName << " ";
|
|
}
|
|
OS << tok::l_brace;
|
|
}
|
|
|
|
void AsyncConverter::preparePlaceholdersAndUnwraps(
|
|
AsyncHandlerDesc HandlerDesc, const ClosureCallbackParams &Params,
|
|
BlockKind Block) {
|
|
// Params that have been dropped always need placeholdering.
|
|
for (auto *Param : Params.getAllParams()) {
|
|
if (!Params.hasBinding(Param, Block))
|
|
Placeholders.insert(Param);
|
|
}
|
|
// For the fallback case, no other params need placeholdering, as they are
|
|
// all freely accessible in the fallback case.
|
|
if (Block == BlockKind::FALLBACK)
|
|
return;
|
|
|
|
switch (HandlerDesc.Type) {
|
|
case HandlerType::PARAMS: {
|
|
auto *ErrParam = Params.getErrParam();
|
|
auto SuccessParams = Params.getSuccessParams();
|
|
switch (Block) {
|
|
case BlockKind::FALLBACK:
|
|
llvm_unreachable("Already handled");
|
|
case BlockKind::ERROR:
|
|
if (ErrParam) {
|
|
if (HandlerDesc.shouldUnwrap(ErrParam->getTypeInContext())) {
|
|
Placeholders.insert(ErrParam);
|
|
Unwraps.insert(ErrParam);
|
|
}
|
|
// Can't use success params in the error body
|
|
Placeholders.insert(SuccessParams.begin(), SuccessParams.end());
|
|
}
|
|
break;
|
|
case BlockKind::SUCCESS:
|
|
for (auto *SuccessParam : SuccessParams) {
|
|
auto Ty = SuccessParam->getTypeInContext();
|
|
if (HandlerDesc.shouldUnwrap(Ty)) {
|
|
// Either unwrap or replace with a placeholder if there's some other
|
|
// reference
|
|
Unwraps.insert(SuccessParam);
|
|
Placeholders.insert(SuccessParam);
|
|
}
|
|
|
|
// Void parameters get omitted where possible, so turn any reference
|
|
// into a placeholder, as its usage is unlikely what the user wants.
|
|
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
|
|
Placeholders.insert(SuccessParam);
|
|
}
|
|
// Can't use the error param in the success body
|
|
if (ErrParam)
|
|
Placeholders.insert(ErrParam);
|
|
break;
|
|
}
|
|
break;
|
|
}
|
|
case HandlerType::RESULT: {
|
|
// Any uses of the result parameter in the current body (that aren't
|
|
// replaced) are invalid, so replace them with a placeholder.
|
|
auto *ResultParam = Params.getResultParam();
|
|
assert(ResultParam);
|
|
Placeholders.insert(ResultParam);
|
|
break;
|
|
}
|
|
default:
|
|
llvm_unreachable("Unhandled handler type");
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::prepareNames(const ClassifiedBlock &Block,
|
|
ArrayRef<const ParamDecl *> Params,
|
|
const InlinePatternsToPrint &InlinePatterns,
|
|
bool AddIfMissing) {
|
|
for (auto *PD : Params) {
|
|
// If this param is to be replaced by a pattern that binds multiple
|
|
// separate vars, it's not actually going to be added to the scope, and
|
|
// therefore doesn't need naming. This avoids needing to rename a var with
|
|
// the same name later on in the scope, as it's not actually clashing.
|
|
if (auto *P = InlinePatterns.lookup(PD)) {
|
|
if (!P->getSingleVar())
|
|
continue;
|
|
}
|
|
auto Name = Block.boundName(PD);
|
|
if (Name.empty() && !AddIfMissing)
|
|
continue;
|
|
|
|
auto Ident = assignUniqueName(PD, Name);
|
|
|
|
// Also propagate the name to any aliases.
|
|
for (auto *Alias : Block.getAliasesFor(PD))
|
|
Names[Alias] = Ident;
|
|
}
|
|
}
|
|
|
|
Identifier AsyncConverter::createUniqueName(StringRef Name) {
|
|
Identifier Ident = getASTContext().getIdentifier(Name);
|
|
if (Name == "_")
|
|
return Ident;
|
|
|
|
auto &CurrentNames = Scopes.back().Names;
|
|
if (CurrentNames.count(Ident)) {
|
|
// Add a number to the end of the name until it's unique given the current
|
|
// names in scope.
|
|
llvm::SmallString<32> UniquedName;
|
|
unsigned UniqueId = 1;
|
|
do {
|
|
UniquedName = Name;
|
|
UniquedName.append(std::to_string(UniqueId));
|
|
Ident = getASTContext().getIdentifier(UniquedName);
|
|
UniqueId++;
|
|
} while (CurrentNames.count(Ident));
|
|
}
|
|
return Ident;
|
|
}
|
|
|
|
Identifier AsyncConverter::assignUniqueName(const Decl *D,
|
|
StringRef BoundName) {
|
|
Identifier Ident;
|
|
if (BoundName.empty()) {
|
|
BoundName = getDeclName(D).userFacingName();
|
|
if (BoundName.empty())
|
|
return Ident;
|
|
}
|
|
|
|
if (BoundName.starts_with("$")) {
|
|
llvm::SmallString<8> NewName;
|
|
NewName.append("val");
|
|
NewName.append(BoundName.drop_front());
|
|
Ident = createUniqueName(NewName);
|
|
} else {
|
|
Ident = createUniqueName(BoundName);
|
|
}
|
|
|
|
Names.try_emplace(D, Ident);
|
|
Scopes.back().Names.insert(Ident);
|
|
return Ident;
|
|
}
|
|
|
|
StringRef AsyncConverter::newNameFor(const Decl *D, bool Required) {
|
|
auto Res = Names.find(D);
|
|
if (Res == Names.end()) {
|
|
assert(!Required && "Missing name for decl when one was required");
|
|
return StringRef();
|
|
}
|
|
return Res->second.str();
|
|
}
|
|
|
|
void AsyncConverter::addNewScope(const llvm::DenseSet<const Decl *> &Decls) {
|
|
if (Scopes.empty()) {
|
|
Scopes.emplace_back(/*ContinuationName=*/Identifier());
|
|
} else {
|
|
// If the parent scope is nested in a continuation, the new one is also.
|
|
// Carry over the continuation name.
|
|
Identifier PreviousContinuationName = Scopes.back().ContinuationName;
|
|
Scopes.emplace_back(PreviousContinuationName);
|
|
}
|
|
for (auto D : Decls) {
|
|
auto Name = getDeclName(D);
|
|
if (!Name.empty())
|
|
Scopes.back().Names.insert(Name);
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::clearNames(ArrayRef<const ParamDecl *> Params) {
|
|
for (auto *Param : Params) {
|
|
Unwraps.erase(Param);
|
|
Placeholders.erase(Param);
|
|
Names.erase(Param);
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addForwardingCallTo(const FuncDecl *FD,
|
|
StringRef HandlerReplacement) {
|
|
OS << FD->getBaseName() << tok::l_paren;
|
|
|
|
auto *Params = FD->getParameters();
|
|
size_t ConvertedArgsIndex = 0;
|
|
for (size_t I = 0, E = Params->size(); I < E; ++I) {
|
|
if (I == TopHandler.Index) {
|
|
/// If we're not replacing the handler with anything, drop it.
|
|
if (HandlerReplacement.empty())
|
|
continue;
|
|
|
|
// Use a trailing closure if the handler is the last param
|
|
if (I == E - 1) {
|
|
OS << tok::r_paren << " ";
|
|
OS << HandlerReplacement;
|
|
return;
|
|
}
|
|
|
|
// Otherwise fall through to do the replacement.
|
|
}
|
|
|
|
if (ConvertedArgsIndex > 0)
|
|
OS << tok::comma << " ";
|
|
|
|
const auto *Param = Params->get(I);
|
|
if (!Param->getArgumentName().empty())
|
|
OS << Param->getArgumentName() << tok::colon << " ";
|
|
|
|
if (I == TopHandler.Index) {
|
|
OS << HandlerReplacement;
|
|
} else {
|
|
OS << Param->getParameterName();
|
|
}
|
|
|
|
ConvertedArgsIndex++;
|
|
}
|
|
OS << tok::r_paren;
|
|
}
|
|
|
|
void AsyncConverter::addForwardedErrorArgument(
|
|
StringRef ErrorName, const AsyncHandlerDesc &HandlerDesc) {
|
|
// If the error type is already Error, we can pass it as-is.
|
|
auto ErrorType = *HandlerDesc.getErrorType();
|
|
if (ErrorType->getCanonicalType() ==
|
|
getASTContext().getErrorExistentialType()) {
|
|
OS << ErrorName;
|
|
return;
|
|
}
|
|
|
|
// Otherwise we need to add a force cast to the destination custom error
|
|
// type. If this is for an Error? parameter, we'll need to add parens around
|
|
// the cast to silence a compiler warning about force casting never
|
|
// producing nil.
|
|
auto RequiresParens = HandlerDesc.getErrorParam().has_value();
|
|
if (RequiresParens)
|
|
OS << tok::l_paren;
|
|
|
|
OS << ErrorName << " " << tok::kw_as << tok::exclaim_postfix << " ";
|
|
ErrorType->lookThroughSingleOptionalType()->print(OS);
|
|
|
|
if (RequiresParens)
|
|
OS << tok::r_paren;
|
|
}
|
|
|
|
void AsyncConverter::addDefaultValueOrPlaceholder(Type T) {
|
|
if (T->isOptional()) {
|
|
OS << tok::kw_nil;
|
|
} else if (T->isVoid()) {
|
|
OS << "()";
|
|
} else {
|
|
OS << "<#";
|
|
T.print(OS);
|
|
OS << "#>";
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addCompletionHandlerArgument(
|
|
size_t Index, StringRef ResultName, const AsyncHandlerDesc &HandlerDesc) {
|
|
if (HandlerDesc.HasError && Index == HandlerDesc.params().size() - 1) {
|
|
// The error parameter is the last argument of the completion handler.
|
|
if (ResultName.empty()) {
|
|
addForwardedErrorArgument("error", HandlerDesc);
|
|
} else {
|
|
addDefaultValueOrPlaceholder(HandlerDesc.params()[Index].getPlainType());
|
|
}
|
|
} else {
|
|
if (ResultName.empty()) {
|
|
addDefaultValueOrPlaceholder(HandlerDesc.params()[Index].getPlainType());
|
|
} else if (HandlerDesc
|
|
.getSuccessParamAsyncReturnType(
|
|
HandlerDesc.params()[Index].getPlainType())
|
|
->isVoid()) {
|
|
// Void return types are not returned by the async function, synthesize
|
|
// a Void instance.
|
|
OS << tok::l_paren << tok::r_paren;
|
|
} else if (HandlerDesc.getSuccessParams().size() > 1) {
|
|
// If the async method returns a tuple, we need to pass its elements to
|
|
// the completion handler separately. For example:
|
|
//
|
|
// func foo() async -> (String, Int) {}
|
|
//
|
|
// causes the following legacy body to be created:
|
|
//
|
|
// func foo(completion: (String, Int) -> Void) {
|
|
// Task {
|
|
// let result = await foo()
|
|
// completion(result.0, result.1)
|
|
// }
|
|
// }
|
|
OS << ResultName << tok::period;
|
|
|
|
auto Label = HandlerDesc.getAsyncReturnTypeLabel(Index);
|
|
if (!Label.empty()) {
|
|
OS << Label;
|
|
} else {
|
|
OS << Index;
|
|
}
|
|
} else {
|
|
OS << ResultName;
|
|
}
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addCallToCompletionHandler(
|
|
StringRef ResultName, const AsyncHandlerDesc &HandlerDesc,
|
|
StringRef HandlerName) {
|
|
OS << HandlerName << tok::l_paren;
|
|
|
|
// Construct arguments to pass to the completion handler
|
|
switch (HandlerDesc.Type) {
|
|
case HandlerType::INVALID:
|
|
llvm_unreachable("Cannot be rewritten");
|
|
break;
|
|
case HandlerType::PARAMS: {
|
|
for (size_t I = 0; I < HandlerDesc.params().size(); ++I) {
|
|
if (I > 0) {
|
|
OS << tok::comma << " ";
|
|
}
|
|
addCompletionHandlerArgument(I, ResultName, HandlerDesc);
|
|
}
|
|
break;
|
|
}
|
|
case HandlerType::RESULT: {
|
|
if (!ResultName.empty()) {
|
|
OS << tok::period_prefix << "success" << tok::l_paren;
|
|
if (!HandlerDesc.willAsyncReturnVoid()) {
|
|
OS << ResultName;
|
|
} else {
|
|
OS << tok::l_paren << tok::r_paren;
|
|
}
|
|
OS << tok::r_paren;
|
|
} else {
|
|
OS << tok::period_prefix << "failure" << tok::l_paren;
|
|
addForwardedErrorArgument("error", HandlerDesc);
|
|
OS << tok::r_paren;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
OS << tok::r_paren; // Close the call to the completion handler
|
|
}
|
|
|
|
void AsyncConverter::addAsyncFuncReturnType(
|
|
const AsyncHandlerDesc &HandlerDesc) {
|
|
// Type or (Type1, Type2, ...)
|
|
SmallVector<LabeledReturnType, 2> Scratch;
|
|
auto ReturnTypes = HandlerDesc.getAsyncReturnTypes(Scratch);
|
|
if (ReturnTypes.empty()) {
|
|
OS << "Void";
|
|
} else {
|
|
addTupleOf(ReturnTypes, OS, [&](LabeledReturnType LabelAndType) {
|
|
if (!LabelAndType.Label.empty()) {
|
|
OS << LabelAndType.Label << tok::colon << " ";
|
|
}
|
|
LabelAndType.Ty->print(OS);
|
|
});
|
|
}
|
|
}
|
|
|
|
void AsyncConverter::addResultTypeAnnotationIfNecessary(
|
|
const FuncDecl *FD, const AsyncHandlerDesc &HandlerDesc) {
|
|
if (FD->hasGenericParamList()) {
|
|
OS << tok::colon << " ";
|
|
addAsyncFuncReturnType(HandlerDesc);
|
|
}
|
|
}
|