[Concurrency] Allow overload 'async' with non-async and disambiguate uses.

Allow an 'async' function to overload a non-'async' one, e.g.,

    func performOperation(_: String) throws -> String { ... }
    func performOperation(_: String) async throws -> String { ... }

Extend the scoring system in the type checker to penalize cases where
code in an asynchronous context (e.g., an `async` function or closure)
references an asychronous declaration or vice-versa, so that
asynchronous code prefers the 'async' functions and synchronous code
prefers the non-'async' functions. This allows the above overloading
to be a legitimate approach to introducing asynchronous functionality
to existing (blocking) APIs and letting code migrate over.
This commit is contained in:
Doug Gregor
2020-09-08 16:51:10 -07:00
parent 6eebf614ea
commit b5759c9fd9
9 changed files with 330 additions and 219 deletions

View File

@@ -2151,6 +2151,238 @@ std::pair<Type, bool> ConstraintSystem::adjustTypeOfOverloadReference(
llvm_unreachable("Unhandled OverloadChoiceKind in switch.");
}
/// Whether the declaration is considered 'async'.
static bool isDeclAsync(ValueDecl *value) {
if (auto func = dyn_cast<AbstractFunctionDecl>(value))
return func->isAsyncContext();
return false;
}
/// Walk a closure AST to determine its effects.
///
/// \returns a function's extended info describing the effects, as
/// determined syntactically.
FunctionType::ExtInfo ConstraintSystem::closureEffects(ClosureExpr *expr) {
auto known = closureEffectsCache.find(expr);
if (known != closureEffectsCache.end())
return known->second;
// A walker that looks for 'try' and 'throw' expressions
// that aren't nested within closures, nested declarations,
// or exhaustive catches.
class FindInnerThrows : public ASTWalker {
ConstraintSystem &CS;
DeclContext *DC;
bool FoundThrow = false;
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
// If we've found a 'try', record it and terminate the traversal.
if (isa<TryExpr>(expr)) {
FoundThrow = true;
return { false, nullptr };
}
// Don't walk into a 'try!' or 'try?'.
if (isa<ForceTryExpr>(expr) || isa<OptionalTryExpr>(expr)) {
return { false, expr };
}
// Do not recurse into other closures.
if (isa<ClosureExpr>(expr))
return { false, expr };
return { true, expr };
}
bool walkToDeclPre(Decl *decl) override {
// Do not walk into function or type declarations.
if (!isa<PatternBindingDecl>(decl))
return false;
return true;
}
bool isSyntacticallyExhaustive(DoCatchStmt *stmt) {
for (auto catchClause : stmt->getCatches()) {
for (auto &LabelItem : catchClause->getMutableCaseLabelItems()) {
if (isSyntacticallyExhaustive(catchClause->getStartLoc(),
LabelItem))
return true;
}
}
return false;
}
bool isSyntacticallyExhaustive(SourceLoc CatchLoc,
CaseLabelItem &LabelItem) {
// If it's obviously non-exhaustive, great.
if (LabelItem.getGuardExpr())
return false;
// If we can show that it's exhaustive without full
// type-checking, great.
if (LabelItem.isSyntacticallyExhaustive())
return true;
// Okay, resolve the pattern.
Pattern *pattern = LabelItem.getPattern();
if (!LabelItem.isPatternResolved()) {
pattern = TypeChecker::resolvePattern(pattern, CS.DC,
/*isStmtCondition*/false);
if (!pattern) return false;
// Save that aside while we explore the type.
LabelItem.setPattern(pattern, /*resolved=*/true);
}
// Require the pattern to have a particular shape: a number
// of is-patterns applied to an irrefutable pattern.
pattern = pattern->getSemanticsProvidingPattern();
while (auto isp = dyn_cast<IsPattern>(pattern)) {
const Type castType = TypeResolution::forContextual(
CS.DC, TypeResolverContext::InExpression,
/*unboundTyOpener*/ nullptr)
.resolveType(isp->getCastTypeRepr());
if (castType->hasError()) {
return false;
}
if (!isp->hasSubPattern()) {
pattern = nullptr;
break;
} else {
pattern = isp->getSubPattern()->getSemanticsProvidingPattern();
}
}
if (pattern && pattern->isRefutablePattern()) {
return false;
}
// Okay, now it should be safe to coerce the pattern.
// Pull the top-level pattern back out.
pattern = LabelItem.getPattern();
Type exnType = CS.getASTContext().getErrorDecl()->getDeclaredInterfaceType();
if (!exnType)
return false;
auto contextualPattern =
ContextualPattern::forRawPattern(pattern, DC);
pattern = TypeChecker::coercePatternToType(
contextualPattern, exnType, TypeResolverContext::InExpression);
if (!pattern)
return false;
LabelItem.setPattern(pattern, /*resolved=*/true);
return LabelItem.isSyntacticallyExhaustive();
}
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
// If we've found a 'throw', record it and terminate the traversal.
if (isa<ThrowStmt>(stmt)) {
FoundThrow = true;
return { false, nullptr };
}
// Handle do/catch differently.
if (auto doCatch = dyn_cast<DoCatchStmt>(stmt)) {
// Only walk into the 'do' clause of a do/catch statement
// if the catch isn't syntactically exhaustive.
if (!isSyntacticallyExhaustive(doCatch)) {
if (!doCatch->getBody()->walk(*this))
return { false, nullptr };
}
// Walk into all the catch clauses.
for (auto catchClause : doCatch->getCatches()) {
if (!catchClause->walk(*this))
return { false, nullptr };
}
// We've already walked all the children we care about.
return { false, stmt };
}
return { true, stmt };
}
public:
FindInnerThrows(ConstraintSystem &cs, DeclContext *dc)
: CS(cs), DC(dc) {}
bool foundThrow() { return FoundThrow; }
};
// A walker that looks for 'async' and 'await' expressions
// that aren't nested within closures or nested declarations.
class FindInnerAsync : public ASTWalker {
bool FoundAsync = false;
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
// If we've found an 'await', record it and terminate the traversal.
if (isa<AwaitExpr>(expr)) {
FoundAsync = true;
return { false, nullptr };
}
// Do not recurse into other closures.
if (isa<ClosureExpr>(expr))
return { false, expr };
return { true, expr };
}
bool walkToDeclPre(Decl *decl) override {
// Do not walk into function or type declarations.
if (!isa<PatternBindingDecl>(decl))
return false;
return true;
}
public:
bool foundAsync() { return FoundAsync; }
};
// If either 'throws' or 'async' was explicitly specified, use that
// set of effects.
bool throws = expr->getThrowsLoc().isValid();
bool async = expr->getAsyncLoc().isValid();
if (throws || async) {
return ASTExtInfoBuilder()
.withThrows(throws)
.withAsync(async)
.build();
}
// Scan the body to determine the effects.
auto body = expr->getBody();
if (!body)
return FunctionType::ExtInfo();
auto throwFinder = FindInnerThrows(*this, expr);
body->walk(throwFinder);
auto asyncFinder = FindInnerAsync();
body->walk(asyncFinder);
auto result = ASTExtInfoBuilder()
.withThrows(throwFinder.foundThrow())
.withAsync(asyncFinder.foundAsync())
.build();
closureEffectsCache[expr] = result;
return result;
}
bool ConstraintSystem::isAsynchronousContext(DeclContext *dc) {
if (auto func = dyn_cast<AbstractFunctionDecl>(dc))
return isDeclAsync(func);
if (auto closure = dyn_cast<ClosureExpr>(dc))
return closureEffects(closure).isAsync();
return false;
}
void ConstraintSystem::bindOverloadType(
const SelectedOverload &overload, Type boundType,
ConstraintLocator *locator, DeclContext *useDC,
@@ -2475,6 +2707,11 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator,
assert(!refType->hasTypeParameter() && "Cannot have a dependent type here");
if (auto *decl = choice.getDeclOrNull()) {
// If we're choosing an asynchronous declaration within a synchronous
// context, or vice-versa, increase the async/async mismatch score.
if (isAsynchronousContext(useDC) != isDeclAsync(decl))
increaseScore(SK_AsyncSyncMismatch);
// If we're binding to an init member, the 'throws' need to line up
// between the bound and reference types.
if (auto CD = dyn_cast<ConstructorDecl>(decl)) {