[CS] Allow ExprPatterns to be type-checked in the solver

Previously we would wait until CSApply, which
would trigger their type-checking in
`coercePatternToType`. This caused a number of
bugs, and hampered solver-based completion, which
does not run CSApply. Instead, form a conjunction
of all the ExprPatterns present, which preserves
some of the previous isolation behavior (though
does not provide complete isolation).

We can then modify `coercePatternToType` to accept
a closure, which allows the solver to take over
rewriting the ExprPatterns it has already solved.

This then sets the stage for the complete removal
of `coercePatternToType`, and doing all pattern
type-checking in the solver.
This commit is contained in:
Hamish Knight
2023-06-02 17:59:02 +01:00
parent 21e787bae8
commit 7a137d6756
21 changed files with 495 additions and 142 deletions

View File

@@ -1500,6 +1500,10 @@ public:
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
caseLabelItems;
/// A map of expressions to the ExprPatterns that they are being solved as
/// a part of.
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
/// The set of parameters that have been inferred to be 'isolated'.
llvm::SmallVector<ParamDecl *, 2> isolatedParams;
@@ -1685,6 +1689,16 @@ public:
: nullptr;
}
/// Retrieve the solved ExprPattern that corresponds to provided
/// sub-expression.
NullablePtr<ExprPattern> getExprPatternFor(Expr *E) const {
auto result = exprPatterns.find(E);
if (result == exprPatterns.end())
return nullptr;
return result->second;
}
/// This method implements functionality of `Expr::isTypeReference`
/// with data provided by a given solution.
bool isTypeReference(Expr *E) const;
@@ -2148,6 +2162,10 @@ private:
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
caseLabelItems;
/// A map of expressions to the ExprPatterns that they are being solved as
/// a part of.
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
/// The set of parameters that have been inferred to be 'isolated'.
llvm::SmallSetVector<ParamDecl *, 2> isolatedParams;
@@ -2745,6 +2763,9 @@ public:
/// The length of \c caseLabelItems.
unsigned numCaseLabelItems;
/// The length of \c exprPatterns.
unsigned numExprPatterns;
/// The length of \c isolatedParams.
unsigned numIsolatedParams;
@@ -3166,6 +3187,15 @@ public:
caseLabelItems[item] = info;
}
/// Record a given ExprPattern as the parent of its sub-expression.
void setExprPatternFor(Expr *E, ExprPattern *EP) {
assert(E);
assert(EP);
auto inserted = exprPatterns.insert({E, EP}).second;
assert(inserted && "Mapping already defined?");
(void)inserted;
}
Optional<CaseLabelItemInfo> getCaseLabelItemInfo(
const CaseLabelItem *item) const {
auto known = caseLabelItems.find(item);
@@ -4315,6 +4345,11 @@ public:
/// \returns \c true if constraint generation failed, \c false otherwise
bool generateConstraints(SingleValueStmtExpr *E);
/// Generate constraints for an array of ExprPatterns, forming a conjunction
/// that solves each expression in turn.
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
ConstraintLocatorBuilder locator);
/// Generate constraints for the given (unchecked) expression.
///
/// \returns a possibly-sanitized expression, or null if an error occurred.

View File

@@ -81,7 +81,13 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
/// \endcode
/// If the code completion expression occurs in such an AST, return the
/// declaration of the \c $match variable, otherwise return \c nullptr.
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, const Solution &S) {
if (auto EP = S.getExprPatternFor(E))
return EP.get()->getMatchVar();
// TODO: Once ExprPattern type-checking is fully moved into the solver,
// the below can be deleted.
auto &CS = S.getConstraintSystem();
auto &Context = CS.getASTContext();
auto *Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(E));
@@ -109,20 +115,21 @@ static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
}
Type swift::ide::getPatternMatchType(const constraints::Solution &S, Expr *E) {
if (auto MatchVar = getMatchVarIfInPatternMatch(E, S.getConstraintSystem())) {
Type MatchVarType;
// If the MatchVar has an explicit type, it's not part of the solution. But
// we can look it up in the constraint system directly.
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
MatchVarType = T;
} else {
MatchVarType = getTypeForCompletion(S, MatchVar);
}
if (MatchVarType) {
return MatchVarType;
}
}
return nullptr;
auto MatchVar = getMatchVarIfInPatternMatch(E, S);
if (!MatchVar)
return nullptr;
if (S.hasType(MatchVar))
return S.getResolvedType(MatchVar);
// If the ExprPattern wasn't solved as part of the constraint system, it's
// not part of the solution.
// TODO: This can be removed once ExprPattern type-checking is fully part
// of the constraint system.
if (auto T = S.getConstraintSystem().getVarType(MatchVar))
return T;
return getTypeForCompletion(S, MatchVar);
}
void swift::ide::getSolutionSpecificVarTypes(

View File

@@ -8705,6 +8705,9 @@ namespace {
return Action::SkipChildren();
}
NullablePtr<Pattern>
rewritePattern(Pattern *pattern, DeclContext *DC);
/// Rewrite the target, producing a new target.
Optional<SyntacticElementTarget>
rewriteTarget(SyntacticElementTarget target);
@@ -8951,12 +8954,68 @@ static Expr *wrapAsyncLetInitializer(
return resultInit;
}
static Pattern *rewriteExprPattern(const SyntacticElementTarget &matchTarget,
Type patternTy,
RewriteTargetFn rewriteTarget) {
auto *EP = matchTarget.getExprPattern();
// See if we can simplify to another kind of pattern.
if (auto simplified = TypeChecker::trySimplifyExprPattern(EP, patternTy))
return simplified.get();
auto resultTarget = rewriteTarget(matchTarget);
if (!resultTarget)
return nullptr;
EP->setMatchExpr(resultTarget->getAsExpr());
EP->getMatchVar()->setInterfaceType(patternTy->mapTypeOutOfContext());
EP->setType(patternTy);
return EP;
}
/// Attempt to rewrite either an ExprPattern, or a pattern that was solved as
/// an ExprPattern, e.g an EnumElementPattern that could not refer to an enum
/// case.
static Optional<Pattern *>
tryRewriteExprPattern(Pattern *P, Solution &solution, Type patternTy,
RewriteTargetFn rewriteTarget) {
// See if we have a match expression target.
auto matchTarget = solution.getTargetFor(P);
if (!matchTarget)
return None;
return rewriteExprPattern(*matchTarget, patternTy, rewriteTarget);
}
NullablePtr<Pattern> ExprWalker::rewritePattern(Pattern *pattern,
DeclContext *DC) {
auto &solution = Rewriter.solution;
// Figure out the pattern type.
auto patternTy = solution.getResolvedType(pattern);
patternTy = patternTy->reconstituteSugar(/*recursive=*/false);
// Coerce the pattern to its appropriate type.
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
patternOptions |= TypeResolutionFlags::OverrideType;
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
return ::tryRewriteExprPattern(
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
};
auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC);
return TypeChecker::coercePatternToType(contextualPattern, patternTy,
patternOptions, tryRewritePattern);
}
/// Apply the given solution to the initialization target.
///
/// \returns the resulting initialization expression.
static Optional<SyntacticElementTarget>
applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
Expr *initializer) {
Expr *initializer,
RewriteTargetFn rewriteTarget) {
auto wrappedVar = target.getInitializationWrappedVar();
Type initType;
if (wrappedVar) {
@@ -9021,10 +9080,14 @@ applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
};
// Apply the solution to the pattern as well.
auto contextualPattern = target.getContextualPattern();
if (auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, finalPatternType, options)) {
contextualPattern, finalPatternType, options, tryRewritePattern)) {
resultTarget.setPattern(coercedPattern);
} else {
return None;
@@ -9171,10 +9234,15 @@ static Optional<SyntacticElementTarget> applySolutionToForEachStmt(
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
options |= TypeResolutionFlags::OverrideType;
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
};
// Apply the solution to the pattern as well.
auto contextualPattern = target.getContextualPattern();
auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, forEachStmtInfo.initType, options);
contextualPattern, forEachStmtInfo.initType, options,
tryRewritePattern);
if (!coercedPattern)
return None;
@@ -9262,7 +9330,8 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
switch (target.getExprContextualTypePurpose()) {
case CTP_Initialization: {
auto initResultTarget = applySolutionToInitialization(
solution, target, rewrittenExpr);
solution, target, rewrittenExpr,
[&](auto target) { return rewriteTarget(target); });
if (!initResultTarget)
return None;
@@ -9353,47 +9422,11 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
ConstraintSystem &cs = solution.getConstraintSystem();
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
// Figure out the pattern type.
Type patternType = solution.simplifyType(solution.getType(info.pattern));
patternType = patternType->reconstituteSugar(/*recursive=*/false);
// Check whether this enum element is resolved via ~= application.
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
if (auto target = cs.getTargetFor(enumElement)) {
auto *EP = target->getExprPattern();
auto enumType = solution.getResolvedType(EP);
auto *matchCall = target->getAsExpr();
auto *result = matchCall->walk(*this);
if (!result)
return None;
{
auto *matchVar = EP->getMatchVar();
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
}
EP->setMatchExpr(result);
EP->setType(enumType);
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
return target;
}
}
// Coerce the pattern to its appropriate type.
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
patternOptions |= TypeResolutionFlags::OverrideType;
auto contextualPattern =
ContextualPattern::forRawPattern(info.pattern,
target.getDeclContext());
if (auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, patternType, patternOptions)) {
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
} else {
auto pattern = rewritePattern(info.pattern, target.getDeclContext());
if (!pattern)
return None;
}
(*caseLabelItem)->setPattern(pattern.get(), /*resolved=*/true);
// If there is a guard expression, coerce that.
if (auto *guardExpr = info.guardExpr) {
@@ -9461,8 +9494,13 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
options |= TypeResolutionFlags::OverrideType;
}
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
return ::tryRewriteExprPattern(
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
};
if (auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, patternType, options)) {
contextualPattern, patternType, options, tryRewritePattern)) {
auto resultTarget = target;
resultTarget.setPattern(coercedPattern);
return resultTarget;

View File

@@ -2453,12 +2453,6 @@ namespace {
// function, to set the type of the pattern.
auto setType = [&](Type type) {
CS.setType(pattern, type);
if (auto PE = dyn_cast<ExprPattern>(pattern)) {
// Set the type of the pattern's sub-expression as well, so code
// completion can retrieve the expression's type in case it is a code
// completion token.
CS.setType(PE->getSubExpr(), type);
}
return type;
};
@@ -2883,15 +2877,12 @@ namespace {
return setType(patternType);
}
// Refutable patterns occur when checking the PatternBindingDecls in an
// if/let or while/let condition. They always require an initial value,
// so they always allow unspecified types.
case PatternKind::Expr:
// TODO: we could try harder here, e.g. for enum elements to provide the
// enum type.
return setType(
CS.createTypeVariable(CS.getConstraintLocator(locator),
TVO_CanBindToNoEscape | TVO_CanBindToHole));
case PatternKind::Expr: {
// We generate constraints for ExprPatterns in a separate pass. For
// now, just create a type variable.
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
TVO_CanBindToNoEscape));
}
}
llvm_unreachable("Unhandled pattern kind");
@@ -4760,8 +4751,20 @@ Type ConstraintSystem::generateConstraints(
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
unsigned patternIndex) {
ConstraintGenerator cg(*this, nullptr);
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
patternBinding, patternIndex);
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
patternBinding, patternIndex);
assert(ty);
// Gather the ExprPatterns, and form a conjunction for their expressions.
SmallVector<ExprPattern *, 4> exprPatterns;
pattern->forEachNode([&](Pattern *P) {
if (auto *EP = dyn_cast<ExprPattern>(P))
exprPatterns.push_back(EP);
});
if (!exprPatterns.empty())
generateConstraints(exprPatterns, getConstraintLocator(pattern));
return ty;
}
bool ConstraintSystem::generateConstraints(StmtCondition condition,

View File

@@ -195,6 +195,7 @@ Solution ConstraintSystem::finalize() {
solution.targets = targets;
solution.caseLabelItems = caseLabelItems;
solution.exprPatterns = exprPatterns;
solution.isolatedParams.append(isolatedParams.begin(), isolatedParams.end());
solution.preconcurrencyClosures.append(preconcurrencyClosures.begin(),
preconcurrencyClosures.end());
@@ -327,6 +328,9 @@ void ConstraintSystem::applySolution(const Solution &solution) {
isolatedParams.insert(param);
}
for (auto &pair : solution.exprPatterns)
exprPatterns.insert(pair);
for (auto closure : solution.preconcurrencyClosures) {
preconcurrencyClosures.insert(closure);
}
@@ -621,6 +625,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
numContextualTypes = cs.contextualTypes.size();
numTargets = cs.targets.size();
numCaseLabelItems = cs.caseLabelItems.size();
numExprPatterns = cs.exprPatterns.size();
numIsolatedParams = cs.isolatedParams.size();
numPreconcurrencyClosures = cs.preconcurrencyClosures.size();
numImplicitValueConversions = cs.ImplicitValueConversions.size();
@@ -737,6 +742,9 @@ ConstraintSystem::SolverScope::~SolverScope() {
// Remove any case label item infos.
truncate(cs.caseLabelItems, numCaseLabelItems);
// Remove any ExprPattern mappings.
truncate(cs.exprPatterns, numExprPatterns);
// Remove any isolated parameters.
truncate(cs.isolatedParams, numIsolatedParams);

View File

@@ -410,7 +410,7 @@ ElementInfo makeJoinElement(ConstraintSystem &cs, TypeJoinExpr *join,
struct SyntacticElementContext
: public llvm::PointerUnion<AbstractFunctionDecl *, AbstractClosureExpr *,
SingleValueStmtExpr *> {
SingleValueStmtExpr *, ExprPattern *> {
// Inherit the constructors from PointerUnion.
using PointerUnion::PointerUnion;
@@ -441,6 +441,10 @@ struct SyntacticElementContext
return context;
}
static SyntacticElementContext forExprPattern(ExprPattern *EP) {
return SyntacticElementContext{EP};
}
DeclContext *getAsDeclContext() const {
if (auto *fn = this->dyn_cast<AbstractFunctionDecl *>()) {
return fn;
@@ -448,6 +452,8 @@ struct SyntacticElementContext
return closure;
} else if (auto *SVE = dyn_cast<SingleValueStmtExpr *>()) {
return SVE->getDeclContext();
} else if (auto *EP = dyn_cast<ExprPattern *>()) {
return EP->getDeclContext();
} else {
llvm_unreachable("unsupported kind");
}
@@ -519,7 +525,32 @@ public:
ConstraintLocator *locator)
: cs(cs), context(context), locator(locator) {}
void visitPattern(Pattern *pattern, ContextualTypeInfo context) {
void visitExprPattern(ExprPattern *EP) {
auto target = SyntacticElementTarget::forExprPattern(EP);
if (cs.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
/*leaveClosureBodyUnchecked=*/false)) {
hadError = true;
return;
}
cs.setType(EP->getMatchVar(), cs.getType(EP));
if (cs.generateConstraints(target)) {
hadError = true;
return;
}
cs.setTargetFor(EP, target);
cs.setExprPatternFor(EP->getSubExpr(), EP);
}
void visitPattern(Pattern *pattern, ContextualTypeInfo contextInfo) {
if (context.is<ExprPattern *>()) {
// This is for an ExprPattern conjunction, go ahead and generate
// constraints for the match expression.
visitExprPattern(cast<ExprPattern>(pattern));
return;
}
auto parentElement =
locator->getLastElementAs<LocatorPathElt::SyntacticElement>();
@@ -535,7 +566,7 @@ public:
}
if (isa<CaseStmt>(stmt)) {
visitCaseItemPattern(pattern, context);
visitCaseItemPattern(pattern, contextInfo);
return;
}
}
@@ -1438,6 +1469,24 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
return generator.hadError;
}
void ConstraintSystem::generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
ConstraintLocatorBuilder locator) {
// Form a conjunction of ExprPattern elements, isolated from the rest of the
// pattern.
SmallVector<ElementInfo> elements;
SmallVector<TypeVariableType *, 2> referencedTypeVars;
for (auto *EP : exprPatterns) {
auto ty = getType(EP)->castTo<TypeVariableType>();
referencedTypeVars.push_back(ty);
ContextualTypeInfo context(ty, CTP_ExprPattern);
elements.push_back(makeElement(EP, getConstraintLocator(EP), context));
}
auto *loc = getConstraintLocator(locator);
createConjunction(*this, elements, loc, /*isIsolated*/ true,
referencedTypeVars);
}
bool ConstraintSystem::isInResultBuilderContext(ClosureExpr *closure) const {
if (!closure->hasSingleExpressionBody()) {
auto *DC = closure->getParent();
@@ -1488,6 +1537,8 @@ ConstraintSystem::simplifySyntacticElementConstraint(
context = SyntacticElementContext::forFunction(fn);
} else if (auto *SVE = getAsExpr<SingleValueStmtExpr>(anchor)) {
context = SyntacticElementContext::forSingleValueStmtExpr(SVE);
} else if (auto *EP = getAsPattern<ExprPattern>(anchor)) {
context = SyntacticElementContext::forExprPattern(EP);
} else {
return SolutionKind::Error;
}

View File

@@ -1034,15 +1034,61 @@ void repairTupleOrAssociatedValuePatternIfApplicable(
enumCase->getName());
}
NullablePtr<Pattern> TypeChecker::trySimplifyExprPattern(ExprPattern *EP,
Type patternTy) {
auto *subExpr = EP->getSubExpr();
auto &ctx = EP->getDeclContext()->getASTContext();
if (patternTy->isBool()) {
// The type is Bool.
// Check if the pattern is a Bool literal
auto *semanticSubExpr = subExpr->getSemanticsProvidingExpr();
if (auto *BLE = dyn_cast<BooleanLiteralExpr>(semanticSubExpr)) {
auto *BP = new (ctx) BoolPattern(BLE->getLoc(), BLE->getValue());
BP->setType(patternTy);
return BP;
}
}
// case nil is equivalent to .none when switching on Optionals.
if (auto *NLE = dyn_cast<NilLiteralExpr>(EP->getSubExpr())) {
if (patternTy->getOptionalObjectType()) {
auto *NoneEnumElement = ctx.getOptionalNoneDecl();
auto *BaseTE = TypeExpr::createImplicit(patternTy, ctx);
auto *EEP = new (ctx)
EnumElementPattern(BaseTE, NLE->getLoc(), DeclNameLoc(NLE->getLoc()),
NoneEnumElement->createNameRef(), NoneEnumElement,
nullptr, EP->getDeclContext());
EEP->setType(patternTy);
return EEP;
} else {
// ...but for non-optional types it can never match! Diagnose it.
ctx.Diags
.diagnose(NLE->getLoc(), diag::value_type_comparison_with_nil_illegal,
patternTy)
.warnUntilSwiftVersion(6);
if (ctx.isSwiftVersionAtLeast(6))
return nullptr;
}
}
return nullptr;
}
/// Perform top-down type coercion on the given pattern.
Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
Type type,
TypeResolutionOptions options) {
Pattern *TypeChecker::coercePatternToType(
ContextualPattern pattern, Type type, TypeResolutionOptions options,
llvm::function_ref<Optional<Pattern *>(Pattern *, Type)>
tryRewritePattern) {
auto P = pattern.getPattern();
auto dc = pattern.getDeclContext();
auto &Context = dc->getASTContext();
auto &diags = Context.Diags;
// See if we can rewrite this using the constraint system.
if (auto result = tryRewritePattern(P, type))
return *result;
options = applyContextualPatternOptions(options, pattern);
auto subOptions = options;
subOptions.setContext(None);
@@ -1061,8 +1107,8 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
if (tupleType->getNumElements() == 1) {
auto element = tupleType->getElement(0);
sub = coercePatternToType(
pattern.forSubPattern(sub, /*retainTopLevel=*/true), element.getType(),
subOptions);
pattern.forSubPattern(sub, /*retainTopLevel=*/true),
element.getType(), subOptions, tryRewritePattern);
if (!sub)
return nullptr;
TuplePatternElt elt(element.getName(), SourceLoc(), sub);
@@ -1077,7 +1123,8 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
}
sub = coercePatternToType(
pattern.forSubPattern(sub, /*retainTopLevel=*/false), type, subOptions);
pattern.forSubPattern(sub, /*retainTopLevel=*/false), type, subOptions,
tryRewritePattern);
if (!sub)
return nullptr;
@@ -1090,7 +1137,8 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
Pattern *sub = VP->getSubPattern();
sub = coercePatternToType(
pattern.forSubPattern(sub, /*retainTopLevel=*/false), type, subOptions);
pattern.forSubPattern(sub, /*retainTopLevel=*/false), type, subOptions,
tryRewritePattern);
if (!sub)
return nullptr;
VP->setSubPattern(sub);
@@ -1123,7 +1171,8 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
Pattern *sub = TP->getSubPattern();
sub = coercePatternToType(
pattern.forSubPattern(sub, /*retainTopLevel=*/false), type,
subOptions | TypeResolutionFlags::FromNonInferredPattern);
subOptions | TypeResolutionFlags::FromNonInferredPattern,
tryRewritePattern);
if (!sub)
return nullptr;
@@ -1212,9 +1261,9 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
auto decayToParen = [&]() -> Pattern * {
assert(canDecayToParen);
Pattern *sub = TP->getElement(0).getPattern();
sub = TypeChecker::coercePatternToType(
sub = coercePatternToType(
pattern.forSubPattern(sub, /*retainTopLevel=*/false), type,
subOptions);
subOptions, tryRewritePattern);
if (!sub)
return nullptr;
@@ -1271,7 +1320,7 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
auto sub = coercePatternToType(
pattern.forSubPattern(elt.getPattern(), /*retainTopLevel=*/false),
CoercionType, subOptions);
CoercionType, subOptions, tryRewritePattern);
if (!sub)
return nullptr;
@@ -1291,37 +1340,9 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
assert(cast<ExprPattern>(P)->isResolved()
&& "coercing unresolved expr pattern!");
auto *EP = cast<ExprPattern>(P);
if (type->isBool()) {
// The type is Bool.
// Check if the pattern is a Bool literal
if (auto *BLE = dyn_cast<BooleanLiteralExpr>(
EP->getSubExpr()->getSemanticsProvidingExpr())) {
P = new (Context) BoolPattern(BLE->getLoc(), BLE->getValue());
P->setType(type);
return P;
}
}
// case nil is equivalent to .none when switching on Optionals.
if (auto *NLE = dyn_cast<NilLiteralExpr>(EP->getSubExpr())) {
if (type->getOptionalObjectType()) {
auto *NoneEnumElement = Context.getOptionalNoneDecl();
auto *BaseTE = TypeExpr::createImplicit(type, Context);
P = new (Context) EnumElementPattern(
BaseTE, NLE->getLoc(), DeclNameLoc(NLE->getLoc()),
NoneEnumElement->createNameRef(), NoneEnumElement, nullptr, dc);
return TypeChecker::coercePatternToType(
pattern.forSubPattern(P, /*retainTopLevel=*/true), type, options);
} else {
// ...but for non-optional types it can never match! Diagnose it.
diags.diagnose(NLE->getLoc(),
diag::value_type_comparison_with_nil_illegal, type)
.warnUntilSwiftVersion(6);
if (type->getASTContext().isSwiftVersionAtLeast(6))
return nullptr;
}
}
if (auto P = trySimplifyExprPattern(EP, type))
return P.get();
if (TypeChecker::typeCheckExprPattern(EP, dc, type))
return nullptr;
@@ -1370,7 +1391,8 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
P = sub;
return coercePatternToType(
pattern.forSubPattern(P, /*retainTopLevel=*/true), type, options);
pattern.forSubPattern(P, /*retainTopLevel=*/true), type, options,
tryRewritePattern);
}
CheckedCastKind castKind = TypeChecker::typeCheckCheckedCast(
@@ -1419,7 +1441,8 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
sub = coercePatternToType(
pattern.forSubPattern(sub, /*retainTopLevel=*/false),
IP->getCastType(),
subOptions | TypeResolutionFlags::FromNonInferredPattern);
subOptions | TypeResolutionFlags::FromNonInferredPattern,
tryRewritePattern);
if (!sub)
return nullptr;
@@ -1457,7 +1480,7 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
EEP->getEndLoc());
return coercePatternToType(
pattern.forSubPattern(P, /*retainTopLevel=*/true), type,
options);
options, tryRewritePattern);
} else {
diags.diagnose(EEP->getLoc(),
diag::enum_element_pattern_member_not_found,
@@ -1472,7 +1495,7 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
Context, EEP->getUnresolvedOriginalExpr(), dc);
return coercePatternToType(
pattern.forSubPattern(P, /*retainTopLevel=*/true), type,
options);
options, tryRewritePattern);
}
}
}
@@ -1595,7 +1618,7 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
sub = coercePatternToType(
pattern.forSubPattern(sub, /*retainTopLevel=*/false), elementType,
newSubOptions);
newSubOptions, tryRewritePattern);
if (!sub)
return nullptr;
@@ -1630,7 +1653,7 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
newSubOptions |= TypeResolutionFlags::FromNonInferredPattern;
sub = coercePatternToType(
pattern.forSubPattern(sub, /*retainTopLevel=*/false), elementType,
newSubOptions);
newSubOptions, tryRewritePattern);
if (!sub)
return nullptr;
EEP->setSubPattern(sub);
@@ -1674,7 +1697,7 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern,
newSubOptions |= TypeResolutionFlags::FromNonInferredPattern;
sub = coercePatternToType(
pattern.forSubPattern(sub, /*retainTopLevel=*/false), elementType,
newSubOptions);
newSubOptions, tryRewritePattern);
if (!sub)
return nullptr;

View File

@@ -733,15 +733,26 @@ Pattern *resolvePattern(Pattern *P, DeclContext *dc, bool isStmtCondition);
/// unbound generic types.
Type typeCheckPattern(ContextualPattern pattern);
/// Attempt to simplify an ExprPattern into a BoolPattern or
/// OptionalSomePattern. Returns \c nullptr if the pattern could not be
/// simplified.
NullablePtr<Pattern> trySimplifyExprPattern(ExprPattern *EP, Type patternTy);
/// Coerce a pattern to the given type.
///
/// \param pattern The contextual pattern.
/// \param type the type to coerce the pattern to.
/// \param options Options that control the coercion.
/// \param tryRewritePattern A function that attempts to externally rewrite
/// the given pattern. This is used by the constraint system to take over
/// rewriting for ExprPatterns.
///
/// \returns the coerced pattern, or nullptr if the coercion failed.
Pattern *coercePatternToType(ContextualPattern pattern, Type type,
TypeResolutionOptions options);
Pattern *coercePatternToType(
ContextualPattern pattern, Type type, TypeResolutionOptions options,
llvm::function_ref<Optional<Pattern *>(Pattern *, Type)> tryRewritePattern =
[](Pattern *, Type) { return None; });
bool typeCheckExprPattern(ExprPattern *EP, DeclContext *DC, Type type);
/// Coerce the specified parameter list of a ClosureExpr to the specified

View File

@@ -568,3 +568,86 @@ struct TestIUOMatchOp {
struct TestRecursiveVarRef<T> {
lazy var e: () -> Int = {e}()
}
func testMultiStmtClosureExprPattern(_ x: Int) {
if case { (); return x }() = x {}
}
func testExprPatternIsolation() {
// We type-check ExprPatterns separately, so these are illegal.
if case 0 = nil {} // expected-error {{'nil' requires a contextual type}}
let _ = {
if case 0 = nil {} // expected-error {{'nil' requires a contextual type}}
}
for case 0 in nil {} // expected-error {{'nil' requires a contextual type}}
for case 0 in [nil] {}
// expected-error@-1 {{type 'Any' cannot conform to 'Equatable'}}
// expected-note@-2 {{only concrete types such as structs, enums and classes can conform to protocols}}
// expected-note@-3 {{requirement from conditional conformance of 'Any?' to 'Equatable'}}
// Though we will try Double for an integer literal...
let d: Double = 0
if case d = 0 {}
let _ = {
if case d = 0 {}
}
for case d in [0] {}
// But not Float
let f: Float = 0
if case f = 0 {} // expected-error {{expression pattern of type 'Float' cannot match values of type 'Int'}}
let _ = {
if case f = 0 {} // expected-error {{expression pattern of type 'Float' cannot match values of type 'Int'}}
}
for case f in [0] {} // expected-error {{expression pattern of type 'Float' cannot match values of type 'Int'}}
enum MultiPayload<T: Equatable>: Equatable {
case e(T, T)
static func f(_ x: T, _ y: T) -> Self { .e(x, y) }
}
enum E: Equatable {
case a, b
static var c: E { .a }
static var d: E { .b }
}
func produceMultiPayload<T>() -> MultiPayload<T> { fatalError() }
// We type-check ExprPatterns left to right, so only one of these works.
if case .e(0.0, 0) = produceMultiPayload() {}
if case .e(0, 0.0) = produceMultiPayload() {} // expected-error {{expression pattern of type 'Double' cannot match values of type 'Int'}}
for case .e(0.0, 0) in [produceMultiPayload()] {}
for case .e(0, 0.0) in [produceMultiPayload()] {} // expected-error {{expression pattern of type 'Double' cannot match values of type 'Int'}}
// Same, because although it's a top-level ExprPattern, we don't resolve
// that until during solving.
if case .f(0.0, 0) = produceMultiPayload() {}
if case .f(0, 0.0) = produceMultiPayload() {} // expected-error {{expression pattern of type 'Double' cannot match values of type 'Int'}}
if case .e(5, nil) = produceMultiPayload() {} // expected-warning {{type 'Int' is not optional, value can never be nil; this is an error in Swift 6}}
// FIXME: Bad error (https://github.com/apple/swift/issues/64279)
if case .e(nil, 0) = produceMultiPayload() {}
// expected-error@-1 {{expression pattern of type 'String' cannot match values of type 'Substring'}}
// expected-note@-2 {{overloads for '~=' exist with these partially matching parameter lists}}
if case .e(5, nil) = produceMultiPayload() as MultiPayload<Int?> {}
if case .e(nil, 0) = produceMultiPayload() as MultiPayload<Int?> {}
// Enum patterns are solved together.
if case .e(E.a, .b) = produceMultiPayload() {}
if case .e(.a, E.b) = produceMultiPayload() {}
// These also work because they start life as EnumPatterns.
if case .e(E.c, .d) = produceMultiPayload() {}
if case .e(.c, E.d) = produceMultiPayload() {}
for case .e(E.c, .d) in [produceMultiPayload()] {}
for case .e(.c, E.d) in [produceMultiPayload()] {}
// Silly, but allowed.
if case 0: Int? = 0 {} // expected-warning {{non-optional expression of type 'Int' used in a check for optionals}}
var opt: Int?
if case opt = 0 {}
}

View File

@@ -12,7 +12,6 @@ func test(value: MyEnum) {
switch value {
case .first(true):
// expected-error@-1 {{expression pattern of type 'Bool' cannot match values of type 'String'}}
// expected-note@-2 {{overloads for '~=' exist with these partially matching parameter lists: (Substring, String)}}
break
default:
break

View File

@@ -0,0 +1,18 @@
// RUN: %target-typecheck-verify-swift
// rdar://105782480
enum MyEnum {
case second(Int?)
}
func takeClosure(_ x: () -> Void) {}
func foo(value: MyEnum) {
takeClosure {
switch value {
case .second(let drag).invalid:
// expected-error@-1 {{value of type 'MyEnum' has no member 'invalid'}}
break
}
}
}

View File

@@ -0,0 +1,11 @@
// RUN: %target-typecheck-verify-swift
enum E: Error { case e }
// rdar://106598067 Make sure we don't crash.
// FIXME: Bad diagnostic (the issue is that it should be written 'as', not 'as?')
let fn = {
// expected-error@-1 {{unable to infer closure type in the current context}}
do {} catch let x as? E {}
// expected-warning@-1 {{'catch' block is unreachable because no errors are thrown in 'do' block}}
}

View File

@@ -0,0 +1,15 @@
// RUN: %target-typecheck-verify-swift
// rdar://109419240 Make sure we don't crash
enum E { // expected-note {{'E' declared here}}
case e(Int)
}
func foo(_ arr: [E]) -> Int {
return arr.reduce(0) { (total, elem) -> Int in
switch elem {
case let e(x): // expected-error {{cannot find 'e' in scope; did you mean 'E'?}}
return total + x
}
}
}

View File

@@ -659,8 +659,6 @@ struct MyView {
}
@TupleBuilder var invalidCaseWithoutDot: some P {
// expected-error@-1 {{return type of property 'invalidCaseWithoutDot' requires that 'Either<Int, Int>' conform to 'P'}}
// expected-note@-2 {{opaque return type declared here}}
switch Optional.some(1) {
case none: 42 // expected-error {{cannot find 'none' in scope}}
case .some(let x):

View File

@@ -15,7 +15,6 @@ func foo(_ x: String) -> Int {
if .random() {
switch x {
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
0
default:
1
@@ -29,7 +28,6 @@ func bar(_ x: String) -> Int {
case 0:
switch x {
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
0
default:
1
@@ -44,7 +42,6 @@ func baz(_ x: String) -> Int {
do {
switch x {
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
0
default:
1
@@ -57,7 +54,6 @@ func qux(_ x: String) -> Int {
for _ in 0 ... 0 {
switch x {
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
0
default:
1

View File

@@ -710,3 +710,38 @@ let _: DispatchTime = .#^UNRESOLVED_FUNCTION_CALL^#now() + 0.2
// UNRESOLVED_FUNCTION_CALL: Begin completions, 2 items
// UNRESOLVED_FUNCTION_CALL-DAG: Decl[StaticMethod]/CurrNominal/Flair[ExprSpecific]/TypeRelation[Convertible]: now()[#DispatchTime#];
// UNRESOLVED_FUNCTION_CALL-DAG: Decl[Constructor]/CurrNominal/TypeRelation[Convertible]: init()[#DispatchTime#];
func id<T>(_ x: T) -> T { x }
func testNestedExprPatternCompletion(_ x: SomeEnum1) {
// Multi-statement closures have different type-checking code paths,
// so we need to test both.
let fn = {
switch x {
case id(.#^UNRESOLVED_NESTED1^#):
// UNRESOLVED_NESTED1: Begin completions, 3 items
// UNRESOLVED_NESTED1: Decl[EnumElement]/CurrNominal/Flair[ExprSpecific]/TypeRelation[Convertible]: South[#SomeEnum1#]; name=South
// UNRESOLVED_NESTED1: Decl[EnumElement]/CurrNominal/Flair[ExprSpecific]/TypeRelation[Convertible]: North[#SomeEnum1#]; name=North
// UNRESOLVED_NESTED1: Decl[InstanceMethod]/CurrNominal/TypeRelation[Invalid]: hash({#(self): SomeEnum1#})[#(into: inout Hasher) -> Void#]; name=hash(:)
break
}
if case id(.#^UNRESOLVED_NESTED2^#) = x {}
// UNRESOLVED_NESTED2: Begin completions, 3 items
// UNRESOLVED_NESTED2: Decl[EnumElement]/CurrNominal/Flair[ExprSpecific]/TypeRelation[Convertible]: South[#SomeEnum1#]; name=South
// UNRESOLVED_NESTED2: Decl[EnumElement]/CurrNominal/Flair[ExprSpecific]/TypeRelation[Convertible]: North[#SomeEnum1#]; name=North
// UNRESOLVED_NESTED2: Decl[InstanceMethod]/CurrNominal/TypeRelation[Invalid]: hash({#(self): SomeEnum1#})[#(into: inout Hasher) -> Void#]; name=hash(:)
}
switch x {
case id(.#^UNRESOLVED_NESTED3^#):
// UNRESOLVED_NESTED3: Begin completions, 3 items
// UNRESOLVED_NESTED3: Decl[EnumElement]/CurrNominal/Flair[ExprSpecific]/TypeRelation[Convertible]: South[#SomeEnum1#]; name=South
// UNRESOLVED_NESTED3: Decl[EnumElement]/CurrNominal/Flair[ExprSpecific]/TypeRelation[Convertible]: North[#SomeEnum1#]; name=North
// UNRESOLVED_NESTED3: Decl[InstanceMethod]/CurrNominal/TypeRelation[Invalid]: hash({#(self): SomeEnum1#})[#(into: inout Hasher) -> Void#]; name=hash(:)
break
}
if case id(.#^UNRESOLVED_NESTED4^#) = x {}
// UNRESOLVED_NESTED4: Begin completions, 3 items
// UNRESOLVED_NESTED4: Decl[EnumElement]/CurrNominal/Flair[ExprSpecific]/TypeRelation[Convertible]: South[#SomeEnum1#]; name=South
// UNRESOLVED_NESTED4: Decl[EnumElement]/CurrNominal/Flair[ExprSpecific]/TypeRelation[Convertible]: North[#SomeEnum1#]; name=North
// UNRESOLVED_NESTED4: Decl[InstanceMethod]/CurrNominal/TypeRelation[Invalid]: hash({#(self): SomeEnum1#})[#(into: inout Hasher) -> Void#]; name=hash(:)
}

View File

@@ -605,3 +605,14 @@ struct TestLValues {
opt![keyPath: kp] = switch Bool.random() { case true: 1 case false: throw Err() }
}
}
func exprPatternInClosure() {
let f: (Int) -> Void = { i in
switch i {
case i:
()
default:
()
}
}
}

View File

@@ -48,7 +48,6 @@ func testAmbiguousStringComparisons(s: String) {
// Shouldn't suggest 'as' in a pattern-matching context, as opposed to all these other situations
if case nsString = "" {} // expected-error{{expression pattern of type 'NSString' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists: (Substring, String)}}
}
func testStringDeprecation(hello: String) {

View File

@@ -169,7 +169,8 @@ func thirteen() {
thirteen_helper { (a) in // expected-error {{invalid conversion from throwing function of type '(Thirteen) throws -> ()' to non-throwing function type '(Thirteen) -> ()'}}
do {
try thrower()
} catch a {
// FIXME: Bad diagnostic (https://github.com/apple/swift/issues/63459)
} catch a { // expected-error {{binary operator '~=' cannot be applied to two 'any Error' operands}}
}
}
}

View File

@@ -10,7 +10,6 @@ struct ContentView: View {
var body: some View {
switch currentPage {
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists: (Substring, String)}}
Text("1")
default:
Text("default")

View File

@@ -9,12 +9,24 @@ let _: () -> Void = {
let _: () -> Void = {
for case (0)? in [a] {}
// expected-error@-1 {{pattern cannot match values of type 'Any?'}}
if case (0, 0) = a {}
// expected-error@-1 {{cannot convert value of type 'Any?' to specified type '(_, _)}}
}
let _: () -> Void = {
for case (0)? in [a] {}
// expected-error@-1 {{pattern cannot match values of type 'Any?'}}
for case (0, 0) in [a] {}
// expected-error@-1 {{cannot convert value of type 'Any?' to expected element type '(_, _)'}}
}
let _: () -> Void = {
if case (0, 0) = a {}
// expected-error@-1 {{cannot convert value of type 'Any?' to specified type '(Int, Int)'}}
for case (0)? in [a] {}
}
let _: () -> Void = {
for case (0, 0) in [a] {}
// expected-error@-1 {{cannot convert value of type 'Any?' to expected element type '(Int, Int)'}}
for case (0)? in [a] {}
}