[Sema] Diagnose generic parameter contextual inference ambiguity between function call result and closure argument

This commit is contained in:
Luciano Almeida
2021-06-12 23:19:33 -03:00
parent 528764c5fc
commit c02f30f5c1
2 changed files with 143 additions and 2 deletions

View File

@@ -3709,6 +3709,141 @@ static bool diagnoseAmbiguity(
return diagnosed;
}
using Fix = std::pair<const Solution *, const ConstraintFix *>;
// Attempts to diagnose function call ambiguities of types inferred for a result
// generic parameter from contextual type and a closure argument that
// conflicting infer a different type for the same argument. Example:
// func callit<T>(_ f: () -> T) -> T {
// f()
// }
//
// func context() -> Int {
// callit {
// print("hello")
// }
// }
// Where generic argument `T` can be inferred both as `Int` from contextual
// result and `Void` from the closure argument result.
static bool
diagnoseContextualFunctionCallGenericAmbiguity(ConstraintSystem &cs,
ArrayRef<Fix> contextualFixes,
ArrayRef<Fix> allFixes) {
if (contextualFixes.empty())
return false;
auto contextualFix = contextualFixes.front();
if (!std::all_of(contextualFixes.begin() + 1, contextualFixes.end(),
[&contextualFix](Fix fix) {
return fix.second->getLocator() ==
contextualFix.second->getLocator();
}))
return false;
auto fixLocator = contextualFix.second->getLocator();
auto contextualAnchor = fixLocator->getAnchor();
auto *AE = getAsExpr<ApplyExpr>(contextualAnchor);
// All contextual failures anchored on the same function call.
if (!AE)
return false;
auto fnLocator = cs.getConstraintLocator(AE->getSemanticFn());
auto overload = contextualFix.first->getOverloadChoiceIfAvailable(fnLocator);
if (!overload)
return false;
auto applyFnType = overload->openedType->castTo<FunctionType>();
auto resultTypeVar = applyFnType->getResult()->getAs<TypeVariableType>();
if (!resultTypeVar)
return false;
auto *GP = resultTypeVar->getImpl().getGenericParameter();
if (!GP)
return false;
auto typeParamResultInvolvesTypeVar =
[&applyFnType](unsigned paramIdx, TypeVariableType *typeVar) {
auto param = applyFnType->getParams()[paramIdx];
auto paramType = param.getParameterType()->castTo<FunctionType>();
bool contains = false;
paramType->getResult().visit([&](Type ty) {
if (ty->isEqual(typeVar))
contains = true;
});
return contains;
};
llvm::SmallVector<ClosureExpr *, 4> closureArguments;
// A single closure argument.
if (auto *closure =
getAsExpr<ClosureExpr>(AE->getArg()->getSemanticsProvidingExpr())) {
if (typeParamResultInvolvesTypeVar(/*paramIdx=*/0, resultTypeVar))
closureArguments.push_back(closure);
} else if (auto *argTuple = getAsExpr<TupleExpr>(AE->getArg())) {
for (auto i : indices(argTuple->getElements())) {
auto arg = argTuple->getElements()[i];
auto *closure = getAsExpr<ClosureExpr>(arg);
if (closure &&
typeParamResultInvolvesTypeVar(/*paramIdx=*/i, resultTypeVar)) {
closureArguments.push_back(closure);
}
}
}
// If no closure result's involves the generic parameter, just bail because we
// won't find a conflict.
if (closureArguments.empty())
return false;
// At least one closure where result type involves the generic parameter.
// So let's try to collect the set of fixed types for the generic parameter
// from all the closure contextual fix/solutions and if there are more than
// one fixed type diagnose it.
llvm::SmallSetVector<Type, 4> genericParamInferredTypes;
for (auto &fix : contextualFixes)
genericParamInferredTypes.insert(fix.first->getFixedType(resultTypeVar));
if (llvm::all_of(allFixes, [&](Fix fix) {
auto fixLocator = fix.second->getLocator();
if (fixLocator->isForContextualType())
return true;
if (!(fix.second->getKind() == FixKind::ContextualMismatch ||
fix.second->getKind() == FixKind::AllowTupleTypeMismatch))
return false;
auto anchor = fixLocator->getAnchor();
if (!(anchor == contextualAnchor ||
fixLocator->isLastElement<LocatorPathElt::ClosureResult>() ||
fixLocator->isLastElement<LocatorPathElt::ClosureBody>()))
return false;
genericParamInferredTypes.insert(
fix.first->getFixedType(resultTypeVar));
return true;
})) {
if (genericParamInferredTypes.size() <= 1)
return false;
auto &DE = cs.getASTContext().Diags;
llvm::SmallString<64> arguments;
llvm::raw_svector_ostream OS(arguments);
interleave(
genericParamInferredTypes,
[&](Type argType) { OS << "'" << argType << "'"; },
[&OS] { OS << " vs. "; });
DE.diagnose(AE->getLoc(),
diag::conflicting_inferred_generic_parameter_result_and_closure,
GP, OS.str());
return true;
}
return false;
}
bool ConstraintSystem::diagnoseAmbiguityWithFixes(
SmallVectorImpl<Solution> &solutions) {
if (solutions.empty())
@@ -3761,8 +3896,6 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
// d. Diagnose remaining (uniqued based on kind + locator) fixes
// iff they appear in all of the solutions.
using Fix = std::pair<const Solution *, const ConstraintFix *>;
llvm::SmallSetVector<Fix, 4> fixes;
for (auto &solution : solutions) {
for (auto *fix : solution.Fixes)
@@ -3837,6 +3970,10 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
}
}
if (!diagnosed && diagnoseContextualFunctionCallGenericAmbiguity(
*this, contextualFixes, fixes.getArrayRef()))
return true;
return diagnosed;
}