[Sema] Split out result builder pre-checking

Use `preCheckTarget` to pre-check the body,
allowing us to replace `PreCheckResultBuilderRequest`
with a request that only checks the brace for
ReturnStmts.
This commit is contained in:
Hamish Knight
2024-08-03 16:05:09 +01:00
parent 23e8340a13
commit a73e44a78f
5 changed files with 36 additions and 223 deletions

View File

@@ -3028,60 +3028,8 @@ public:
void cacheResult(ProtocolConformanceRef value) const;
};
struct PreCheckResultBuilderDescriptor {
AnyFunctionRef Fn;
private:
// NOTE: Since source tooling (e.g. code completion) might replace the body,
// we need to take the body into account to calculate 'hash_value' and '=='.
// Also, we cannot 'getBody()' inside 'hash_value' and '==' because it invokes
// another request (even if it's cached).
BraceStmt *Body;
public:
PreCheckResultBuilderDescriptor(AnyFunctionRef Fn)
: Fn(Fn), Body(Fn.getBody()) {}
friend llvm::hash_code
hash_value(const PreCheckResultBuilderDescriptor &owner) {
return llvm::hash_combine(owner.Fn, owner.Body);
}
friend bool operator==(const PreCheckResultBuilderDescriptor &lhs,
const PreCheckResultBuilderDescriptor &rhs) {
return lhs.Fn == rhs.Fn && lhs.Body == rhs.Body;
}
friend bool operator!=(const PreCheckResultBuilderDescriptor &lhs,
const PreCheckResultBuilderDescriptor &rhs) {
return !(lhs == rhs);
}
friend SourceLoc extractNearestSourceLoc(PreCheckResultBuilderDescriptor d) {
return extractNearestSourceLoc(d.Fn);
}
friend void simple_display(llvm::raw_ostream &out,
const PreCheckResultBuilderDescriptor &d) {
simple_display(out, d.Fn);
}
};
enum class ResultBuilderBodyPreCheck : uint8_t {
/// There were no problems pre-checking the closure.
Okay,
/// There was an error pre-checking the closure.
Error,
/// The closure has a return statement.
HasReturnStmt,
};
class PreCheckResultBuilderRequest
: public SimpleRequest<PreCheckResultBuilderRequest,
ResultBuilderBodyPreCheck(
PreCheckResultBuilderDescriptor),
class BraceHasReturnRequest
: public SimpleRequest<BraceHasReturnRequest, bool(const BraceStmt *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
@@ -3090,8 +3038,7 @@ private:
friend SimpleRequest;
// Evaluation.
ResultBuilderBodyPreCheck
evaluate(Evaluator &evaluator, PreCheckResultBuilderDescriptor owner) const;
bool evaluate(Evaluator &evaluator, const BraceStmt *BS) const;
public:
// Separate caching.

View File

@@ -381,8 +381,8 @@ SWIFT_REQUEST(TypeChecker, HasUserDefinedDesignatedInitRequest,
bool(NominalTypeDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, HasMemberwiseInitRequest,
bool(StructDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, PreCheckResultBuilderRequest,
ResultBuilderBodyPreCheck(PreCheckResultBuilderDescriptor),
SWIFT_REQUEST(TypeChecker, BraceHasReturnRequest,
bool(const BraceStmt *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, ResolveImplicitMemberRequest,
evaluator::SideEffect(NominalTypeDecl *, ImplicitMemberAction),

View File

@@ -1317,25 +1317,6 @@ void AssociatedConformanceRequest::cacheResult(
conformance->setAssociatedConformance(index, assocConf);
}
//----------------------------------------------------------------------------//
// PreCheckResultBuilderRequest computation.
//----------------------------------------------------------------------------//
void swift::simple_display(llvm::raw_ostream &out,
ResultBuilderBodyPreCheck value) {
switch (value) {
case ResultBuilderBodyPreCheck::Okay:
out << "okay";
break;
case ResultBuilderBodyPreCheck::HasReturnStmt:
out << "has return statement";
break;
case ResultBuilderBodyPreCheck::Error:
out << "error";
break;
}
}
//----------------------------------------------------------------------------//
// HasCircularInheritedProtocolsRequest computation.
//----------------------------------------------------------------------------//

View File

@@ -920,23 +920,10 @@ private:
std::optional<BraceStmt *>
TypeChecker::applyResultBuilderBodyTransform(FuncDecl *func, Type builderType) {
// Pre-check the body: pre-check any expressions in it and look
// for return statements.
//
// If we encountered an error or there was an explicit result type,
// bail out and report that to the caller.
// First look for any return statements, and bail if we have any.
auto &ctx = func->getASTContext();
auto request = PreCheckResultBuilderRequest{AnyFunctionRef(func)};
switch (evaluateOrDefault(ctx.evaluator, request,
ResultBuilderBodyPreCheck::Error)) {
case ResultBuilderBodyPreCheck::Okay:
// If the pre-check was okay, apply the result-builder transform.
break;
case ResultBuilderBodyPreCheck::Error:
return nullptr;
case ResultBuilderBodyPreCheck::HasReturnStmt: {
if (evaluateOrDefault(ctx.evaluator, BraceHasReturnRequest{func->getBody()},
false)) {
// One or more explicit 'return' statements were encountered, which
// disables the result builder transform. Warn when we do this.
auto returnStmts = findReturnStatements(func);
@@ -970,7 +957,10 @@ TypeChecker::applyResultBuilderBodyTransform(FuncDecl *func, Type builderType) {
return std::nullopt;
}
}
auto target = SyntacticElementTarget(func);
if (ConstraintSystem::preCheckTarget(target))
return nullptr;
ConstraintSystemOptions options = ConstraintSystemFlags::AllowFixes;
auto resultInterfaceTy = func->getResultInterfaceType();
@@ -1018,8 +1008,7 @@ TypeChecker::applyResultBuilderBodyTransform(FuncDecl *func, Type builderType) {
cs.Options |= ConstraintSystemFlags::ForCodeCompletion;
cs.solveForCodeCompletion(solutions);
SyntacticElementTarget funcTarget(func);
CompletionContextFinder analyzer(funcTarget, func->getDeclContext());
CompletionContextFinder analyzer(target, func->getDeclContext());
if (analyzer.hasCompletion()) {
filterSolutionsForCodeCompletion(solutions, analyzer);
for (const auto &solution : solutions) {
@@ -1066,7 +1055,7 @@ TypeChecker::applyResultBuilderBodyTransform(FuncDecl *func, Type builderType) {
case SolutionResult::Kind::UndiagnosedError:
reportSolutionsToSolutionCallback(salvagedResult);
cs.diagnoseFailureFor(SyntacticElementTarget(func));
cs.diagnoseFailureFor(target);
salvagedResult.markAsDiagnosed();
return nullptr;
@@ -1100,8 +1089,7 @@ TypeChecker::applyResultBuilderBodyTransform(FuncDecl *func, Type builderType) {
cs.applySolution(solutions.front());
// Apply the solution to the function body.
if (auto result =
cs.applySolution(solutions.front(), SyntacticElementTarget(func))) {
if (auto result = cs.applySolution(solutions.front(), target)) {
performSyntacticDiagnosticsForTarget(*result, /*isExprStmt*/ false);
auto *body = result->getFunctionBody();
@@ -1142,21 +1130,8 @@ ConstraintSystem::matchResultBuilder(AnyFunctionRef fn, Type builderType,
// not apply the result builder transform if it contained an explicit return.
// To maintain source compatibility, we still need to check for HasReturnStmt.
// https://github.com/apple/swift/issues/64332.
switch (evaluateOrDefault(getASTContext().evaluator,
PreCheckResultBuilderRequest{fn},
ResultBuilderBodyPreCheck::Error)) {
case ResultBuilderBodyPreCheck::Okay:
// If the pre-check was okay, apply the result-builder transform.
break;
case ResultBuilderBodyPreCheck::Error: {
llvm_unreachable(
"Running PreCheckResultBuilderRequest on a function shouldn't run "
"preCheckExpression and thus we should never enter this case.");
break;
}
case ResultBuilderBodyPreCheck::HasReturnStmt:
if (evaluateOrDefault(getASTContext().evaluator,
BraceHasReturnRequest{fn.getBody()}, false)) {
// Diagnostic mode means that solver couldn't reach any viable
// solution, so let's diagnose presence of a `return` statement
// in the closure body.
@@ -1257,42 +1232,14 @@ ConstraintSystem::matchResultBuilder(AnyFunctionRef fn, Type builderType,
}
namespace {
/// Pre-check all the expressions in the body.
class PreCheckResultBuilderApplication : public ASTWalker {
AnyFunctionRef Fn;
bool SkipPrecheck = false;
bool SuppressDiagnostics = false;
class ReturnStmtFinder : public ASTWalker {
std::vector<ReturnStmt *> ReturnStmts;
bool HasError = false;
bool hasReturnStmt() const { return !ReturnStmts.empty(); }
public:
PreCheckResultBuilderApplication(AnyFunctionRef fn, bool skipPrecheck,
bool suppressDiagnostics)
: Fn(fn), SkipPrecheck(skipPrecheck),
SuppressDiagnostics(suppressDiagnostics) {}
const std::vector<ReturnStmt *> getReturnStmts() const { return ReturnStmts; }
ResultBuilderBodyPreCheck run() {
Stmt *oldBody = Fn.getBody();
Stmt *newBody = oldBody->walk(*this);
// If the walk was aborted, it was because we had a problem of some kind.
assert((newBody == nullptr) == HasError &&
"unexpected short-circuit while walking body");
if (HasError)
return ResultBuilderBodyPreCheck::Error;
assert(oldBody == newBody && "pre-check walk wasn't in-place?");
if (hasReturnStmt())
return ResultBuilderBodyPreCheck::HasReturnStmt;
return ResultBuilderBodyPreCheck::Okay;
static std::vector<ReturnStmt *> find(const BraceStmt *BS) {
ReturnStmtFinder finder;
const_cast<BraceStmt *>(BS)->walk(finder);
return std::move(finder.ReturnStmts);
}
MacroWalking getMacroWalkingBehavior() const override {
@@ -1300,70 +1247,17 @@ public:
}
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
if (SkipPrecheck)
return Action::SkipNode(E);
// Pre-check the expression. If this fails, abort the walk immediately.
// Otherwise, replace the expression with the result of pre-checking.
// In either case, don't recurse into the expression.
{
auto *DC = Fn.getAsDeclContext();
auto &diagEngine = DC->getASTContext().Diags;
// Suppress any diagnostics which could be produced by this expression.
DiagnosticTransaction transaction(diagEngine);
HasError |= ConstraintSystem::preCheckExpression(E, DC);
HasError |= transaction.hasErrors();
if (!HasError)
HasError |= containsErrorExpr(E);
if (SuppressDiagnostics)
transaction.abort();
if (HasError)
return Action::Stop();
return Action::SkipNode(E);
}
return Action::SkipNode(E);
}
PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
// If we see a return statement, note it..
if (auto returnStmt = dyn_cast<ReturnStmt>(S)) {
if (!returnStmt->isImplicit()) {
ReturnStmts.push_back(returnStmt);
return Action::SkipNode(S);
}
}
auto *returnStmt = dyn_cast<ReturnStmt>(S);
if (!returnStmt || returnStmt->isImplicit())
return Action::Continue(S);
// Otherwise, recurse into the statement normally.
return Action::Continue(S);
}
/// Check whether given expression (including single-statement
/// closures) contains `ErrorExpr` as one of its sub-expressions.
bool containsErrorExpr(Expr *expr) {
bool hasError = false;
expr->forEachChildExpr([&](Expr *expr) -> Expr * {
hasError |= isa<ErrorExpr>(expr);
if (hasError)
return nullptr;
if (auto *closure = dyn_cast<ClosureExpr>(expr)) {
if (closure->hasSingleExpressionBody()) {
hasError |= containsErrorExpr(closure->getSingleExpressionBody());
return hasError ? nullptr : expr;
}
}
return expr;
});
return hasError;
ReturnStmts.push_back(returnStmt);
return Action::SkipNode(S);
}
/// Ignore patterns.
@@ -1371,24 +1265,15 @@ public:
return Action::SkipNode(pat);
}
};
} // end anonymous namespace
}
ResultBuilderBodyPreCheck PreCheckResultBuilderRequest::evaluate(
Evaluator &evaluator, PreCheckResultBuilderDescriptor owner) const {
// Closures should already be pre-checked when we run this, so there's no need
// to pre-check them again.
bool skipPrecheck = owner.Fn.getAbstractClosureExpr();
return PreCheckResultBuilderApplication(
owner.Fn, skipPrecheck, /*suppressDiagnostics=*/false)
.run();
bool BraceHasReturnRequest::evaluate(Evaluator &evaluator,
const BraceStmt *BS) const {
return !ReturnStmtFinder::find(BS).empty();
}
std::vector<ReturnStmt *> TypeChecker::findReturnStatements(AnyFunctionRef fn) {
PreCheckResultBuilderApplication precheck(fn, /*skipPreCheck=*/true,
/*SuppressDiagnostics=*/true);
(void)precheck.run();
return precheck.getReturnStmts();
return ReturnStmtFinder::find(fn.getBody());
}
ResultBuilderOpSupport TypeChecker::checkBuilderOpSupport(

View File

@@ -79,7 +79,7 @@ struct TupleBuilderWithoutIf { // expected-note 3{{struct 'TupleBuilderWithoutIf
static func buildDo<T>(_ value: T) -> T { return value }
}
func tuplify<T>(_ cond: Bool, @TupleBuilder body: (Bool) -> T) { // expected-note {{'tuplify(_:body:)' declared here}}
func tuplify<T>(_ cond: Bool, @TupleBuilder body: (Bool) -> T) { // expected-note 2{{'tuplify(_:body:)' declared here}}
print(body(cond))
}
@@ -455,7 +455,7 @@ func testNonExhaustiveSwitch(e: E) {
// rdar://problem/59856491
struct TestConstraintGenerationErrors {
@TupleBuilder var buildTupleFnBody: String {
let a = nil // There is no diagnostic here because next line fails to pre-check, so body is invalid
let a = nil // expected-error {{'nil' requires a contextual type}}
String(nothing) // expected-error {{cannot find 'nothing' in scope}}
}
@@ -722,7 +722,7 @@ struct TuplifiedStructWithInvalidClosure {
@TupleBuilder var errorsDiagnosedByParser: some Any {
if let _ = condition {
tuplify { _ in
tuplify { _ in // expected-error {{missing argument for parameter #1 in call}}
self. // expected-error {{expected member name following '.'}}
}
42