Files
swift-mirror/lib/Sema/BuilderTransform.cpp
Pavel Yaskevich 9dc2dfde98 [ConstraintSystem] Allow single-statement closures be used with function builders
While resolving closure use contextual locator information to
determine whether given contextual type comes from a "function builder"
parameter and if so, use special `applyFunctionBuilder` to open
closure body instead of regular constraint generation.
2020-01-14 00:09:32 -08:00

800 lines
28 KiB
C++

//===--- BuilderTransform.cpp - Function-builder transformation -----------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements routines associated with the function-builder
// transformation.
//
//===----------------------------------------------------------------------===//
#include "ConstraintSystem.h"
#include "SolutionResult.h"
#include "TypeChecker.h"
#include "swift/AST/ASTVisitor.h"
#include "swift/AST/ASTWalker.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/NameLookupRequests.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/TypeCheckRequests.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <iterator>
#include <map>
#include <memory>
#include <utility>
#include <tuple>
using namespace swift;
using namespace constraints;
namespace {
/// Visitor to classify the contents of the given closure.
class BuilderClosureVisitor
: public StmtVisitor<BuilderClosureVisitor, Expr *> {
ConstraintSystem *cs;
ASTContext &ctx;
bool wantExpr;
Type builderType;
NominalTypeDecl *builder = nullptr;
llvm::SmallDenseMap<Identifier, bool> supportedOps;
public:
SkipUnhandledConstructInFunctionBuilder::UnhandledNode unhandledNode;
private:
/// Produce a builder call to the given named function with the given arguments.
Expr *buildCallIfWanted(SourceLoc loc,
Identifier fnName, ArrayRef<Expr *> args,
ArrayRef<Identifier> argLabels,
bool oneWay) {
if (!wantExpr)
return nullptr;
// FIXME: Setting a TypeLoc on this expression is necessary in order
// to get diagnostics if something about this builder call fails,
// e.g. if there isn't a matching overload for `buildBlock`.
// But we can only do this if there isn't a type variable in the type.
TypeLoc typeLoc;
if (!builderType->hasTypeVariable()) {
typeLoc = TypeLoc(new (ctx) FixedTypeRepr(builderType, loc), builderType);
}
auto typeExpr = new (ctx) TypeExpr(typeLoc);
if (cs) {
cs->setType(typeExpr, MetatypeType::get(builderType));
cs->setType(&typeExpr->getTypeLoc(), builderType);
}
SmallVector<SourceLoc, 4> argLabelLocs;
for (auto i : indices(argLabels)) {
argLabelLocs.push_back(args[i]->getStartLoc());
}
typeExpr->setImplicit();
auto memberRef = new (ctx) UnresolvedDotExpr(
typeExpr, loc, DeclNameRef(fnName), DeclNameLoc(loc),
/*implicit=*/true);
SourceLoc openLoc = args.empty() ? loc : args.front()->getStartLoc();
SourceLoc closeLoc = args.empty() ? loc : args.back()->getEndLoc();
Expr *result = CallExpr::create(ctx, memberRef, openLoc, args,
argLabels, argLabelLocs, closeLoc,
/*trailing closure*/ nullptr,
/*implicit*/true);
if (oneWay) {
// Form a one-way constraint to prevent backward propagation.
result = new (ctx) OneWayExpr(result);
}
return result;
}
/// Check whether the builder supports the given operation.
bool builderSupports(Identifier fnName,
ArrayRef<Identifier> argLabels = {}) {
auto known = supportedOps.find(fnName);
if (known != supportedOps.end()) {
return known->second;
}
bool found = false;
for (auto decl : builder->lookupDirect(fnName)) {
if (auto func = dyn_cast<FuncDecl>(decl)) {
// Function must be static.
if (!func->isStatic())
continue;
// Function must have the right argument labels, if provided.
if (!argLabels.empty()) {
auto funcLabels = func->getFullName().getArgumentNames();
if (argLabels.size() > funcLabels.size() ||
funcLabels.slice(0, argLabels.size()) != argLabels)
continue;
}
// Okay, it's a good-enough match.
found = true;
break;
}
}
return supportedOps[fnName] = found;
}
public:
BuilderClosureVisitor(ASTContext &ctx, ConstraintSystem *cs,
bool wantExpr, Type builderType)
: cs(cs), ctx(ctx), wantExpr(wantExpr), builderType(builderType) {
assert((cs || !builderType->hasTypeVariable()) &&
"cannot handle builder type with type variables without "
"constraint system");
builder = builderType->getAnyNominal();
}
#define CONTROL_FLOW_STMT(StmtClass) \
Expr *visit##StmtClass##Stmt(StmtClass##Stmt *stmt) { \
if (!unhandledNode) \
unhandledNode = stmt; \
\
return nullptr; \
}
Expr *visitBraceStmt(BraceStmt *braceStmt) {
SmallVector<Expr *, 4> expressions;
for (const auto &node : braceStmt->getElements()) {
if (auto stmt = node.dyn_cast<Stmt *>()) {
auto expr = visit(stmt);
if (expr)
expressions.push_back(expr);
continue;
}
if (auto decl = node.dyn_cast<Decl *>()) {
// Just ignore #if; the chosen children should appear in the
// surrounding context. This isn't good for source tools but it
// at least works.
if (isa<IfConfigDecl>(decl))
continue;
// Emit #warning/#error but don't build anything for it.
if (auto poundDiag = dyn_cast<PoundDiagnosticDecl>(decl)) {
TypeChecker::typeCheckDecl(poundDiag);
continue;
}
if (!unhandledNode)
unhandledNode = decl;
continue;
}
auto expr = node.get<Expr *>();
if (wantExpr) {
if (builderSupports(ctx.Id_buildExpression)) {
expr = buildCallIfWanted(expr->getLoc(), ctx.Id_buildExpression,
{ expr }, { Identifier() },
/*oneWay=*/false);
}
expr = new (ctx) OneWayExpr(expr);
}
expressions.push_back(expr);
}
// Call Builder.buildBlock(... args ...)
return buildCallIfWanted(braceStmt->getStartLoc(),
ctx.Id_buildBlock, expressions,
/*argLabels=*/{ },
/*oneWay=*/true);
}
Expr *visitReturnStmt(ReturnStmt *stmt) {
// Allow implicit returns due to 'return' elision.
if (!stmt->isImplicit() || !stmt->hasResult()) {
if (!unhandledNode)
unhandledNode = stmt;
return nullptr;
}
return stmt->getResult();
}
Expr *visitDoStmt(DoStmt *doStmt) {
if (!builderSupports(ctx.Id_buildDo)) {
if (!unhandledNode)
unhandledNode = doStmt;
return nullptr;
}
auto arg = visit(doStmt->getBody());
if (!arg)
return nullptr;
return buildCallIfWanted(doStmt->getStartLoc(), ctx.Id_buildDo, arg,
/*argLabels=*/{ }, /*oneWay=*/true);
}
CONTROL_FLOW_STMT(Yield)
CONTROL_FLOW_STMT(Defer)
static Expr *getTrivialBooleanCondition(StmtCondition condition) {
if (condition.size() != 1)
return nullptr;
return condition.front().getBooleanOrNull();
}
static bool isBuildableIfChainRecursive(IfStmt *ifStmt,
unsigned &numPayloads,
bool &isOptional) {
// The conditional must be trivial.
if (!getTrivialBooleanCondition(ifStmt->getCond()))
return false;
// The 'then' clause contributes a payload.
numPayloads++;
// If there's an 'else' clause, it contributes payloads:
if (auto elseStmt = ifStmt->getElseStmt()) {
// If it's 'else if', it contributes payloads recursively.
if (auto elseIfStmt = dyn_cast<IfStmt>(elseStmt)) {
return isBuildableIfChainRecursive(elseIfStmt, numPayloads,
isOptional);
// Otherwise it's just the one.
} else {
numPayloads++;
}
// If not, the chain result is at least optional.
} else {
isOptional = true;
}
return true;
}
bool isBuildableIfChain(IfStmt *ifStmt, unsigned &numPayloads,
bool &isOptional) {
if (!isBuildableIfChainRecursive(ifStmt, numPayloads, isOptional))
return false;
// If there's a missing 'else', we need 'buildIf' to exist.
if (isOptional && !builderSupports(ctx.Id_buildIf))
return false;
// If there are multiple clauses, we need 'buildEither(first:)' and
// 'buildEither(second:)' to both exist.
if (numPayloads > 1) {
if (!builderSupports(ctx.Id_buildEither, {ctx.Id_first}) ||
!builderSupports(ctx.Id_buildEither, {ctx.Id_second}))
return false;
}
return true;
}
Expr *visitIfStmt(IfStmt *ifStmt) {
// Check whether the chain is buildable and whether it terminates
// without an `else`.
bool isOptional = false;
unsigned numPayloads = 0;
if (!isBuildableIfChain(ifStmt, numPayloads, isOptional)) {
if (!unhandledNode)
unhandledNode = ifStmt;
return nullptr;
}
// Attempt to build the chain, propagating short-circuits, which
// might arise either do to error or not wanting an expression.
auto chainExpr =
buildIfChainRecursive(ifStmt, 0, numPayloads, isOptional);
if (!chainExpr)
return nullptr;
assert(wantExpr);
// The operand should have optional type if we had optional results,
// so we just need to call `buildIf` now, since we're at the top level.
if (isOptional) {
chainExpr = buildCallIfWanted(ifStmt->getStartLoc(),
ctx.Id_buildIf, chainExpr,
/*argLabels=*/{ },
/*oneWay=*/true);
} else {
// Form a one-way constraint to prevent backward propagation.
chainExpr = new (ctx) OneWayExpr(chainExpr);
}
return chainExpr;
}
/// Recursively build an if-chain: build an expression which will have
/// a value of the chain result type before any call to `buildIf`.
/// The expression will perform any necessary calls to `buildEither`,
/// and the result will have optional type if `isOptional` is true.
Expr *buildIfChainRecursive(IfStmt *ifStmt, unsigned payloadIndex,
unsigned numPayloads, bool isOptional) {
assert(payloadIndex < numPayloads);
// Make sure we recursively visit both sides even if we're not
// building expressions.
// Build the then clause. This will have the corresponding payload
// type (i.e. not wrapped in any way).
Expr *thenArg = visit(ifStmt->getThenStmt());
// Build the else clause, if present. If this is from an else-if,
// this will be fully wrapped; otherwise it will have the corresponding
// payload type (at index `payloadIndex + 1`).
assert(ifStmt->getElseStmt() || isOptional);
bool isElseIf = false;
Optional<Expr *> elseChain;
if (auto elseStmt = ifStmt->getElseStmt()) {
if (auto elseIfStmt = dyn_cast<IfStmt>(elseStmt)) {
isElseIf = true;
elseChain = buildIfChainRecursive(elseIfStmt, payloadIndex + 1,
numPayloads, isOptional);
} else {
elseChain = visit(elseStmt);
}
}
// Short-circuit if appropriate.
if (!wantExpr || !thenArg || (elseChain && !*elseChain))
return nullptr;
// Okay, build the conditional expression.
// Prepare the `then` operand by wrapping it to produce a chain result.
SourceLoc thenLoc = ifStmt->getThenStmt()->getStartLoc();
Expr *thenExpr = buildWrappedChainPayload(thenArg, payloadIndex,
numPayloads, isOptional);
// Prepare the `else operand:
Expr *elseExpr;
SourceLoc elseLoc;
// - If there's no `else` clause, use `Optional.none`.
if (!elseChain) {
assert(isOptional);
elseLoc = ifStmt->getEndLoc();
elseExpr = buildNoneExpr(elseLoc);
// - If there's an `else if`, the chain expression from that
// should already be producing a chain result.
} else if (isElseIf) {
elseExpr = *elseChain;
elseLoc = ifStmt->getElseLoc();
// - Otherwise, wrap it to produce a chain result.
} else {
elseLoc = ifStmt->getElseLoc();
elseExpr = buildWrappedChainPayload(*elseChain,
payloadIndex + 1, numPayloads,
isOptional);
}
Expr *condition = getTrivialBooleanCondition(ifStmt->getCond());
assert(condition && "checked by isBuildableIfChain");
auto ifExpr = new (ctx) IfExpr(condition, thenLoc, thenExpr,
elseLoc, elseExpr);
ifExpr->setImplicit();
return ifExpr;
}
/// Wrap a payload value in an expression which will produce a chain
/// result (without `buildIf`).
Expr *buildWrappedChainPayload(Expr *operand, unsigned payloadIndex,
unsigned numPayloads, bool isOptional) {
assert(payloadIndex < numPayloads);
// Inject into the appropriate chain position.
//
// We produce a (left-biased) balanced binary tree of Eithers in order
// to prevent requiring a linear number of injections in the worst case.
// That is, if we have 13 clauses, we want to produce:
//
// /------------------Either------------\
// /-------Either-------\ /--Either--\
// /--Either--\ /--Either--\ /--Either--\ \
// /-E-\ /-E-\ /-E-\ /-E-\ /-E-\ /-E-\ \
// 0000 0001 0010 0011 0100 0101 0110 0111 1000 1001 1010 1011 1100
//
// Note that a prefix of length D of the payload index acts as a path
// through the tree to the node at depth D. On the rightmost path
// through the tree (when this prefix is equal to the corresponding
// prefix of the maximum payload index), the bits of the index mark
// where Eithers are required.
//
// Since we naturally want to build from the innermost Either out, and
// therefore work with progressively shorter prefixes, we can do it all
// with right-shifts.
for (auto path = payloadIndex, maxPath = numPayloads - 1;
maxPath != 0; path >>= 1, maxPath >>= 1) {
// Skip making Eithers on the rightmost path where they aren't required.
// This isn't just an optimization: adding spurious Eithers could
// leave us with unresolvable type variables if `buildEither` has
// a signature like:
// static func buildEither<T,U>(first value: T) -> Either<T,U>
// which relies on unification to work.
if (path == maxPath && !(maxPath & 1)) continue;
bool isSecond = (path & 1);
operand = buildCallIfWanted(operand->getStartLoc(),
ctx.Id_buildEither, operand,
{isSecond ? ctx.Id_second : ctx.Id_first},
/*oneWay=*/false);
}
// Inject into Optional if required. We'll be adding the call to
// `buildIf` after all the recursive calls are complete.
if (isOptional) {
operand = buildSomeExpr(operand);
}
return operand;
}
Expr *buildSomeExpr(Expr *arg) {
auto optionalDecl = ctx.getOptionalDecl();
auto optionalType = optionalDecl->getDeclaredType();
auto loc = arg->getStartLoc();
auto optionalTypeExpr =
TypeExpr::createImplicitHack(loc, optionalType, ctx);
auto someRef = new (ctx) UnresolvedDotExpr(
optionalTypeExpr, loc, DeclNameRef(ctx.getIdentifier("some")),
DeclNameLoc(loc), /*implicit=*/true);
return CallExpr::createImplicit(ctx, someRef, arg, { });
}
Expr *buildNoneExpr(SourceLoc endLoc) {
auto optionalDecl = ctx.getOptionalDecl();
auto optionalType = optionalDecl->getDeclaredType();
auto optionalTypeExpr =
TypeExpr::createImplicitHack(endLoc, optionalType, ctx);
return new (ctx) UnresolvedDotExpr(
optionalTypeExpr, endLoc, DeclNameRef(ctx.getIdentifier("none")),
DeclNameLoc(endLoc), /*implicit=*/true);
}
CONTROL_FLOW_STMT(Guard)
CONTROL_FLOW_STMT(While)
CONTROL_FLOW_STMT(DoCatch)
CONTROL_FLOW_STMT(RepeatWhile)
CONTROL_FLOW_STMT(ForEach)
CONTROL_FLOW_STMT(Switch)
CONTROL_FLOW_STMT(Case)
CONTROL_FLOW_STMT(Catch)
CONTROL_FLOW_STMT(Break)
CONTROL_FLOW_STMT(Continue)
CONTROL_FLOW_STMT(Fallthrough)
CONTROL_FLOW_STMT(Fail)
CONTROL_FLOW_STMT(Throw)
CONTROL_FLOW_STMT(PoundAssert)
#undef CONTROL_FLOW_STMT
};
} // end anonymous namespace
/// Find the return statements in the given body, which block the application
/// of a function builder.
static std::vector<ReturnStmt *> findReturnStatements(AnyFunctionRef fn);
Optional<BraceStmt *> TypeChecker::applyFunctionBuilderBodyTransform(
FuncDecl *func, Type builderType) {
// Pre-check the body: pre-check any expressions in it and look
// for return statements.
//
// If we encountered an error or there was an explicit result type,
// bail out and report that to the caller.
auto &ctx = func->getASTContext();
auto request = PreCheckFunctionBuilderRequest{func};
switch (evaluateOrDefault(
ctx.evaluator, request, FunctionBuilderClosurePreCheck::Error)) {
case FunctionBuilderClosurePreCheck::Okay:
// If the pre-check was okay, apply the function-builder transform.
break;
case FunctionBuilderClosurePreCheck::Error:
return nullptr;
case FunctionBuilderClosurePreCheck::HasReturnStmt: {
// One or more explicit 'return' statements were encountered, which
// disables the function builder transform. Warn when we do this.
auto returnStmts = findReturnStatements(func);
assert(!returnStmts.empty());
ctx.Diags.diagnose(
returnStmts.front()->getReturnLoc(),
diag::function_builder_disabled_by_return, builderType);
// Note that one can remove the function builder attribute.
auto attr = func->getAttachedFunctionBuilder();
if (!attr) {
if (auto accessor = dyn_cast<AccessorDecl>(func)) {
attr = accessor->getStorage()->getAttachedFunctionBuilder();
}
}
if (attr) {
ctx.Diags.diagnose(
attr->getLocation(), diag::function_builder_remove_attr)
.fixItRemove(attr->getRangeWithAt());
attr->setInvalid();
}
// Note that one can remove all of the return statements.
{
auto diag = ctx.Diags.diagnose(
returnStmts.front()->getReturnLoc(),
diag::function_builder_remove_returns);
for (auto returnStmt : returnStmts) {
diag.fixItRemove(returnStmt->getReturnLoc());
}
}
return None;
}
}
ConstraintSystemOptions options = ConstraintSystemFlags::AllowFixes;
auto resultInterfaceTy = func->getResultInterfaceType();
auto resultContextType = func->mapTypeIntoContext(resultInterfaceTy);
// Determine whether we're inferring the underlying type for the opaque
// result type of this function.
ConstraintKind resultConstraintKind = ConstraintKind::Conversion;
if (auto opaque = resultContextType->getAs<OpaqueTypeArchetypeType>()) {
if (opaque->getDecl()->isOpaqueReturnTypeOfFunction(func)) {
options |= ConstraintSystemFlags::UnderlyingTypeForOpaqueReturnType;
resultConstraintKind = ConstraintKind::OpaqueUnderlyingType;
}
}
// Build a constraint system in which we can check the body of the function.
ConstraintSystem cs(func, options);
// FIXME: check the result
cs.matchFunctionBuilder(func, builderType, resultContextType,
resultConstraintKind,
/*calleeLocator=*/nullptr,
/*FIXME:*/ConstraintLocatorBuilder(nullptr));
// Solve the constraint system.
SmallVector<Solution, 4> solutions;
if (cs.solve(solutions) || solutions.size() != 1) {
// Try to fix the system or provide a decent diagnostic.
auto salvagedResult = cs.salvage();
switch (salvagedResult.getKind()) {
case SolutionResult::Kind::Success:
solutions.clear();
solutions.push_back(std::move(salvagedResult).takeSolution());
break;
case SolutionResult::Kind::Error:
case SolutionResult::Kind::Ambiguous:
return nullptr;
case SolutionResult::Kind::UndiagnosedError:
cs.diagnoseFailureFor(SolutionApplicationTarget(func));
salvagedResult.markAsDiagnosed();
return nullptr;
case SolutionResult::Kind::TooComplex:
func->diagnose(diag::expression_too_complex)
.highlight(func->getBodySourceRange());
salvagedResult.markAsDiagnosed();
return nullptr;
}
// The system was salvaged; continue on as if nothing happened.
}
// Apply the solution to the function body.
return cast_or_null<BraceStmt>(
cs.applySolutionToBody(solutions.front(), func));
}
ConstraintSystem::TypeMatchResult ConstraintSystem::matchFunctionBuilder(
AnyFunctionRef fn, Type builderType, Type bodyResultType,
ConstraintKind bodyResultConstraintKind,
ConstraintLocator *calleeLocator, ConstraintLocatorBuilder locator) {
auto builder = builderType->getAnyNominal();
assert(builder && "Bad function builder type");
assert(builder->getAttrs().hasAttribute<FunctionBuilderAttr>());
// Pre-check the body: pre-check any expressions in it and look
// for return statements.
auto request = PreCheckFunctionBuilderRequest{fn};
switch (evaluateOrDefault(getASTContext().evaluator, request,
FunctionBuilderClosurePreCheck::Error)) {
case FunctionBuilderClosurePreCheck::Okay:
// If the pre-check was okay, apply the function-builder transform.
break;
case FunctionBuilderClosurePreCheck::Error:
// If the pre-check had an error, flag that.
return getTypeMatchFailure(locator);
case FunctionBuilderClosurePreCheck::HasReturnStmt:
// If the closure has a return statement, suppress the transform but
// continue solving the constraint system.
return getTypeMatchSuccess();
}
// Check the form of this closure to see if we can apply the
// function-builder translation at all.
{
// Check whether we can apply this specific function builder.
BuilderClosureVisitor visitor(getASTContext(), this,
/*wantExpr=*/false, builderType);
(void)visitor.visit(fn.getBody());
// If we saw a control-flow statement or declaration that the builder
// cannot handle, we don't have a well-formed function builder application.
if (visitor.unhandledNode) {
// If we aren't supposed to attempt fixes, fail.
if (!shouldAttemptFixes()) {
return getTypeMatchFailure(locator);
}
// Record the first unhandled construct as a fix.
if (recordFix(
SkipUnhandledConstructInFunctionBuilder::create(
*this, visitor.unhandledNode, builder,
getConstraintLocator(locator)))) {
return getTypeMatchFailure(locator);
}
}
}
// If the builder type has a type parameter, substitute in the type
// variables.
if (builderType->hasTypeParameter()) {
// Find the opened type for this callee and substitute in the type
// parametes.
for (const auto &opened : OpenedTypes) {
if (opened.first == calleeLocator) {
OpenedTypeMap replacements(opened.second.begin(),
opened.second.end());
builderType = openType(builderType, replacements);
break;
}
}
assert(!builderType->hasTypeParameter());
}
BuilderClosureVisitor visitor(getASTContext(), this,
/*wantExpr=*/true, builderType);
Expr *singleExpr = visitor.visit(fn.getBody());
// We've already pre-checked all the original expressions, but do the
// pre-check to the generated expression just to set up any preconditions
// that CSGen might have.
//
// TODO: just build the AST the way we want it in the first place.
auto dc = fn.getAsDeclContext();
if (ConstraintSystem::preCheckExpression(singleExpr, dc))
return getTypeMatchFailure(locator);
singleExpr = generateConstraints(singleExpr, dc);
if (!singleExpr)
return getTypeMatchFailure(locator);
Type transformedType = getType(singleExpr);
assert(transformedType && "Missing type");
// Record the transformation.
assert(std::find_if(
functionBuilderTransformed.begin(),
functionBuilderTransformed.end(),
[&](const std::pair<AnyFunctionRef, AppliedBuilderTransform> &elt) {
return elt.first == fn;
}) == functionBuilderTransformed.end() &&
"already transformed this closure along this path!?!");
functionBuilderTransformed.push_back(
std::make_pair(
fn,
AppliedBuilderTransform{builderType, singleExpr, bodyResultType}));
// Bind the body result type to the type of the transformed expression.
addConstraint(bodyResultConstraintKind, transformedType, bodyResultType,
locator);
return getTypeMatchSuccess();
}
namespace {
/// Pre-check all the expressions in the closure body.
class PreCheckFunctionBuilderApplication : public ASTWalker {
AnyFunctionRef Fn;
bool SkipPrecheck = false;
std::vector<ReturnStmt *> ReturnStmts;
bool HasError = false;
bool hasReturnStmt() const { return !ReturnStmts.empty(); }
public:
PreCheckFunctionBuilderApplication(AnyFunctionRef fn, bool skipPrecheck)
: Fn(fn), SkipPrecheck(skipPrecheck) {}
const std::vector<ReturnStmt *> getReturnStmts() const { return ReturnStmts; }
FunctionBuilderClosurePreCheck run() {
Stmt *oldBody = Fn.getBody();
Stmt *newBody = oldBody->walk(*this);
// If the walk was aborted, it was because we had a problem of some kind.
assert((newBody == nullptr) == HasError &&
"unexpected short-circuit while walking body");
if (HasError)
return FunctionBuilderClosurePreCheck::Error;
if (hasReturnStmt())
return FunctionBuilderClosurePreCheck::HasReturnStmt;
assert(oldBody == newBody && "pre-check walk wasn't in-place?");
return FunctionBuilderClosurePreCheck::Okay;
}
std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
// Pre-check the expression. If this fails, abort the walk immediately.
// Otherwise, replace the expression with the result of pre-checking.
// In either case, don't recurse into the expression.
if (!SkipPrecheck &&
ConstraintSystem::preCheckExpression(E, /*DC*/ Fn.getAsDeclContext())) {
HasError = true;
return std::make_pair(false, nullptr);
}
return std::make_pair(false, E);
}
std::pair<bool, Stmt *> walkToStmtPre(Stmt *S) override {
// If we see a return statement, note it..
if (auto returnStmt = dyn_cast<ReturnStmt>(S)) {
if (!returnStmt->isImplicit()) {
ReturnStmts.push_back(returnStmt);
return std::make_pair(false, S);
}
}
// Otherwise, recurse into the statement normally.
return std::make_pair(true, S);
}
};
}
llvm::Expected<FunctionBuilderClosurePreCheck>
PreCheckFunctionBuilderRequest::evaluate(Evaluator &eval,
AnyFunctionRef fn) const {
// Single-expression closures should already have been pre-checked.
if (auto closure = fn.getAbstractClosureExpr()) {
if (closure->hasSingleExpressionBody())
return FunctionBuilderClosurePreCheck::Okay;
}
return PreCheckFunctionBuilderApplication(fn, false).run();
}
std::vector<ReturnStmt *> findReturnStatements(AnyFunctionRef fn) {
PreCheckFunctionBuilderApplication precheck(fn, true);
(void)precheck.run();
return precheck.getReturnStmts();
}