From 87d86f3545bc62bc77c8ada12caacdcecc2f299d Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Thu, 9 Apr 2020 11:02:56 -0700 Subject: [PATCH] [Constraint solver] Migrate for-each statement checking into SolutionApplicationTarget. Pull the entirety of type checking for for-each statement headers (i.e., not the body) into the constraint system, using the normal SolutionApplicationTarget-based constraint generation and application facilities. Most of this was already handled in the constraint solver (although the `where` filtering condition was not), so this is a smaller change than it looks like. --- lib/Sema/CSApply.cpp | 177 ++++++++++++++++++++- lib/Sema/CSGen.cpp | 116 ++++++++++++++ lib/Sema/ConstraintSystem.cpp | 33 +++- lib/Sema/ConstraintSystem.h | 65 ++++++-- lib/Sema/TypeCheckConstraints.cpp | 249 ++---------------------------- 5 files changed, 384 insertions(+), 256 deletions(-) diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 7520fb1f8ae..a9348cadb6c 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -7886,7 +7886,7 @@ bool ConstraintSystem::applySolutionFixes(const Solution &solution) { /// Apply the given solution to the initialization target. /// -/// \returns the resulting initialiation expression. +/// \returns the resulting initialization expression. static Optional applySolutionToInitialization( Solution &solution, SolutionApplicationTarget target, Expr *initializer) { @@ -7950,7 +7950,7 @@ static Optional applySolutionToInitialization( finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false); // Apply the solution to the pattern as well. - auto contextualPattern = target.getInitializationContextualPattern(); + auto contextualPattern = target.getContextualPattern(); if (auto coercedPattern = TypeChecker::coercePatternToType( contextualPattern, finalPatternType, options)) { resultTarget.setPattern(coercedPattern); @@ -7961,6 +7961,139 @@ static Optional applySolutionToInitialization( return resultTarget; } +/// Apply the given solution to the for-each statement target. +/// +/// \returns the resulting initialization expression. +static Optional applySolutionToForEachStmt( + Solution &solution, SolutionApplicationTarget target, Expr *sequence) { + auto resultTarget = target; + auto &forEachStmtInfo = resultTarget.getForEachStmtInfo(); + + // Simplify the various types. + forEachStmtInfo.elementType = + solution.simplifyType(forEachStmtInfo.elementType); + forEachStmtInfo.iteratorType = + solution.simplifyType(forEachStmtInfo.iteratorType); + forEachStmtInfo.initType = + solution.simplifyType(forEachStmtInfo.initType); + forEachStmtInfo.sequenceType = + solution.simplifyType(forEachStmtInfo.sequenceType); + + // Coerce the sequence to the sequence type. + auto &cs = solution.getConstraintSystem(); + auto locator = cs.getConstraintLocator(target.getAsExpr()); + sequence = solution.coerceToType( + sequence, forEachStmtInfo.sequenceType, locator); + if (!sequence) + return None; + + resultTarget.setExpr(sequence); + + // Get the conformance of the sequence type to the Sequence protocol. + auto stmt = forEachStmtInfo.stmt; + auto sequenceProto = TypeChecker::getProtocol( + cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence); + auto contextualLocator = cs.getConstraintLocator( + target.getAsExpr(), LocatorPathElt::ContextualType()); + auto sequenceConformance = solution.resolveConformance( + contextualLocator, sequenceProto); + assert(!sequenceConformance.isInvalid() && + "Couldn't find sequence conformance"); + + // Coerce the pattern to the element type. + TypeResolutionOptions options(TypeResolverContext::ForEachStmt); + options |= TypeResolutionFlags::OverrideType; + + // Apply the solution to the pattern as well. + auto contextualPattern = target.getContextualPattern(); + if (auto coercedPattern = TypeChecker::coercePatternToType( + contextualPattern, forEachStmtInfo.initType, options)) { + resultTarget.setPattern(coercedPattern); + } else { + return None; + } + + // Apply the solution to the filtering condition, if there is one. + auto dc = target.getDeclContext(); + if (forEachStmtInfo.whereExpr) { + auto *boolDecl = dc->getASTContext().getBoolDecl(); + assert(boolDecl); + Type boolType = boolDecl->getDeclaredType(); + assert(boolType); + + SolutionApplicationTarget whereTarget( + forEachStmtInfo.whereExpr, dc, CTP_Condition, boolType, + /*isDiscarded=*/false); + auto newWhereTarget = cs.applySolution(solution, whereTarget); + if (!newWhereTarget) + return None; + + forEachStmtInfo.whereExpr = newWhereTarget->getAsExpr(); + } + + // Invoke iterator() to get an iterator from the sequence. + ASTContext &ctx = cs.getASTContext(); + VarDecl *iterator; + Type nextResultType = OptionalType::get(forEachStmtInfo.elementType); + { + // Create a local variable to capture the iterator. + std::string name; + if (auto np = dyn_cast_or_null(stmt->getPattern())) + name = "$"+np->getBoundName().str().str(); + name += "$generator"; + + iterator = new (ctx) VarDecl( + /*IsStatic*/ false, VarDecl::Introducer::Var, + /*IsCaptureList*/ false, stmt->getInLoc(), + ctx.getIdentifier(name), dc); + iterator->setInterfaceType( + forEachStmtInfo.iteratorType->mapTypeOutOfContext()); + iterator->setImplicit(); + + auto genPat = new (ctx) NamedPattern(iterator); + genPat->setImplicit(); + + // TODO: test/DebugInfo/iteration.swift requires this extra info to + // be around. + PatternBindingDecl::createImplicit( + ctx, StaticSpellingKind::None, genPat, + new (ctx) OpaqueValueExpr(stmt->getInLoc(), nextResultType), + dc, /*VarLoc*/ stmt->getForLoc()); + } + + // Create the iterator variable. + auto *varRef = TypeChecker::buildCheckedRefExpr( + iterator, dc, DeclNameLoc(stmt->getInLoc()), /*implicit*/ true); + + // Convert that Optional value to the type of the pattern. + auto optPatternType = OptionalType::get(forEachStmtInfo.initType); + if (!optPatternType->isEqual(nextResultType)) { + OpaqueValueExpr *elementExpr = + new (ctx) OpaqueValueExpr(stmt->getInLoc(), nextResultType, + /*isPlaceholder=*/true); + Expr *convertElementExpr = elementExpr; + if (TypeChecker::typeCheckExpression( + convertElementExpr, dc, + TypeLoc::withoutLoc(optPatternType), + CTP_CoerceOperand).isNull()) { + return None; + } + elementExpr->setIsPlaceholder(false); + stmt->setElementExpr(elementExpr); + stmt->setConvertElementExpr(convertElementExpr); + } + + // Write the result back into the AST. + stmt->setSequence(resultTarget.getAsExpr()); + stmt->setPattern(resultTarget.getContextualPattern().getPattern()); + stmt->setSequenceConformance(sequenceConformance); + stmt->setWhere(forEachStmtInfo.whereExpr); + stmt->setIteratorVar(iterator); + stmt->setIteratorVarRef(varRef); + + return resultTarget; +} + Optional ExprWalker::rewriteTarget(SolutionApplicationTarget target) { auto &solution = Rewriter.solution; @@ -7972,16 +8105,50 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) { if (!rewrittenExpr) return None; - /// Handle application for initializations. - if (target.getExprContextualTypePurpose() == CTP_Initialization) { + /// Handle special cases for expressions. + switch (target.getExprContextualTypePurpose()) { + case CTP_Initialization: { auto initResultTarget = applySolutionToInitialization( solution, target, rewrittenExpr); if (!initResultTarget) return None; result = *initResultTarget; - } else { + break; + } + + case CTP_ForEachStmt: { + auto forEachResultTarget = applySolutionToForEachStmt( + solution, target, rewrittenExpr); + if (!forEachResultTarget) + return None; + + result = *forEachResultTarget; + break; + } + + case CTP_Unused: + case CTP_ReturnStmt: + case swift::CTP_ReturnSingleExpr: + case swift::CTP_YieldByValue: + case swift::CTP_YieldByReference: + case swift::CTP_ThrowStmt: + case swift::CTP_EnumCaseRawValue: + case swift::CTP_DefaultParameter: + case swift::CTP_AutoclosureDefaultParameter: + case swift::CTP_CalleeResult: + case swift::CTP_CallArgument: + case swift::CTP_ClosureResult: + case swift::CTP_ArrayElement: + case swift::CTP_DictionaryKey: + case swift::CTP_DictionaryValue: + case swift::CTP_CoerceOperand: + case swift::CTP_AssignSource: + case swift::CTP_SubscriptAssignSource: + case swift::CTP_Condition: + case swift::CTP_CannotFail: result.setExpr(rewrittenExpr); + break; } } else if (auto stmtCondition = target.getAsStmtCondition()) { for (auto &condElement : *stmtCondition) { diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 7191f143ba3..c25861178b5 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4128,6 +4128,112 @@ static bool generateInitPatternConstraints( return false; } +/// Generate constraints for a for-each statement. +static Optional +generateForEachStmtConstraints( + ConstraintSystem &cs, SolutionApplicationTarget target, Expr *sequence) { + auto forEachStmtInfo = target.getForEachStmtInfo(); + ForEachStmt *stmt = forEachStmtInfo.stmt; + + auto locator = cs.getConstraintLocator(sequence); + auto contextualLocator = + cs.getConstraintLocator(sequence, LocatorPathElt::ContextualType()); + + // The expression type must conform to the Sequence protocol. + auto sequenceProto = TypeChecker::getProtocol( + cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence); + if (!sequenceProto) { + return None; + } + + Type sequenceType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape); + cs.addConstraint(ConstraintKind::Conversion, cs.getType(sequence), + sequenceType, locator); + cs.addConstraint(ConstraintKind::ConformsTo, sequenceType, + sequenceProto->getDeclaredType(), contextualLocator); + + // Check the element pattern. + ASTContext &ctx = cs.getASTContext(); + auto dc = target.getDeclContext(); + Pattern *pattern = TypeChecker::resolvePattern(stmt->getPattern(), dc, + /*isStmtCondition*/false); + if (!pattern) + return None; + + auto contextualPattern = + ContextualPattern::forRawPattern(pattern, dc); + Type patternType = TypeChecker::typeCheckPattern(contextualPattern); + if (patternType->hasError()) { + return None; + } + + // Collect constraints from the element pattern. + auto elementLocator = cs.getConstraintLocator( + contextualLocator, ConstraintLocator::SequenceElementType); + Type initType = cs.generateConstraints( + pattern, contextualLocator, target.shouldBindPatternVarsOneWay(), + nullptr, 0); + if (!initType) + return None; + + // Add a conversion constraint between the element type of the sequence + // and the type of the element pattern. + auto elementAssocType = + sequenceProto->getAssociatedType(cs.getASTContext().Id_Element); + Type elementType = DependentMemberType::get(sequenceType, elementAssocType); + cs.addConstraint(ConstraintKind::Conversion, elementType, initType, + elementLocator); + + // Determine the iterator type. + auto iteratorAssocType = + sequenceProto->getAssociatedType(cs.getASTContext().Id_Iterator); + Type iteratorType = DependentMemberType::get(sequenceType, iteratorAssocType); + + // The iterator type must conform to IteratorProtocol. + ProtocolDecl *iteratorProto = TypeChecker::getProtocol( + cs.getASTContext(), stmt->getForLoc(), + KnownProtocolKind::IteratorProtocol); + if (!iteratorProto) + return None; + + // Reference the makeIterator witness. + FuncDecl *makeIterator = ctx.getSequenceMakeIterator(); + Type makeIteratorType = + cs.createTypeVariable(locator, TVO_CanBindToNoEscape); + cs.addValueWitnessConstraint( + LValueType::get(sequenceType), makeIterator, + makeIteratorType, dc, FunctionRefKind::Compound, + contextualLocator); + + // Generate constraints for the "where" expression, if there is one. + if (forEachStmtInfo.whereExpr) { + auto *boolDecl = dc->getASTContext().getBoolDecl(); + if (!boolDecl) + return None; + + Type boolType = boolDecl->getDeclaredType(); + if (!boolType) + return None; + + SolutionApplicationTarget whereTarget( + forEachStmtInfo.whereExpr, dc, CTP_Condition, boolType, + /*isDiscarded=*/false); + if (cs.generateConstraints(whereTarget, FreeTypeVariableBinding::Disallow)) + return None; + + forEachStmtInfo.whereExpr = whereTarget.getAsExpr(); + } + + // Populate all of the information for a for-each loop. + forEachStmtInfo.elementType = elementType; + forEachStmtInfo.iteratorType = iteratorType; + forEachStmtInfo.initType = initType; + forEachStmtInfo.sequenceType = sequenceType; + target.setPattern(pattern); + target.getForEachStmtInfo() = forEachStmtInfo; + return target; +} + bool ConstraintSystem::generateConstraints( SolutionApplicationTarget &target, FreeTypeVariableBinding allowFreeTypeVariables) { @@ -4186,6 +4292,16 @@ bool ConstraintSystem::generateConstraints( return true; } + // For a for-each statement, generate constraints for the pattern, where + // clause, and sequence traversal. + if (target.getExprContextualTypePurpose() == CTP_ForEachStmt) { + auto resultTarget = generateForEachStmtConstraints(*this, target, expr); + if (!resultTarget) + return true; + + target = *resultTarget; + } + if (getASTContext().TypeCheckerOpts.DebugConstraintSolver) { auto &log = getASTContext().TypeCheckerDebug->getStream(); log << "---Initial constraints for the given expression---\n"; diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index b2ad9fd9073..f2a4fa76290 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -4161,8 +4161,8 @@ SolutionApplicationTarget::SolutionApplicationTarget( expression.wrappedVar = nullptr; expression.isDiscarded = isDiscarded; expression.bindPatternVarsOneWay = false; - expression.patternBinding = nullptr; - expression.patternBindingIndex = 0; + expression.initialization.patternBinding = nullptr; + expression.initialization.patternBindingIndex = 0; } void SolutionApplicationTarget::maybeApplyPropertyWrapper() { @@ -4259,18 +4259,35 @@ SolutionApplicationTarget SolutionApplicationTarget::forInitialization( auto result = forInitialization( initializer, dc, patternType, patternBinding->getPattern(patternBindingIndex), bindPatternVarsOneWay); - result.expression.patternBinding = patternBinding; - result.expression.patternBindingIndex = patternBindingIndex; + result.expression.initialization.patternBinding = patternBinding; + result.expression.initialization.patternBindingIndex = patternBindingIndex; return result; } +SolutionApplicationTarget SolutionApplicationTarget::forForEachStmt( + ForEachStmt *stmt, ProtocolDecl *sequenceProto, DeclContext *dc, + bool bindPatternVarsOneWay) { + SolutionApplicationTarget target( + stmt->getSequence(), dc, CTP_ForEachStmt, + sequenceProto->getDeclaredType(), /*isDiscarded=*/false); + target.expression.pattern = stmt->getPattern(); + target.expression.bindPatternVarsOneWay = + bindPatternVarsOneWay || (stmt->getWhere() != nullptr); + target.expression.forEachStmt.stmt = stmt; + target.expression.forEachStmt.whereExpr = stmt->getWhere(); + return target; +} + ContextualPattern -SolutionApplicationTarget::getInitializationContextualPattern() const { +SolutionApplicationTarget::getContextualPattern() const { assert(kind == Kind::expression); - assert(expression.contextualPurpose == CTP_Initialization); - if (expression.patternBinding) { + assert(expression.contextualPurpose == CTP_Initialization || + expression.contextualPurpose == CTP_ForEachStmt); + if (expression.contextualPurpose == CTP_Initialization && + expression.initialization.patternBinding) { return ContextualPattern::forPatternBindingDecl( - expression.patternBinding, expression.patternBindingIndex); + expression.initialization.patternBinding, + expression.initialization.patternBindingIndex); } return ContextualPattern::forRawPattern(expression.pattern, expression.dc); diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h index 4b09bfc48c5..3921bffa4c6 100644 --- a/lib/Sema/ConstraintSystem.h +++ b/lib/Sema/ConstraintSystem.h @@ -812,6 +812,27 @@ struct CaseLabelItemInfo { Expr *guardExpr; }; +/// Describes information about a for-each loop that needs to be tracked +/// within the constraint system. +struct ForEachStmtInfo { + ForEachStmt *stmt; + + /// The type of the sequence. + Type sequenceType; + + /// The type of the iterator. + Type iteratorType; + + /// The type of an element in the sequence. + Type elementType; + + /// The type of the pattern that matches the elements. + Type initType; + + /// The "where" expression, if there is one. + Expr *whereExpr; +}; + /// Key to the constraint solver's mapping from AST nodes to their corresponding /// solution application targets. using SolutionApplicationTargetsKey = @@ -1209,11 +1230,17 @@ class SolutionApplicationTarget { /// fresh type variables via one-way constraints. bool bindPatternVarsOneWay; - /// The pattern binding declaration for an initialization, if any. - PatternBindingDecl *patternBinding; + union { + struct { + /// The pattern binding declaration for an initialization, if any. + PatternBindingDecl *patternBinding; - /// The index into the pattern binding declaration, if any. - unsigned patternBindingIndex; + /// The index into the pattern binding declaration, if any. + unsigned patternBindingIndex; + } initialization; + + ForEachStmtInfo forEachStmt; + }; } expression; struct { @@ -1282,6 +1309,11 @@ public: PatternBindingDecl *patternBinding, unsigned patternBindingIndex, bool bindPatternVarsOneWay); + /// Form a target for a for-each loop. + static SolutionApplicationTarget forForEachStmt( + ForEachStmt *stmt, ProtocolDecl *sequenceProto, DeclContext *dc, + bool bindPatternVarsOneWay); + Expr *getAsExpr() const { switch (kind) { case Kind::expression: @@ -1367,7 +1399,7 @@ public: } /// For a pattern initialization target, retrieve the contextual pattern. - ContextualPattern getInitializationContextualPattern() const; + ContextualPattern getContextualPattern() const; /// Whether this is an initialization for an Optional.Some pattern. bool isOptionalSomePatternInit() const { @@ -1379,9 +1411,7 @@ public: /// Whether to bind the types of any variables within the pattern via /// one-way constraints. bool shouldBindPatternVarsOneWay() const { - return kind == Kind::expression && - expression.contextualPurpose == CTP_Initialization && - expression.bindPatternVarsOneWay; + return kind == Kind::expression && expression.bindPatternVarsOneWay; } /// Retrieve the wrapped variable when initializing a pattern with a @@ -1395,13 +1425,25 @@ public: PatternBindingDecl *getInitializationPatternBindingDecl() const { assert(kind == Kind::expression); assert(expression.contextualPurpose == CTP_Initialization); - return expression.patternBinding; + return expression.initialization.patternBinding; } unsigned getInitializationPatternBindingIndex() const { assert(kind == Kind::expression); assert(expression.contextualPurpose == CTP_Initialization); - return expression.patternBindingIndex; + return expression.initialization.patternBindingIndex; + } + + const ForEachStmtInfo &getForEachStmtInfo() const { + assert(kind == Kind::expression); + assert(expression.contextualPurpose == CTP_ForEachStmt); + return expression.forEachStmt; + } + + ForEachStmtInfo &getForEachStmtInfo() { + assert(kind == Kind::expression); + assert(expression.contextualPurpose == CTP_ForEachStmt); + return expression.forEachStmt; } /// Whether this context infers an opaque return type. @@ -1422,7 +1464,8 @@ public: void setPattern(Pattern *pattern) { assert(kind == Kind::expression); - assert(expression.contextualPurpose == CTP_Initialization); + assert(expression.contextualPurpose == CTP_Initialization || + expression.contextualPurpose == CTP_ForEachStmt); expression.pattern = pattern; } diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index 4e1943a2e08..6c243260672 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -2468,244 +2468,29 @@ bool TypeChecker::typeCheckPatternBinding(PatternBindingDecl *PBD, } bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) { - /// Type checking listener for for-each binding. - class BindingListener : public ExprTypeCheckListener { - /// The for-each statement. - ForEachStmt *Stmt; - - /// The declaration context in which this for-each statement resides. - DeclContext *DC; - - /// The locator we're using. - ConstraintLocator *Locator; - - /// The contextual locator we're using. - ConstraintLocator *ContextualLocator; - - /// The Sequence protocol. - ProtocolDecl *SequenceProto; - - /// The IteratorProtocol. - ProtocolDecl *IteratorProto; - - /// The type of the initializer. - Type InitType; - - /// The type of the sequence. - Type SequenceType; - - /// The conformance of the sequence type to the Sequence protocol. - ProtocolConformanceRef SequenceConformance; - - /// The type of the element. - Type ElementType; - - /// The type of the iterator. - Type IteratorType; - - public: - explicit BindingListener(ForEachStmt *stmt, DeclContext *dc) - : Stmt(stmt), DC(dc) { } - - bool builtConstraints(ConstraintSystem &cs, Expr *expr) override { - // Save the locator we're using for the expression. - Locator = cs.getConstraintLocator(expr); - ContextualLocator = - cs.getConstraintLocator(expr, LocatorPathElt::ContextualType()); - - // The expression type must conform to the Sequence protocol. - SequenceProto = TypeChecker::getProtocol( - cs.getASTContext(), Stmt->getForLoc(), KnownProtocolKind::Sequence); - if (!SequenceProto) { - return true; - } - - SequenceType = cs.createTypeVariable(Locator, TVO_CanBindToNoEscape); - cs.addConstraint(ConstraintKind::Conversion, cs.getType(expr), - SequenceType, Locator); - cs.addConstraint(ConstraintKind::ConformsTo, SequenceType, - SequenceProto->getDeclaredType(), ContextualLocator); - - auto elementLocator = cs.getConstraintLocator( - ContextualLocator, ConstraintLocator::SequenceElementType); - - // Check the element pattern. - ASTContext &ctx = cs.getASTContext(); - if (auto *P = TypeChecker::resolvePattern(Stmt->getPattern(), DC, - /*isStmtCondition*/false)) { - Stmt->setPattern(P); - } else { - Stmt->getPattern()->setType(ErrorType::get(ctx)); - return true; - } - - auto contextualPattern = - ContextualPattern::forRawPattern(Stmt->getPattern(), DC); - Type patternType = TypeChecker::typeCheckPattern(contextualPattern); - if (patternType->hasError()) { - // FIXME: Handle errors better. - Stmt->getPattern()->setType(ErrorType::get(ctx)); - return true; - } - - // Collect constraints from the element pattern. - auto pattern = Stmt->getPattern(); - InitType = cs.generateConstraints( - pattern, elementLocator, /*bindPatternVarsOneWay=*/false, - /*patternBinding=*/nullptr, /*patternBindingIndex=*/0); - if (!InitType) - return true; - - // Add a conversion constraint between the element type of the sequence - // and the type of the element pattern. - auto elementAssocType = - SequenceProto->getAssociatedType(cs.getASTContext().Id_Element); - ElementType = DependentMemberType::get(SequenceType, elementAssocType); - cs.addConstraint(ConstraintKind::Conversion, ElementType, InitType, - elementLocator); - - // Determine the iterator type. - auto iteratorAssocType = - SequenceProto->getAssociatedType(cs.getASTContext().Id_Iterator); - IteratorType = DependentMemberType::get(SequenceType, iteratorAssocType); - - // The iterator type must conform to IteratorProtocol. - IteratorProto = TypeChecker::getProtocol( - cs.getASTContext(), Stmt->getForLoc(), - KnownProtocolKind::IteratorProtocol); - if (!IteratorProto) { - return true; - } - - // Reference the makeIterator witness. - FuncDecl *makeIterator = ctx.getSequenceMakeIterator(); - Type makeIteratorType = - cs.createTypeVariable(Locator, TVO_CanBindToNoEscape); - cs.addValueWitnessConstraint( - LValueType::get(SequenceType), makeIterator, - makeIteratorType, DC, FunctionRefKind::Compound, - ContextualLocator); - - Stmt->setSequence(expr); - return false; - } - - Expr *appliedSolution(Solution &solution, Expr *expr) override { - // Figure out what types the constraints decided on. - auto &cs = solution.getConstraintSystem(); - ASTContext &ctx = cs.getASTContext(); - InitType = solution.simplifyType(InitType); - SequenceType = solution.simplifyType(SequenceType); - ElementType = solution.simplifyType(ElementType); - IteratorType = solution.simplifyType(IteratorType); - - // If the type doesn't conform to Sequence we'll get its element type - // bound to `UnresolvedType` since fixes are allowed. - if (InitType->is()) - return nullptr; - - cs.cacheExprTypes(expr); - Stmt->setSequence(expr); - solution.setExprTypes(expr); - - // Apply the solution to the iteration pattern as well. - Pattern *pattern = Stmt->getPattern(); - TypeResolutionOptions options(TypeResolverContext::ForEachStmt); - options |= TypeResolutionFlags::OverrideType; - auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC); - pattern = TypeChecker::coercePatternToType(contextualPattern, - InitType, options); - if (!pattern) - return nullptr; - Stmt->setPattern(pattern); - - // Get the conformance of the sequence type to the Sequence protocol. - SequenceConformance = solution.resolveConformance( - ContextualLocator, SequenceProto); - assert(!SequenceConformance.isInvalid() && - "Couldn't find sequence conformance"); - Stmt->setSequenceConformance(SequenceConformance); - - // Check the filtering condition. - // FIXME: This should be pulled into the constraint system itself. - if (auto *Where = Stmt->getWhere()) { - if (!TypeChecker::typeCheckCondition(Where, DC)) - Stmt->setWhere(Where); - } - - // Invoke iterator() to get an iterator from the sequence. - VarDecl *iterator; - Type nextResultType = OptionalType::get(ElementType); - { - // Create a local variable to capture the iterator. - std::string name; - if (auto np = dyn_cast_or_null(Stmt->getPattern())) - name = "$"+np->getBoundName().str().str(); - name += "$generator"; - - iterator = new (ctx) VarDecl( - /*IsStatic*/ false, VarDecl::Introducer::Var, - /*IsCaptureList*/ false, Stmt->getInLoc(), - ctx.getIdentifier(name), DC); - iterator->setInterfaceType(IteratorType->mapTypeOutOfContext()); - iterator->setImplicit(); - Stmt->setIteratorVar(iterator); - - auto genPat = new (ctx) NamedPattern(iterator); - genPat->setImplicit(); - - // TODO: test/DebugInfo/iteration.swift requires this extra info to - // be around. - PatternBindingDecl::createImplicit( - ctx, StaticSpellingKind::None, genPat, - new (ctx) OpaqueValueExpr(Stmt->getInLoc(), nextResultType), - DC, /*VarLoc*/ Stmt->getForLoc()); - } - - // Create the iterator variable. - auto *varRef = TypeChecker::buildCheckedRefExpr( - iterator, DC, DeclNameLoc(Stmt->getInLoc()), /*implicit*/ true); - if (varRef) - Stmt->setIteratorVarRef(varRef); - - // Convert that Optional value to the type of the pattern. - auto optPatternType = OptionalType::get(Stmt->getPattern()->getType()); - if (!optPatternType->isEqual(nextResultType)) { - OpaqueValueExpr *elementExpr = - new (ctx) OpaqueValueExpr(Stmt->getInLoc(), nextResultType, - /*isPlaceholder=*/true); - Expr *convertElementExpr = elementExpr; - if (TypeChecker::typeCheckExpression( - convertElementExpr, DC, - TypeLoc::withoutLoc(optPatternType), - CTP_CoerceOperand).isNull()) { - return nullptr; - } - elementExpr->setIsPlaceholder(false); - Stmt->setElementExpr(elementExpr); - Stmt->setConvertElementExpr(convertElementExpr); - } - - return expr; - } - }; - - BindingListener listener(stmt, dc); - Expr *seq = stmt->getSequence(); - assert(seq && "type-checking an uninitialized for-each statement?"); - auto sequenceProto = TypeChecker::getProtocol( dc->getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence); if (!sequenceProto) return true; - // Type-check the for-each loop sequence and element pattern. - auto resultTy = TypeChecker::typeCheckExpression( - seq, dc, TypeLoc::withoutLoc(sequenceProto->getDeclaredType()), - CTP_ForEachStmt, None, &listener); - if (!resultTy) + // Precheck the sequence. + Expr *sequence = stmt->getSequence(); + if (ConstraintSystem::preCheckExpression(sequence, dc)) return true; - return false; + stmt->setSequence(sequence); + + // Precheck the filtering condition. + if (Expr *whereExpr = stmt->getWhere()) { + if (ConstraintSystem::preCheckExpression(whereExpr, dc)) + return true; + + stmt->setWhere(whereExpr); + } + + auto target = SolutionApplicationTarget::forForEachStmt( + stmt, sequenceProto, dc, /*bindPatternVarsOneWay=*/false); + bool unresolvedTypeExprs = false; + return !typeCheckExpression(target, unresolvedTypeExprs); } bool TypeChecker::typeCheckCondition(Expr *&expr, DeclContext *dc) {