diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 51ef0a00b9b..3c3d0dff4a6 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -251,7 +251,8 @@ void ConstraintSystem::applySolution(const Solution &solution) { // Add the node types back. for (auto &nodeType : solution.addedNodeTypes) { - setType(nodeType.first, nodeType.second); + if (!hasType(nodeType.first)) + setType(nodeType.first, nodeType.second); } // Register the conformances checked along the way to arrive to solution. diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h index f127bfc530f..05185b5fa4c 100644 --- a/lib/Sema/ConstraintSystem.h +++ b/lib/Sema/ConstraintSystem.h @@ -571,7 +571,9 @@ struct Score { }; /// An AST node that can gain type information while solving. -using TypedNode = llvm::PointerUnion3; +using TypedNode = + llvm::PointerUnion3; /// Display a score. llvm::raw_ostream &operator<<(llvm::raw_ostream &out, const Score &score); @@ -1756,12 +1758,12 @@ public: assert(type && "Expected non-null type"); // Record the type. - if (auto expr = node.dyn_cast()) { + if (auto expr = node.dyn_cast()) { ExprTypes[expr] = type.getPointer(); - } else if (auto typeLoc = node.dyn_cast()) { + } else if (auto typeLoc = node.dyn_cast()) { TypeLocTypes[typeLoc] = type.getPointer(); } else { - auto param = node.get(); + auto param = node.get(); ParamTypes[param] = type.getPointer(); } @@ -1775,26 +1777,18 @@ public: /// map is used throughout the expression type checker in order to /// avoid mutating expressions until we know we have successfully /// type-checked them. - void setType(Expr *E, Type T) { - setType(TypedNode(E), T); - } - void setType(TypeLoc &L, Type T) { setType(TypedNode(&L), T); } - void setType(ParamDecl *P, Type T) { - setType(TypedNode(P), T); - } - /// Erase the type for the given node. void eraseType(TypedNode node) { - if (auto expr = node.dyn_cast()) { + if (auto expr = node.dyn_cast()) { ExprTypes.erase(expr); - } else if (auto typeLoc = node.dyn_cast()) { + } else if (auto typeLoc = node.dyn_cast()) { TypeLocTypes.erase(typeLoc); } else { - auto param = node.get(); + auto param = node.get(); ParamTypes.erase(param); } } @@ -1812,12 +1806,20 @@ public: } bool hasType(const TypeLoc &L) const { - return TypeLocTypes.find(&L) != TypeLocTypes.end(); + return hasType(TypedNode(&L)); } - bool hasType(const ParamDecl *P) const { - assert(P != nullptr && "Expected non-null parameter!"); - return ParamTypes.find(P) != ParamTypes.end(); + /// Check to see if we have a type for a node. + bool hasType(TypedNode node) const { + assert(!node.isNull() && "Expected non-null node"); + if (auto expr = node.dyn_cast()) { + return ExprTypes.find(expr) != ExprTypes.end(); + } else if (auto typeLoc = node.dyn_cast()) { + return TypeLocTypes.find(typeLoc) != TypeLocTypes.end(); + } else { + auto param = node.get(); + return ParamTypes.find(param) != ParamTypes.end(); + } } bool hasType(const KeyPathExpr *KP, unsigned I) const { diff --git a/test/Constraints/function_builder_diags.swift b/test/Constraints/function_builder_diags.swift index 715dff972bf..a629f54b687 100644 --- a/test/Constraints/function_builder_diags.swift +++ b/test/Constraints/function_builder_diags.swift @@ -117,5 +117,12 @@ func testOverloading(name: String) { _ = overloadedTuplify(true) { b in b ? "Hello, \(name)" : "Goodbye" 42 + overloadedTuplify(false) { + $0 ? "Hello, \(name)" : "Goodbye" + 42 + if b { + "Hello, \(name)" + } + } } }