diff --git a/include/swift/AST/ASTNode.h b/include/swift/AST/ASTNode.h index ac208f1bf06..5724ac649b0 100644 --- a/include/swift/AST/ASTNode.h +++ b/include/swift/AST/ASTNode.h @@ -43,11 +43,9 @@ namespace swift { enum class PatternKind : uint8_t; enum class StmtKind; - using StmtCondition = llvm::MutableArrayRef; - struct ASTNode : public llvm::PointerUnion { + StmtConditionElement *, CaseLabelItem *> { // Inherit the constructors from PointerUnion. using PointerUnion::PointerUnion; diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index dab00742450..d4820894a16 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -395,7 +395,7 @@ public: /// the "x" binding, one for the "y" binding, one for the where clause, one for /// "z"'s binding. A simple "if" statement is represented as a single binding. /// -class StmtConditionElement { +class alignas(1 << PatternAlignInBits) StmtConditionElement { /// If this is a pattern binding, it may be the first one in a declaration, in /// which case this is the location of the var/let/case keyword. If this is /// the second pattern (e.g. for 'y' in "var x = ..., y = ...") then this @@ -818,7 +818,7 @@ public: }; /// A pattern and an optional guard expression used in a 'case' statement. -class CaseLabelItem { +class alignas(1 << PatternAlignInBits) CaseLabelItem { enum class Kind { /// A normal pattern Normal = 0, diff --git a/include/swift/AST/TypeAlignments.h b/include/swift/AST/TypeAlignments.h index 1c0f5e3b868..911c29494ff 100644 --- a/include/swift/AST/TypeAlignments.h +++ b/include/swift/AST/TypeAlignments.h @@ -61,6 +61,7 @@ namespace swift { class TypeRepr; class ValueDecl; class CaseLabelItem; + class StmtConditionElement; /// We frequently use three tag bits on all of these types. constexpr size_t AttrAlignInBits = 3; @@ -155,6 +156,9 @@ LLVM_DECLARE_TYPE_ALIGNMENT(swift::TypeRepr, swift::TypeReprAlignInBits) LLVM_DECLARE_TYPE_ALIGNMENT(swift::CaseLabelItem, swift::PatternAlignInBits) +LLVM_DECLARE_TYPE_ALIGNMENT(swift::StmtConditionElement, + swift::PatternAlignInBits) + static_assert(alignof(void*) >= 2, "pointer alignment is too small"); #endif diff --git a/include/swift/Sema/ConstraintLocator.h b/include/swift/Sema/ConstraintLocator.h index 74d4d87df95..c1e19ab12ac 100644 --- a/include/swift/Sema/ConstraintLocator.h +++ b/include/swift/Sema/ConstraintLocator.h @@ -1036,7 +1036,7 @@ public: if (auto *repr = node.dyn_cast()) return repr; - if (auto *cond = node.dyn_cast()) + if (auto *cond = node.dyn_cast()) return cond; if (auto *caseItem = node.dyn_cast()) diff --git a/lib/AST/ASTNode.cpp b/lib/AST/ASTNode.cpp index bce57fb943d..a9b0d3ae334 100644 --- a/lib/AST/ASTNode.cpp +++ b/lib/AST/ASTNode.cpp @@ -35,15 +35,8 @@ SourceRange ASTNode::getSourceRange() const { return P->getSourceRange(); if (const auto *T = this->dyn_cast()) return T->getSourceRange(); - if (const auto *C = this->dyn_cast()) { - if (C->empty()) - return SourceRange(); - - auto first = C->front(); - auto last = C->back(); - - return {first.getStartLoc(), last.getEndLoc()}; - } + if (const auto *C = this->dyn_cast()) + return C->getSourceRange(); if (const auto *I = this->dyn_cast()) { return I->getSourceRange(); } @@ -85,7 +78,7 @@ bool ASTNode::isImplicit() const { return P->isImplicit(); if (const auto *T = this->dyn_cast()) return false; - if (const auto *C = this->dyn_cast()) + if (const auto *C = this->dyn_cast()) return false; if (const auto *I = this->dyn_cast()) return false; @@ -103,10 +96,9 @@ void ASTNode::walk(ASTWalker &Walker) { P->walk(Walker); else if (auto *T = this->dyn_cast()) T->walk(Walker); - else if (auto *C = this->dyn_cast()) { - for (auto &elt : *C) - elt.walk(Walker); - } else if (auto *I = this->dyn_cast()) { + else if (auto *C = this->dyn_cast()) + C->walk(Walker); + else if (auto *I = this->dyn_cast()) { if (auto *P = I->getPattern()) P->walk(Walker); @@ -127,9 +119,9 @@ void ASTNode::dump(raw_ostream &OS, unsigned Indent) const { P->dump(OS, Indent); else if (auto T = dyn_cast()) T->print(OS); - else if (auto C = dyn_cast()) { - OS.indent(Indent) << "(statement conditions)"; - } else if (auto *I = dyn_cast()) { + else if (auto *C = dyn_cast()) + OS.indent(Indent) << "(statement condition)"; + else if (auto *I = dyn_cast()) { OS.indent(Indent) << "(case label item)"; } else llvm_unreachable("unsupported AST node"); diff --git a/lib/Sema/CSClosure.cpp b/lib/Sema/CSClosure.cpp index 6bcd97622eb..6f9e59a1513 100644 --- a/lib/Sema/CSClosure.cpp +++ b/lib/Sema/CSClosure.cpp @@ -531,6 +531,15 @@ private: "Unsupported statement: Fallthrough"); } + void visitStmtCondition(LabeledConditionalStmt *S, + SmallVectorImpl &elements, + ConstraintLocator *locator) { + auto *condLocator = + cs.getConstraintLocator(locator, ConstraintLocator::Condition); + for (auto &condition : S->getCond()) + elements.push_back(makeElement(&condition, condLocator)); + } + void visitIfStmt(IfStmt *ifStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: If"); @@ -538,11 +547,7 @@ private: SmallVector elements; // Condition - { - auto *condLoc = - cs.getConstraintLocator(locator, ConstraintLocator::Condition); - elements.push_back(makeElement(ifStmt->getCondPointer(), condLoc)); - } + visitStmtCondition(ifStmt, elements, locator); // Then Branch { @@ -565,24 +570,24 @@ private: assert(isSupportedMultiStatementClosure() && "Unsupported statement: Guard"); - createConjunction(cs, - {makeElement(guardStmt->getCondPointer(), - cs.getConstraintLocator( - locator, ConstraintLocator::Condition)), - makeElement(guardStmt->getBody(), locator)}, - locator); + SmallVector elements; + + visitStmtCondition(guardStmt, elements, locator); + elements.push_back(makeElement(guardStmt->getBody(), locator)); + + createConjunction(cs, elements, locator); } void visitWhileStmt(WhileStmt *whileStmt) { assert(isSupportedMultiStatementClosure() && "Unsupported statement: While"); - createConjunction(cs, - {makeElement(whileStmt->getCondPointer(), - cs.getConstraintLocator( - locator, ConstraintLocator::Condition)), - makeElement(whileStmt->getBody(), locator)}, - locator); + SmallVector elements; + + visitStmtCondition(whileStmt, elements, locator); + elements.push_back(makeElement(whileStmt->getBody(), locator)); + + createConjunction(cs, elements, locator); } void visitDoStmt(DoStmt *doStmt) { @@ -970,8 +975,8 @@ ConstraintSystem::simplifyClosureBodyElementConstraint( return SolutionKind::Solved; } else if (auto *stmt = element.dyn_cast()) { generator.visit(stmt); - } else if (auto *cond = element.dyn_cast()) { - if (generateConstraints(*cond, closure)) + } else if (auto *cond = element.dyn_cast()) { + if (generateConstraints({*cond}, closure)) return SolutionKind::Error; } else if (auto *pattern = element.dyn_cast()) { generator.visitPattern(pattern, context); @@ -1571,7 +1576,7 @@ void ConjunctionElement::findReferencedVariables( TypeVariableRefFinder refFinder(cs, locator->getAnchor(), typeVars); - if (element.is() || element.is() || + if (element.is() || element.is() || element.is() || element.isStmt(StmtKind::Return)) element.walk(refFinder); } diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 55238287308..22cbe9a94e2 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -6050,8 +6050,8 @@ SourceLoc constraints::getLoc(ASTNode anchor) { return S->getStartLoc(); } else if (auto *P = anchor.dyn_cast()) { return P->getLoc(); - } else if (auto *C = anchor.dyn_cast()) { - return C->front().getStartLoc(); + } else if (auto *C = anchor.dyn_cast()) { + return C->getStartLoc(); } else { auto *I = anchor.get(); return I->getStartLoc();