Requestify FallthroughStmt source and destination lookup

Follow a similar pattern to BreakTargetRequest
and ContinueTargetRequest.
This commit is contained in:
Hamish Knight
2024-08-14 19:59:05 +01:00
parent 4470814db8
commit 55aed16ee6
12 changed files with 103 additions and 66 deletions

View File

@@ -1491,10 +1491,10 @@ BridgedDoCatchStmt BridgedDoCatchStmt_createParsed(
BridgedNullableTypeRepr cThrownType, BridgedStmt cBody,
BridgedArrayRef cCatches);
SWIFT_NAME("BridgedFallthroughStmt.createParsed(_:loc:)")
SWIFT_NAME("BridgedFallthroughStmt.createParsed(loc:declContext:)")
BridgedFallthroughStmt
BridgedFallthroughStmt_createParsed(BridgedASTContext cContext,
BridgedSourceLoc cLoc);
BridgedFallthroughStmt_createParsed(BridgedSourceLoc cLoc,
BridgedDeclContext cDC);
SWIFT_NAME("BridgedForEachStmt.createParsed(_:labelInfo:forLoc:tryLoc:awaitLoc:"
"pattern:inLoc:sequence:whereLoc:whereExpr:body:)")

View File

@@ -1159,36 +1159,29 @@ public:
/// FallthroughStmt - The keyword "fallthrough".
class FallthroughStmt : public Stmt {
SourceLoc Loc;
CaseStmt *FallthroughSource;
CaseStmt *FallthroughDest;
DeclContext *DC;
public:
FallthroughStmt(SourceLoc Loc, std::optional<bool> implicit = std::nullopt)
FallthroughStmt(SourceLoc Loc, DeclContext *DC,
std::optional<bool> implicit = std::nullopt)
: Stmt(StmtKind::Fallthrough, getDefaultImplicitFlag(implicit, Loc)),
Loc(Loc), FallthroughSource(nullptr), FallthroughDest(nullptr) {}
Loc(Loc), DC(DC) {}
public:
static FallthroughStmt *createParsed(SourceLoc Loc, DeclContext *DC);
SourceLoc getLoc() const { return Loc; }
SourceRange getSourceRange() const { return Loc; }
DeclContext *getDeclContext() const { return DC; }
void setDeclContext(DeclContext *newDC) { DC = newDC; }
/// Get the CaseStmt block from which the fallthrough transfers control.
/// Set during Sema. (May stay null if fallthrough is invalid.)
CaseStmt *getFallthroughSource() const { return FallthroughSource; }
void setFallthroughSource(CaseStmt *C) {
assert(!FallthroughSource && "fallthrough source already set?!");
FallthroughSource = C;
}
/// Returns \c nullptr if the fallthrough is invalid.
CaseStmt *getFallthroughSource() const;
/// Get the CaseStmt block to which the fallthrough transfers control.
/// Set during Sema.
CaseStmt *getFallthroughDest() const {
assert(FallthroughDest && "fallthrough dest is not set until Sema");
return FallthroughDest;
}
void setFallthroughDest(CaseStmt *C) {
assert(!FallthroughDest && "fallthrough dest already set?!");
FallthroughDest = C;
}
/// Returns \c nullptr if the fallthrough is invalid.
CaseStmt *getFallthroughDest() const;
static bool classof(const Stmt *S) {
return S->getKind() == StmtKind::Fallthrough;

View File

@@ -4162,6 +4162,29 @@ public:
bool isCached() const { return true; }
};
struct FallthroughSourceAndDest {
CaseStmt *Source;
CaseStmt *Dest;
};
/// Lookup the source and destination of a 'fallthrough'.
class FallthroughSourceAndDestRequest
: public SimpleRequest<FallthroughSourceAndDestRequest,
FallthroughSourceAndDest(const FallthroughStmt *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
private:
friend SimpleRequest;
FallthroughSourceAndDest evaluate(Evaluator &evaluator,
const FallthroughStmt *FS) const;
public:
bool isCached() const { return true; }
};
/// Precheck a ReturnStmt, which involves some initial validation, as well as
/// applying a conversion to a FailStmt if needed.
class PreCheckReturnStmtRequest

View File

@@ -482,6 +482,9 @@ SWIFT_REQUEST(TypeChecker, BreakTargetRequest,
SWIFT_REQUEST(TypeChecker, ContinueTargetRequest,
LabeledStmt *(const ContinueStmt *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, FallthroughSourceAndDestRequest,
FallthroughSourceAndDest(const FallthroughStmt *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, PreCheckReturnStmtRequest,
Stmt *(ReturnStmt *, DeclContext *),
Cached, NoLocationInfo)

View File

@@ -2005,9 +2005,9 @@ BridgedDoCatchStmt BridgedDoCatchStmt_createParsed(
}
BridgedFallthroughStmt
BridgedFallthroughStmt_createParsed(BridgedASTContext cContext,
BridgedSourceLoc cLoc) {
return new (cContext.unbridged()) FallthroughStmt(cLoc.unbridged());
BridgedFallthroughStmt_createParsed(BridgedSourceLoc cLoc,
BridgedDeclContext cDC) {
return FallthroughStmt::createParsed(cLoc.unbridged(), cDC.unbridged());
}
BridgedForEachStmt BridgedForEachStmt_createParsed(

View File

@@ -989,6 +989,23 @@ LabeledStmt *ContinueStmt::getTarget() const {
return evaluateOrDefault(eval, ContinueTargetRequest{this}, nullptr);
}
FallthroughStmt *FallthroughStmt::createParsed(SourceLoc Loc, DeclContext *DC) {
auto &ctx = DC->getASTContext();
return new (ctx) FallthroughStmt(Loc, DC);
}
CaseStmt *FallthroughStmt::getFallthroughSource() const {
auto &eval = getDeclContext()->getASTContext().evaluator;
return evaluateOrDefault(eval, FallthroughSourceAndDestRequest{this}, {})
.Source;
}
CaseStmt *FallthroughStmt::getFallthroughDest() const {
auto &eval = getDeclContext()->getASTContext().evaluator;
return evaluateOrDefault(eval, FallthroughSourceAndDestRequest{this}, {})
.Dest;
}
SourceLoc swift::extractNearestSourceLoc(const Stmt *S) {
return S->getStartLoc();
}

View File

@@ -339,8 +339,8 @@ extension ASTGenVisitor {
func generate(fallThroughStmt node: FallThroughStmtSyntax) -> BridgedFallthroughStmt {
return .createParsed(
self.ctx,
loc: self.generateSourceLoc(node.fallthroughKeyword)
loc: self.generateSourceLoc(node.fallthroughKeyword),
declContext: self.declContext
)
}

View File

@@ -648,8 +648,8 @@ ParserResult<Stmt> Parser::parseStmt(bool fromASTGen) {
if (LabelInfo) diagnose(LabelInfo.Loc, diag::invalid_label_on_stmt);
if (tryLoc.isValid()) diagnose(tryLoc, diag::try_on_stmt, Tok.getText());
return makeParserResult(
new (Context) FallthroughStmt(consumeToken(tok::kw_fallthrough)));
auto loc = consumeToken(tok::kw_fallthrough);
return makeParserResult(FallthroughStmt::createParsed(loc, CurDeclContext));
}
case tok::pound_assert:
if (LabelInfo) diagnose(LabelInfo.Loc, diag::invalid_label_on_stmt);

View File

@@ -1794,7 +1794,7 @@ private:
}
ASTNode visitFallthroughStmt(FallthroughStmt *fallthroughStmt) {
if (checkFallthroughStmt(context.getAsDeclContext(), fallthroughStmt))
if (checkFallthroughStmt(fallthroughStmt))
hadError = true;
return fallthroughStmt;
}

View File

@@ -742,6 +742,25 @@ ContinueTargetRequest::evaluate(Evaluator &evaluator,
CS->getTargetName(), CS->getTargetLoc(), /*isContinue*/ true, DC);
}
FallthroughSourceAndDest
FallthroughSourceAndDestRequest::evaluate(Evaluator &evaluator,
const FallthroughStmt *FS) const {
auto *SF = FS->getDeclContext()->getParentSourceFile();
auto &ctx = SF->getASTContext();
auto loc = FS->getLoc();
auto [src, dest] = ASTScope::lookupFallthroughSourceAndDest(SF, loc);
if (!src) {
ctx.Diags.diagnose(loc, diag::fallthrough_outside_switch);
return {};
}
if (!dest) {
ctx.Diags.diagnose(loc, diag::fallthrough_from_last_case);
return {};
}
return {src, dest};
}
static Expr *getDeclRefProvidingExpressionForHasSymbol(Expr *E) {
// Strip coercions, which are necessary in source to disambiguate overloaded
// functions or generic functions, e.g.
@@ -925,12 +944,18 @@ static bool typeCheckConditionForStatement(LabeledConditionalStmt *stmt,
return false;
}
/// Verify that the pattern bindings for the cases that we're falling through
/// from and to are equivalent.
static void checkFallthroughPatternBindingsAndTypes(
ASTContext &ctx,
CaseStmt *caseBlock, CaseStmt *previousBlock,
FallthroughStmt *fallthrough) {
/// Check the correctness of a 'fallthrough' statement.
///
/// \returns true if an error occurred.
bool swift::checkFallthroughStmt(FallthroughStmt *FS) {
auto &ctx = FS->getDeclContext()->getASTContext();
auto *caseBlock = FS->getFallthroughDest();
auto *previousBlock = FS->getFallthroughSource();
if (!previousBlock || !caseBlock)
return true;
// Verify that the pattern bindings for the cases that we're falling through
// from and to are equivalent.
auto firstPattern = caseBlock->getCaseLabelItems()[0].getPattern();
SmallVector<VarDecl *, 4> vars;
firstPattern->collectVariables(vars);
@@ -969,36 +994,10 @@ static void checkFallthroughPatternBindingsAndTypes(
if (!matched) {
ctx.Diags.diagnose(
fallthrough->getLoc(), diag::fallthrough_into_case_with_var_binding,
FS->getLoc(), diag::fallthrough_into_case_with_var_binding,
expected->getName());
}
}
}
/// Check the correctness of a 'fallthrough' statement.
///
/// \returns true if an error occurred.
bool swift::checkFallthroughStmt(DeclContext *dc, FallthroughStmt *stmt) {
CaseStmt *fallthroughSource;
CaseStmt *fallthroughDest;
ASTContext &ctx = dc->getASTContext();
auto sourceFile = dc->getParentSourceFile();
std::tie(fallthroughSource, fallthroughDest) =
ASTScope::lookupFallthroughSourceAndDest(sourceFile, stmt->getLoc());
if (!fallthroughSource) {
ctx.Diags.diagnose(stmt->getLoc(), diag::fallthrough_outside_switch);
return true;
}
if (!fallthroughDest) {
ctx.Diags.diagnose(stmt->getLoc(), diag::fallthrough_from_last_case);
return true;
}
stmt->setFallthroughSource(fallthroughSource);
stmt->setFallthroughDest(fallthroughDest);
checkFallthroughPatternBindingsAndTypes(
ctx, fallthroughDest, fallthroughSource, stmt);
return false;
}
@@ -1457,7 +1456,7 @@ public:
}
Stmt *visitFallthroughStmt(FallthroughStmt *S) {
if (checkFallthroughStmt(DC, S))
if (checkFallthroughStmt(S))
return nullptr;
return S;

View File

@@ -1593,6 +1593,8 @@ namespace {
BS->setDeclContext(NewDC);
if (auto *CS = dyn_cast<ContinueStmt>(S))
CS->setDeclContext(NewDC);
if (auto *FS = dyn_cast<FallthroughStmt>(S))
FS->setDeclContext(NewDC);
return Action::Continue(S);
}

View File

@@ -1433,7 +1433,7 @@ LabeledStmt *findBreakOrContinueStmtTarget(ASTContext &ctx,
/// Check the correctness of a 'fallthrough' statement.
///
/// \returns true if an error occurred.
bool checkFallthroughStmt(DeclContext *dc, FallthroughStmt *stmt);
bool checkFallthroughStmt(FallthroughStmt *stmt);
/// Check for restrictions on the use of the @unknown attribute on a
/// case statement.