Introduce if/switch expressions

Introduce SingleValueStmtExpr, which allows the
embedding of a statement in an expression context.
This then allows us to parse and type-check `if`
and `switch` statements as expressions, gated
behind the `IfSwitchExpression` experimental
feature for now. In the future,
SingleValueStmtExpr could also be used for e.g
`do` expressions.

For now, only single expression branches are
supported for producing a value from an
`if`/`switch` expression, and each branch is
type-checked independently. A multi-statement
branch may only appear if it ends with a `throw`,
and it may not `break`, `continue`, or `return`.

The placement of `if`/`switch` expressions is also
currently limited by a syntactic use diagnostic.
Currently they're only allowed in bindings,
assignments, throws, and returns. But this could
be lifted in the future if desired.
This commit is contained in:
Hamish Knight
2023-02-01 15:30:18 +00:00
parent df2b3b2880
commit a40f1abaff
70 changed files with 5520 additions and 188 deletions

View File

@@ -49,9 +49,11 @@ class TypeVariableRefFinder : public ASTWalker {
public:
TypeVariableRefFinder(
ConstraintSystem &cs, ASTNode parent,
ConstraintSystem &cs, ASTNode parent, ContextualTypeInfo context,
llvm::SmallPtrSetImpl<TypeVariableType *> &referencedVars)
: CS(cs), Parent(parent), ReferencedVars(referencedVars) {
if (auto ty = context.getType())
inferVariables(ty);
if (auto *closure = getAsExpr<ClosureExpr>(Parent))
ClosureDCs.push_back(closure);
}
@@ -331,6 +333,15 @@ static void createConjunction(ConstraintSystem &cs,
isIsolated = true;
}
if (locator->isForSingleValueStmtConjunction()) {
auto *SVE = castToExpr<SingleValueStmtExpr>(locator->getAnchor());
referencedVars.push_back(cs.getType(SVE)->castTo<TypeVariableType>());
// Single value statement conjunctions are always isolated, as we want to
// solve the branches independently of the rest of the system.
isIsolated = true;
}
UnresolvedClosureParameterCollector paramCollector(cs);
for (const auto &entry : elements) {
@@ -372,11 +383,23 @@ ElementInfo makeElement(ASTNode node, ConstraintLocator *locator,
return std::make_tuple(node, context, isDiscarded, locator);
}
ElementInfo makeJoinElement(ConstraintSystem &cs, TypeJoinExpr *join,
ConstraintLocator *locator) {
return makeElement(
join, cs.getConstraintLocator(locator,
{LocatorPathElt::SyntacticElement(join)}));
}
struct SyntacticElementContext
: public llvm::PointerUnion<AbstractFunctionDecl *, AbstractClosureExpr *> {
: public llvm::PointerUnion<AbstractFunctionDecl *, AbstractClosureExpr *,
SingleValueStmtExpr *> {
// Inherit the constructors from PointerUnion.
using PointerUnion::PointerUnion;
/// A join that should be applied to the elements of a SingleValueStmtExpr.
NullablePtr<TypeJoinExpr> ElementJoin;
static SyntacticElementContext forFunctionRef(AnyFunctionRef ref) {
if (auto *decl = ref.getAbstractFunctionDecl()) {
return {decl};
@@ -393,11 +416,21 @@ struct SyntacticElementContext
return {func};
}
static SyntacticElementContext
forSingleValueStmtExpr(SingleValueStmtExpr *SVE,
TypeJoinExpr *Join = nullptr) {
auto context = SyntacticElementContext{SVE};
context.ElementJoin = Join;
return context;
}
DeclContext *getAsDeclContext() const {
if (auto *fn = this->dyn_cast<AbstractFunctionDecl *>()) {
return fn;
} else if (auto *closure = this->dyn_cast<AbstractClosureExpr *>()) {
return closure;
} else if (auto *SVE = dyn_cast<SingleValueStmtExpr *>()) {
return SVE->getDeclContext();
} else {
llvm_unreachable("unsupported kind");
}
@@ -426,11 +459,13 @@ struct SyntacticElementContext
}
}
BraceStmt *getBody() const {
Stmt *getStmt() const {
if (auto *fn = this->dyn_cast<AbstractFunctionDecl *>()) {
return fn->getBody();
} else if (auto *closure = this->dyn_cast<AbstractClosureExpr *>()) {
return closure->getBody();
} else if (auto *SVE = dyn_cast<SingleValueStmtExpr *>()) {
return SVE->getStmt();
} else {
llvm_unreachable("unsupported kind");
}
@@ -734,17 +769,12 @@ private:
// Other declarations will be handled at application time.
}
void visitBreakStmt(BreakStmt *breakStmt) {
}
void visitContinueStmt(ContinueStmt *continueStmt) {
}
void visitDeferStmt(DeferStmt *deferStmt) {
}
void visitFallthroughStmt(FallthroughStmt *fallthroughStmt) {
}
// These statements don't require any type-checking.
void visitBreakStmt(BreakStmt *breakStmt) {}
void visitContinueStmt(ContinueStmt *continueStmt) {}
void visitDeferStmt(DeferStmt *deferStmt) {}
void visitFallthroughStmt(FallthroughStmt *fallthroughStmt) {}
void visitFailStmt(FailStmt *fail) {}
void visitStmtCondition(LabeledConditionalStmt *S,
SmallVectorImpl<ElementInfo> &elements,
@@ -775,6 +805,10 @@ private:
elements.push_back(makeElement(ifStmt->getElseStmt(), elseLoc));
}
// Inject a join if we have one.
if (auto *join = context.ElementJoin.getPtrOrNull())
elements.push_back(makeJoinElement(cs, join, locator));
createConjunction(cs, elements, locator);
}
@@ -876,6 +910,10 @@ private:
elements.push_back(makeElement(rawCase, switchLoc));
}
// Inject a join if we have one.
if (auto *join = context.ElementJoin.getPtrOrNull())
elements.push_back(makeJoinElement(cs, join, switchLoc));
createConjunction(cs, elements, switchLoc);
}
@@ -992,10 +1030,6 @@ private:
SmallVector<ElementInfo, 4> elements;
for (auto element : braceStmt->getElements()) {
bool isDiscarded =
element.is<Expr *>() &&
(!ctx.LangOpts.Playground && !ctx.LangOpts.DebuggerSupport);
if (auto *decl = element.dyn_cast<Decl *>()) {
if (auto *PDB = dyn_cast<PatternBindingDecl>(decl)) {
visitPatternBinding(PDB, elements);
@@ -1003,11 +1037,30 @@ private:
}
}
elements.push_back(
makeElement(element,
cs.getConstraintLocator(
locator, LocatorPathElt::SyntacticElement(element)),
/*contextualInfo=*/{}, isDiscarded));
bool isDiscarded = false;
auto contextInfo = cs.getContextualTypeInfo(element);
if (element.is<Expr *>() &&
!ctx.LangOpts.Playground && !ctx.LangOpts.DebuggerSupport) {
isDiscarded = !contextInfo || contextInfo->purpose == CTP_Unused;
}
// For an if/switch expression, if the contextual type for the branch is
// still a type variable, we can drop it. This avoids needlessly
// propagating the type of the branch to subsequent branches, instead
// we'll let the join handle the conversion.
if (contextInfo && isExpr<SingleValueStmtExpr>(locator->getAnchor())) {
auto contextualFixedTy = cs.getFixedTypeRecursive(
contextInfo->getType(), /*wantRValue*/ true);
if (contextualFixedTy->isTypeVariableOrMember())
contextInfo = None;
}
elements.push_back(makeElement(
element,
cs.getConstraintLocator(locator,
LocatorPathElt::SyntacticElement(element)),
contextInfo.getValueOr(ContextualTypeInfo()), isDiscarded));
}
createConjunction(cs, elements, locator);
@@ -1070,7 +1123,7 @@ private:
}
ContextualTypeInfo getContextualResultInfo() const {
auto funcRef = context.getAsAnyFunctionRef();
auto funcRef = AnyFunctionRef::fromDeclContext(context.getAsDeclContext());
if (!funcRef)
return {Type(), CTP_Unused};
@@ -1088,7 +1141,6 @@ private:
llvm_unreachable("Unsupported statement kind " #STMT); \
}
UNSUPPORTED_STMT(Yield)
UNSUPPORTED_STMT(Fail)
#undef UNSUPPORTED_STMT
private:
@@ -1195,6 +1247,92 @@ bool ConstraintSystem::generateConstraints(AnyFunctionRef fn, BraceStmt *body) {
return generator.hadError;
}
bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
auto *S = E->getStmt();
auto &ctx = getASTContext();
auto *loc = getConstraintLocator(E);
Type resultTy = createTypeVariable(loc, /*options*/ 0);
setType(E, resultTy);
// Assign contextual types for each of the expression branches.
SmallVector<Expr *, 4> scratch;
auto branches = E->getSingleExprBranches(scratch);
for (auto *branch : branches) {
setContextualType(branch, TypeLoc::withoutLoc(resultTy),
CTP_SingleValueStmtBranch);
}
TypeJoinExpr *join = nullptr;
if (branches.empty()) {
// If we only have statement branches, the expression is typed as Void. This
// should only be the case for 'if' and 'switch' statements that must be
// expressions that have branches that all end in a throw, and we'll warn
// that we've inferred Void.
addConstraint(ConstraintKind::Bind, resultTy, ctx.getVoidType(), loc);
} else {
// Otherwise, we join the result types for each of the branches.
join = TypeJoinExpr::forBranchesOfSingleValueStmtExpr(
ctx, resultTy, E, AllocationArena::ConstraintSolver);
}
// If this is the single expression body of a closure, we need to account
// for the fact that the result type may be bound to Void. This is necessary
// to correctly handle the following case:
//
// func foo<T>(_ fn: () -> T) {}
// foo {
// if .random() { 0 } else { "" }
// }
//
// Before if/switch expressions, this was treated as a regular statement,
// with the branches being discarded (and we'd warn). We need to ensure we
// maintain compatibility by continuing to infer T as Void in the case where
// the branches mismatch. This example is contrived, but can occur in the real
// world with e.g branches that insert and remove elements from a set, in both
// cases the methods have mismatching discardable returns.
//
// To maintain this behavior, form a disjunction that will attempt to either
// bind the expression type to the closure result type, or bind it to Void.
// Only if we fail to solve with the closure result type will we attempt with
// Void. We can't rely on the usual defaulting of the closure result type,
// as we need to solve the conjunction before trying defaults.
//
// This only needs to happen for cases where the return is implicit, we don't
// need to do this with 'return if'. We also don't need to do it for function
// decls, as we proactively avoid transforming the if/switch into an
// expression if the result is known to be Void.
if (auto *CE = dyn_cast<ClosureExpr>(E->getDeclContext())) {
if (CE->hasSingleExpressionBody() && !hasExplicitResult(CE) &&
CE->getSingleExpressionBody()->getSemanticsProvidingExpr() == E) {
assert(!getAppliedResultBuilderTransform(CE) &&
"Should have applied the builder with statement semantics");
// We may not have a closure type if we're solving a sub-expression
// independently for e.g code completion.
// TODO: This won't be necessary once we stop doing the fallback
// type-check.
if (auto *closureTy = getClosureTypeIfAvailable(CE)) {
auto closureResultTy = closureTy->getResult();
auto *bindToClosure = Constraint::create(
*this, ConstraintKind::Bind, resultTy, closureResultTy, loc);
bindToClosure->setFavored();
auto *bindToVoid = Constraint::create(*this, ConstraintKind::Bind,
resultTy, ctx.getVoidType(), loc);
addDisjunctionConstraint({bindToClosure, bindToVoid}, loc);
}
}
}
// Generate the conjunction for the branches.
auto context = SyntacticElementContext::forSingleValueStmtExpr(E, join);
SyntacticElementConstraintGenerator generator(*this, context, loc);
generator.visit(S);
return generator.hadError;
}
bool ConstraintSystem::isInResultBuilderContext(ClosureExpr *closure) const {
if (!closure->hasSingleExpressionBody()) {
auto *DC = closure->getParent();
@@ -1244,6 +1382,8 @@ ConstraintSystem::simplifySyntacticElementConstraint(
context = SyntacticElementContext::forClosure(closure);
} else if (auto *fn = getAsDecl<AbstractFunctionDecl>(anchor)) {
context = SyntacticElementContext::forFunction(fn);
} else if (auto *SVE = getAsExpr<SingleValueStmtExpr>(anchor)) {
context = SyntacticElementContext::forSingleValueStmtExpr(SVE);
} else {
return SolutionKind::Error;
}
@@ -1252,9 +1392,26 @@ ConstraintSystem::simplifySyntacticElementConstraint(
getConstraintLocator(locator));
if (auto *expr = element.dyn_cast<Expr *>()) {
auto ctpElt = LocatorPathElt::ContextualType(contextInfo.purpose);
auto *contextualTypeLoc = getConstraintLocator(expr, {ctpElt});
// If this is a branch expression in a SingleValueStmtExpr, form a locator
// based on the branch index.
if (auto *SVE = getAsExpr<SingleValueStmtExpr>(locator.getAnchor())) {
SmallVector<Expr *, 4> scratch;
auto branches = SVE->getSingleExprBranches(scratch);
for (auto idx : indices(branches)) {
if (expr == branches[idx]) {
contextualTypeLoc = getConstraintLocator(
SVE, {LocatorPathElt::SingleValueStmtBranch(idx), ctpElt});
break;
}
}
}
SolutionApplicationTarget target(expr, context->getAsDeclContext(),
contextInfo.purpose, contextInfo.getType(),
isDiscarded);
contextualTypeLoc, isDiscarded);
if (generateConstraints(target, FreeTypeVariableBinding::Disallow))
return SolutionKind::Error;
@@ -1302,10 +1459,21 @@ public:
SyntacticElementSolutionApplication(Solution &solution,
SyntacticElementContext context,
Type resultType,
RewriteTargetFn rewriteTarget)
: solution(solution), context(context), resultType(resultType),
rewriteTarget(rewriteTarget) {}
: solution(solution), context(context), rewriteTarget(rewriteTarget) {
if (auto fn = AnyFunctionRef::fromDeclContext(context.getAsDeclContext())) {
if (auto transform = solution.getAppliedBuilderTransform(*fn)) {
resultType = solution.simplifyType(transform->bodyResultType);
} else if (auto *closure =
getAsExpr<ClosureExpr>(fn->getAbstractClosureExpr())) {
resultType = solution.getResolvedType(closure)
->castTo<FunctionType>()
->getResult();
} else {
resultType = fn->getBodyResultType();
}
}
}
virtual ~SyntacticElementSolutionApplication() {}
@@ -1383,6 +1551,10 @@ private:
return fallthroughStmt;
}
ASTNode visitFailStmt(FailStmt *failStmt) {
return failStmt;
}
ASTNode visitDeferStmt(DeferStmt *deferStmt) {
TypeChecker::typeCheckDecl(deferStmt->getTempDecl());
@@ -1592,7 +1764,7 @@ private:
return caseStmt;
}
ASTNode visitBraceElement(ASTNode node) {
virtual ASTNode visitBraceElement(ASTNode node) {
auto &cs = solution.getConstraintSystem();
if (auto *expr = node.dyn_cast<Expr *>()) {
// Rewrite the expression.
@@ -1800,13 +1972,12 @@ private:
llvm_unreachable("Unsupported statement kind " #STMT); \
}
UNSUPPORTED_STMT(Yield)
UNSUPPORTED_STMT(Fail)
#undef UNSUPPORTED_STMT
public:
/// Apply solution to the closure and return updated body.
ASTNode apply() {
auto body = visit(context.getBody());
/// Apply the solution to the context and return updated statement.
Stmt *apply() {
auto body = visit(context.getStmt());
// Since local functions can capture variables that are declared
// after them, let's type-check them after all of the pattern
@@ -1814,7 +1985,7 @@ public:
for (auto *func : LocalFuncs)
TypeChecker::typeCheckDecl(func);
return body;
return body ? body.get<Stmt *>() : nullptr;
}
};
@@ -1827,11 +1998,11 @@ public:
RewriteTargetFn rewriteTarget)
: SyntacticElementSolutionApplication(
solution, SyntacticElementContext::forFunctionRef(context),
transform.bodyResultType, rewriteTarget),
rewriteTarget),
Transform(transform) {}
bool apply() {
auto body = visit(context.getBody());
auto body = visit(context.getStmt());
if (!body || hadError)
return true;
@@ -1863,6 +2034,15 @@ private:
return doStmt;
}
ASTNode visitBraceElement(ASTNode node) override {
if (auto *SVE = getAsExpr<SingleValueStmtExpr>(node)) {
// This should never be treated as an expression in a result builder,
// it should have statement semantics.
return visitBraceElement(SVE->getStmt());
}
return SyntacticElementSolutionApplication::visitBraceElement(node);
}
NullablePtr<Stmt> transformDo(DoStmt *doStmt) {
if (!doStmt->isImplicit())
return nullptr;
@@ -2208,29 +2388,15 @@ bool ConstraintSystem::applySolutionToBody(Solution &solution,
// transformations.
llvm::SaveAndRestore<DeclContext *> savedDC(currentDC, fn.getAsDeclContext());
Type resultTy;
if (auto transform = solution.getAppliedBuilderTransform(fn)) {
resultTy = solution.simplifyType(transform->bodyResultType);
} else if (auto *closure =
getAsExpr<ClosureExpr>(fn.getAbstractClosureExpr())) {
resultTy =
solution.getResolvedType(closure)->castTo<FunctionType>()->getResult();
} else {
resultTy = fn.getBodyResultType();
}
SyntacticElementSolutionApplication application(
solution, SyntacticElementContext::forFunctionRef(fn), resultTy,
rewriteTarget);
solution, SyntacticElementContext::forFunctionRef(fn), rewriteTarget);
auto body = application.apply();
auto *body = application.apply();
if (!body || application.hadError)
return true;
fn.setTypecheckedBody(castToStmt<BraceStmt>(body),
fn.hasSingleExpressionBody());
fn.setTypecheckedBody(cast<BraceStmt>(body), fn.hasSingleExpressionBody());
return false;
}
@@ -2249,6 +2415,31 @@ bool ConjunctionElement::mightContainCodeCompletionToken(
}
}
bool ConstraintSystem::applySolutionToSingleValueStmt(
Solution &solution, SingleValueStmtExpr *SVE, DeclContext *DC,
RewriteTargetFn rewriteTarget) {
auto context = SyntacticElementContext::forSingleValueStmtExpr(SVE);
SyntacticElementSolutionApplication application(solution, context,
rewriteTarget);
auto *stmt = application.apply();
if (!stmt || application.hadError)
return true;
// If the expression was typed as Void, its branches are effectively
// discarded, so treat them as ignored expressions. This doesn't happen in
// the solution application walker as we consider all the branches to have
// contextual types.
if (solution.getResolvedType(SVE)->lookThroughAllOptionalTypes()->isVoid()) {
SmallVector<Expr *, 4> scratch;
for (auto *branch : SVE->getSingleExprBranches(scratch))
TypeChecker::checkIgnoredExpr(branch);
}
SVE->setStmt(stmt);
return false;
}
void ConjunctionElement::findReferencedVariables(
ConstraintSystem &cs, SmallPtrSetImpl<TypeVariableType *> &typeVars) const {
auto referencedVars = Element->getTypeVariables();
@@ -2260,7 +2451,16 @@ void ConjunctionElement::findReferencedVariables(
ASTNode element = Element->getSyntacticElement();
auto *locator = Element->getLocator();
TypeVariableRefFinder refFinder(cs, locator->getAnchor(), typeVars);
ASTNode parent = locator->getAnchor();
if (auto *SVE = getAsExpr<SingleValueStmtExpr>(parent)) {
// Use a parent closure if we have one. This is needed to correctly handle
// return statements that refer to an outer closure.
if (auto *CE = dyn_cast<ClosureExpr>(SVE->getDeclContext()))
parent = CE;
}
TypeVariableRefFinder refFinder(cs, parent, Element->getElementContext(),
typeVars);
// If this is a pattern of `for-in` statement, let's walk into `for-in`
// sequence expression because both elements are type-checked together.