[Refactoring] Only unwrap optionals if the handler has an optional error

Resolves rdar://73973459
This commit is contained in:
Ben Barham
2021-05-04 13:36:33 +10:00
parent 058613dd42
commit 398124c61a
4 changed files with 252 additions and 109 deletions

View File

@@ -4284,6 +4284,10 @@ struct AsyncHandlerDesc {
return getSuccessParamAsyncReturnType(Ty)->isVoid();
});
}
bool shouldUnwrap(swift::Type Ty) const {
return HasError && Ty->isOptional();
}
};
enum class ConditionType { INVALID, NIL, NOT_NIL };
@@ -4549,18 +4553,12 @@ struct CallbackClassifier {
static void classifyInto(ClassifiedBlocks &Blocks,
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
DiagnosticEngine &DiagEngine,
ArrayRef<const ParamDecl *> SuccessParams,
llvm::DenseSet<const Decl *> UnwrapParams,
const ParamDecl *ErrParam, HandlerType ResultType,
ArrayRef<ASTNode> Body) {
assert(!Body.empty() && "Cannot classify empty body");
auto ParamsSet = llvm::DenseSet<const Decl *>(SuccessParams.begin(),
SuccessParams.end());
if (ErrParam)
ParamsSet.insert(ErrParam);
CallbackClassifier Classifier(Blocks, HandledSwitches, DiagEngine,
ParamsSet, ErrParam,
UnwrapParams, ErrParam,
ResultType == HandlerType::RESULT);
Classifier.classifyNodes(Body);
}
@@ -4570,19 +4568,19 @@ private:
llvm::DenseSet<SwitchStmt *> &HandledSwitches;
DiagnosticEngine &DiagEngine;
ClassifiedBlock *CurrentBlock;
llvm::DenseSet<const Decl *> ParamsSet;
llvm::DenseSet<const Decl *> UnwrapParams;
const ParamDecl *ErrParam;
bool IsResultParam;
CallbackClassifier(ClassifiedBlocks &Blocks,
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
DiagnosticEngine &DiagEngine,
llvm::DenseSet<const Decl *> ParamsSet,
llvm::DenseSet<const Decl *> UnwrapParams,
const ParamDecl *ErrParam, bool IsResultParam)
: Blocks(Blocks), HandledSwitches(HandledSwitches),
DiagEngine(DiagEngine), CurrentBlock(&Blocks.SuccessBlock),
ParamsSet(ParamsSet), ErrParam(ErrParam), IsResultParam(IsResultParam) {
}
UnwrapParams(UnwrapParams), ErrParam(ErrParam),
IsResultParam(IsResultParam) {}
void classifyNodes(ArrayRef<ASTNode> Nodes) {
for (auto I = Nodes.begin(), E = Nodes.end(); I < E; ++I) {
@@ -4614,7 +4612,7 @@ private:
ArrayRef<ASTNode> ThenNodes, Stmt *ElseStmt) {
llvm::DenseMap<const Decl *, CallbackCondition> CallbackConditions;
bool UnhandledConditions =
!CallbackCondition::all(Condition, ParamsSet, CallbackConditions);
!CallbackCondition::all(Condition, UnwrapParams, CallbackConditions);
CallbackCondition ErrCondition = CallbackConditions.lookup(ErrParam);
if (UnhandledConditions) {
@@ -4942,7 +4940,7 @@ private:
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
if (HandlerDesc.isValid())
return addCustom(CE->getSourceRange(),
[&]() { addAsyncAlternativeCall(CE, HandlerDesc); });
[&]() { addHoistedCallback(CE, HandlerDesc); });
}
}
@@ -5145,8 +5143,8 @@ private:
}
}
void addAsyncAlternativeCall(const CallExpr *CE,
const AsyncHandlerDesc &HandlerDesc) {
void addHoistedCallback(const CallExpr *CE,
const AsyncHandlerDesc &HandlerDesc) {
auto ArgList = callArgs(CE);
if ((size_t)HandlerDesc.Index >= ArgList.ref().size()) {
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
@@ -5159,53 +5157,63 @@ private:
return;
}
ParameterList *CallbackParams = Callback->getParameters();
ArrayRef<const ParamDecl *> CallbackParams =
Callback->getParameters()->getArray();
ArrayRef<ASTNode> CallbackBody = Callback->getBody()->getElements();
if (HandlerDesc.params().size() != CallbackParams->size()) {
if (HandlerDesc.params().size() != CallbackParams.size()) {
DiagEngine.diagnose(CE->getStartLoc(), diag::mismatched_callback_args);
return;
}
// Note that the `ErrParam` may be a Result (in which case it's also the
// only element in `SuccessParams`)
ArrayRef<const ParamDecl *> SuccessParams = CallbackParams->getArray();
ArrayRef<const ParamDecl *> SuccessParams = CallbackParams;
const ParamDecl *ErrParam = nullptr;
if (HandlerDesc.HasError) {
if (HandlerDesc.Type == HandlerType::RESULT) {
ErrParam = SuccessParams.back();
if (HandlerDesc.Type == HandlerType::PARAMS)
SuccessParams = SuccessParams.drop_back();
} else if (HandlerDesc.HasError) {
assert(HandlerDesc.Type == HandlerType::PARAMS);
ErrParam = SuccessParams.back();
SuccessParams = SuccessParams.drop_back();
}
ArrayRef<const ParamDecl *> ErrParams;
if (ErrParam)
ErrParams = llvm::makeArrayRef(ErrParam);
ClassifiedBlocks Blocks;
if (!HandlerDesc.HasError) {
Blocks.SuccessBlock.addAllNodes(CallbackBody);
} else if (!CallbackBody.empty()) {
CallbackClassifier::classifyInto(Blocks, HandledSwitches, DiagEngine,
SuccessParams, ErrParam,
HandlerDesc.Type, CallbackBody);
if (DiagEngine.hadAnyError()) {
// Can only fallback when the results are params, in which case only
// the names are used (defaulted to the names of the params if none)
if (HandlerDesc.Type != HandlerType::PARAMS)
return;
DiagEngine.resetHadAnyError();
setNames(ClassifiedBlock(), CallbackParams->getArray());
addFallbackVars(CallbackParams->getArray(), Blocks);
addDo();
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
HandlerDesc, /*AddDeclarations=*/!HandlerDesc.HasError);
addFallbackCatch(ErrParam);
OS << "\n";
convertNodes(CallbackBody);
clearParams(CallbackParams->getArray());
return;
llvm::DenseSet<const Decl *> UnwrapParams;
for (auto *Param : SuccessParams) {
if (HandlerDesc.shouldUnwrap(Param->getType()))
UnwrapParams.insert(Param);
}
if (ErrParam)
UnwrapParams.insert(ErrParam);
CallbackClassifier::classifyInto(Blocks, HandledSwitches, DiagEngine,
UnwrapParams, ErrParam, HandlerDesc.Type,
CallbackBody);
}
if (DiagEngine.hadAnyError()) {
// Can only fallback when the results are params, in which case only
// the names are used (defaulted to the names of the params if none)
if (HandlerDesc.Type != HandlerType::PARAMS)
return;
DiagEngine.resetHadAnyError();
// Don't do any unwrapping or placeholder replacement since all params
// are still valid in the fallback case
prepareNames(ClassifiedBlock(), CallbackParams);
addFallbackVars(CallbackParams, Blocks);
addDo();
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
HandlerDesc, /*AddDeclarations=*/!HandlerDesc.HasError);
addFallbackCatch(ErrParam);
OS << "\n";
convertNodes(CallbackBody);
clearNames(CallbackParams);
return;
}
bool RequireDo = !Blocks.ErrorBlock.nodes().empty();
@@ -5229,25 +5237,27 @@ private:
addDo();
}
setNames(Blocks.SuccessBlock, SuccessParams);
prepareNames(Blocks.SuccessBlock, SuccessParams);
preparePlaceholdersAndUnwraps(HandlerDesc, SuccessParams, ErrParam,
/*Success=*/true);
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
HandlerDesc, /*AddDeclarations=*/true);
prepareNamesForBody(HandlerDesc, SuccessParams, ErrParams);
convertNodes(Blocks.SuccessBlock.nodes());
clearNames(SuccessParams);
if (RequireDo) {
clearParams(SuccessParams);
// Always use the ErrParam name if none is bound
setNames(Blocks.ErrorBlock, ErrParams,
HandlerDesc.Type != HandlerType::RESULT);
prepareNames(Blocks.ErrorBlock, llvm::makeArrayRef(ErrParam),
HandlerDesc.Type != HandlerType::RESULT);
preparePlaceholdersAndUnwraps(HandlerDesc, SuccessParams, ErrParam,
/*Success=*/false);
addCatch(ErrParam);
prepareNamesForBody(HandlerDesc, ErrParams, SuccessParams);
addCatchBody(ErrParam, Blocks.ErrorBlock);
convertNodes(Blocks.ErrorBlock.nodes());
OS << "\n" << tok::r_brace;
clearNames(llvm::makeArrayRef(ErrParam));
}
clearParams(CallbackParams->getArray());
}
void addAwaitCall(const CallExpr *CE, ArrayRef<Expr *> Args,
@@ -5318,44 +5328,56 @@ private:
OS << tok::l_brace;
}
void addCatchBody(const ParamDecl *ErrParam,
const ClassifiedBlock &ErrorBlock) {
convertNodes(ErrorBlock.nodes());
OS << "\n" << tok::r_brace;
}
void prepareNamesForBody(const AsyncHandlerDesc &HandlerDesc,
ArrayRef<const ParamDecl *> CurrentParams,
ArrayRef<const ParamDecl *> OtherParams) {
void preparePlaceholdersAndUnwraps(AsyncHandlerDesc HandlerDesc,
ArrayRef<const ParamDecl *> SuccessParams,
const ParamDecl *ErrParam, bool Success) {
switch (HandlerDesc.Type) {
case HandlerType::PARAMS:
for (auto *Param : CurrentParams) {
auto Ty = Param->getType();
if (Ty->getOptionalObjectType()) {
Unwraps.insert(Param);
Placeholders.insert(Param);
if (!Success) {
if (ErrParam) {
if (HandlerDesc.shouldUnwrap(ErrParam->getType())) {
Placeholders.insert(ErrParam);
Unwraps.insert(ErrParam);
}
// Can't use success params in the error body
Placeholders.insert(SuccessParams.begin(), SuccessParams.end());
}
// Void parameters get omitted where possible, so turn any reference
// into a placeholder, as its usage is unlikely what the user wants.
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
Placeholders.insert(Param);
} else {
for (auto *SuccessParam : SuccessParams) {
auto Ty = SuccessParam->getType();
if (HandlerDesc.shouldUnwrap(Ty)) {
// Either unwrap or replace with a placeholder if there's some other
// reference
Unwraps.insert(SuccessParam);
Placeholders.insert(SuccessParam);
}
// Void parameters get omitted where possible, so turn any reference
// into a placeholder, as its usage is unlikely what the user wants.
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
Placeholders.insert(SuccessParam);
}
// Can't use the error param in the success body
if (ErrParam)
Placeholders.insert(ErrParam);
}
// Use of the other params is invalid within the current body
Placeholders.insert(OtherParams.begin(), OtherParams.end());
break;
case HandlerType::RESULT:
// Any uses of the result parameter in the current body (that
// isn't replaced) are invalid, so replace them with a placeholder
Placeholders.insert(CurrentParams.begin(), CurrentParams.end());
// Any uses of the result parameter in the current body (that aren't
// replaced) are invalid, so replace them with a placeholder.
assert(SuccessParams.size() == 1 && SuccessParams[0] == ErrParam);
Placeholders.insert(ErrParam);
break;
default:
llvm_unreachable("Unhandled handler type");
}
}
// TODO: Check for clashes with existing names
void setNames(const ClassifiedBlock &Block,
ArrayRef<const ParamDecl *> Params, bool AddIfMissing = true) {
// TODO: Check for clashes with existing names and add all decls, not just
// params
void prepareNames(const ClassifiedBlock &Block,
ArrayRef<const ParamDecl *> Params,
bool AddIfMissing = true) {
for (auto *Param : Params) {
StringRef Name = Block.boundName(Param);
if (!Name.empty()) {
@@ -5384,7 +5406,7 @@ private:
return StringRef(Res->second);
}
void clearParams(ArrayRef<const ParamDecl *> Params) {
void clearNames(ArrayRef<const ParamDecl *> Params) {
for (auto *Param : Params) {
Unwraps.erase(Param);
Placeholders.erase(Param);