Properly substitute coroutines

This commit is contained in:
Anton Korobeynikov
2024-08-14 23:45:58 -07:00
parent 6c504bb9fa
commit 366b552286
8 changed files with 109 additions and 43 deletions

View File

@@ -3969,6 +3969,9 @@ public:
/// Return the function type setting sendable to \p newValue.
AnyFunctionType *withSendable(bool newValue) const;
/// Return the function type without yields (and coroutine flag)
AnyFunctionType *getWithoutYields() const;
/// True if the parameter declaration it is attached to is guaranteed
/// to not persist the closure for longer than the duration of the call.
bool isNoEscape() const {

View File

@@ -1524,13 +1524,9 @@ public:
/// type, return the abstraction pattern for one of its argument types.
AbstractionPattern getParameterizedProtocolArgType(unsigned i) const;
/// Given that the value being abstracted is a yield result type,
/// return the abstraction pattern for corresponding type.
AbstractionPattern getYieldResultType() const;
/// Given that the value being abstracted is a function, return the
/// abstraction pattern for its result type.
AbstractionPattern getFunctionResultType() const;
AbstractionPattern getFunctionResultType(bool withoutYields = false) const;
/// Given that the value being abstracted is a function, return the
/// abstraction pattern for its thrown error type.

View File

@@ -11538,7 +11538,7 @@ Type FuncDecl::getResultInterfaceTypeWithoutYields() const {
Type eltTy = elt.getType();
if (eltTy->is<YieldResultType>())
continue;
elements.push_back(eltTy);
elements.push_back(elt);
}
// Handle vanishing tuples -- flatten to produce the

View File

@@ -4752,6 +4752,39 @@ AnyFunctionType *AnyFunctionType::withSendable(bool newValue) const {
return withExtInfo(info);
}
AnyFunctionType *AnyFunctionType::getWithoutYields() const {
auto resultType = getResult();
if (auto *tupleResTy = resultType->getAs<TupleType>()) {
// Strip @yield results on the first level of tuple
SmallVector<TupleTypeElt, 4> elements;
for (const auto &elt : tupleResTy->getElements()) {
Type eltTy = elt.getType();
if (eltTy->is<YieldResultType>())
continue;
elements.push_back(elt);
}
// Handle vanishing tuples -- flatten to produce the
// normal coroutine result type
if (elements.size() == 1 && isCoroutine())
resultType = elements[0].getType();
else
resultType = TupleType::get(elements, getASTContext());
} else if (resultType->is<YieldResultType>()) {
resultType = TupleType::getEmpty(getASTContext());
}
auto noCoroExtInfo = getExtInfo().intoBuilder()
.withCoroutine(false)
.build();
if (isa<FunctionType>(this))
return FunctionType::get(getParams(), resultType, noCoroExtInfo);
assert(isa<GenericFunctionType>(this));
return GenericFunctionType::get(getOptGenericSignature(), getParams(),
resultType, noCoroExtInfo);
}
std::optional<Type> AnyFunctionType::getEffectiveThrownErrorType() const {
// A non-throwing function... has no thrown interface type.
if (!isThrowing())

View File

@@ -1040,13 +1040,6 @@ AbstractionPattern::getParameterizedProtocolArgType(unsigned argIndex) const {
cast<ParameterizedProtocolType>(getType()).getArgs()[argIndex]);
}
AbstractionPattern AbstractionPattern::getYieldResultType() const {
assert(getKind() == Kind::Type);
return AbstractionPattern(getGenericSubstitutions(),
getGenericSignature(),
cast<YieldResultType>(getType()).getResultType());
}
AbstractionPattern AbstractionPattern::removingMoveOnlyWrapper() const {
switch (getKind()) {
case Kind::Invalid:
@@ -1180,11 +1173,15 @@ AbstractionPattern::getCXXMethodSelfPattern(CanType selfType) const {
getGenericSignatureForFunctionComponent(), selfType);
}
static CanType getResultType(CanType type) {
return cast<AnyFunctionType>(type).getResult();
static CanType getResultType(CanType type, bool withoutYields) {
auto aft = cast<AnyFunctionType>(type);
if (withoutYields)
aft = CanAnyFunctionType(aft->getWithoutYields());
return aft.getResult();
}
AbstractionPattern AbstractionPattern::getFunctionResultType() const {
AbstractionPattern AbstractionPattern::getFunctionResultType(bool withoutYields) const {
switch (getKind()) {
case Kind::Invalid:
llvm_unreachable("querying invalid abstraction pattern!");
@@ -1198,7 +1195,7 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
return AbstractionPattern::getOpaque();
return AbstractionPattern(getGenericSubstitutions(),
getGenericSignatureForFunctionComponent(),
getResultType(getType()));
getResultType(getType(), withoutYields));
case Kind::Discard:
llvm_unreachable("don't need to discard function abstractions yet");
case Kind::ClangType:
@@ -1207,33 +1204,34 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
auto clangFunctionType = getClangFunctionType(getClangType());
return AbstractionPattern(getGenericSubstitutions(),
getGenericSignatureForFunctionComponent(),
getResultType(getType()),
getResultType(getType(), withoutYields),
clangFunctionType->getReturnType().getTypePtr());
}
case Kind::CXXMethodType:
case Kind::PartialCurriedCXXMethodType:
return AbstractionPattern(getGenericSubstitutions(),
getGenericSignatureForFunctionComponent(),
getResultType(getType()),
getResultType(getType(), withoutYields),
getCXXMethod()->getReturnType().getTypePtr());
case Kind::CurriedObjCMethodType:
return getPartialCurriedObjCMethod(
getGenericSubstitutions(),
getGenericSignatureForFunctionComponent(),
getResultType(getType()),
getResultType(getType(), withoutYields),
getObjCMethod(),
getEncodedForeignInfo());
case Kind::CurriedCFunctionAsMethodType:
return getPartialCurriedCFunctionAsMethod(
getGenericSubstitutions(),
getGenericSignatureForFunctionComponent(),
getResultType(getType()),
getResultType(getType(), withoutYields),
getClangType(),
getImportAsMemberStatus());
case Kind::CurriedCXXMethodType:
return getPartialCurriedCXXMethod(getGenericSubstitutions(),
getGenericSignatureForFunctionComponent(),
getResultType(getType()), getCXXMethod(),
getResultType(getType(), withoutYields),
getCXXMethod(),
getImportAsMemberStatus());
case Kind::PartialCurriedObjCMethodType:
case Kind::ObjCMethodType: {
@@ -1290,7 +1288,8 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
return AbstractionPattern(getGenericSubstitutions(),
getGenericSignatureForFunctionComponent(),
getResultType(getType()), clangResultType);
getResultType(getType(), withoutYields),
clangResultType);
}
default:
@@ -1300,14 +1299,15 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
return AbstractionPattern::getObjCCompletionHandlerArgumentsType(
getGenericSubstitutions(),
getGenericSignatureForFunctionComponent(),
getResultType(getType()), callbackParamTy,
getResultType(getType(), withoutYields),
callbackParamTy,
getEncodedForeignInfo());
}
}
return AbstractionPattern(getGenericSubstitutions(),
getGenericSignatureForFunctionComponent(),
getResultType(getType()),
getResultType(getType(), withoutYields),
getObjCMethod()->getReturnType().getTypePtr());
}
case Kind::OpaqueFunction:
@@ -2753,13 +2753,6 @@ public:
llvm_unreachable("shouldn't encounter pack element by itself");
}
CanType visitYieldResultType(CanYieldResultType yield,
AbstractionPattern pattern) {
auto resultType = visit(yield.getResultType(), pattern.getYieldResultType());
return YieldResultType::get(resultType, yield->isInOut())
->getCanonicalType();
}
CanType handlePackExpansion(AbstractionPattern origExpansion,
CanType candidateSubstType) {
// When we're within a pack expansion, pack references matching that
@@ -2936,10 +2929,9 @@ public:
addParam(param.getOrigFlags(), expansionType);
}
});
if (yieldType) {
if (yieldType)
substYieldType = visit(yieldType, yieldPattern);
}
CanType newErrorType;
@@ -2949,8 +2941,8 @@ public:
newErrorType = visit(errorType, errorPattern);
}
auto newResultTy = visit(func.getResult(),
pattern.getFunctionResultType());
auto newResultTy = visit(func->getWithoutYields()->getResult()->getCanonicalType(),
pattern.getFunctionResultType(/* withoutYields */ true));
std::optional<FunctionType::ExtInfo> extInfo;
if (func->hasExtInfo())
@@ -2962,6 +2954,10 @@ public:
extInfo = extInfo->withThrows(true, newErrorType);
}
// Yields were substituted separately
if (extInfo)
extInfo = extInfo->withCoroutine(false);
return CanFunctionType::get(FunctionType::CanParamArrayRef(newParams),
newResultTy, extInfo);
}

View File

@@ -114,9 +114,9 @@ extension ConcreteWithInt : ProtoWithAssoc {
}
// CHECK-LABEL: sil_vtable ConcreteWithInt {
// CHECK: #Generic.generic!modify: <T> (Generic<T>) -> () -> () : @$s3mod15ConcreteWithIntC7genericSivMAA7GenericCADxvMTV [override]
// CHECK: #Generic.genericFunction!modify: <T> (Generic<T>) -> () -> () : @$s3mod15ConcreteWithIntC15genericFunctionSiycvMAA7GenericCADxycvMTV [override]
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> (U) -> () : @$s3mod15ConcreteWithIntC16returningGenericSix_tcluiMAA0F0CADxqd___tcluiMTV [override]
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> (U) -> () : @$s3mod15ConcreteWithIntC19returningOwnGenericxx_tcluiM [override]
// CHECK: #Generic.complexTuple!modify: <T> (Generic<T>) -> () -> () : @$s3mod15ConcreteWithIntC12complexTupleSiSg_SDySSSiGtvMAA7GenericCADxSg_SDySSxGtvMTV [override]
// CHECK: #Generic.generic!modify: <T> (Generic<T>) -> @yield_once () -> inout @yields T : @$s3mod15ConcreteWithIntC7genericSivMAA7GenericCADxvMTV [override]
// CHECK: #Generic.genericFunction!modify: <T> (Generic<T>) -> @yield_once () -> inout @yields () -> T : @$s3mod15ConcreteWithIntC15genericFunctionSiycvMAA7GenericCADxycvMTV [override]
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> @yield_once (U) -> inout @yields T : @$s3mod15ConcreteWithIntC16returningGenericSix_tcluiMAA0F0CADxqd___tcluiMTV [override]
// CHECK: #Generic.subscript!modify: <T><U> (Generic<T>) -> @yield_once (U) -> inout @yields U : @$s3mod15ConcreteWithIntC19returningOwnGenericxx_tcluiM [override]
// CHECK: #Generic.complexTuple!modify: <T> (Generic<T>) -> @yield_once () -> inout @yields (T?, [String : T]) : @$s3mod15ConcreteWithIntC12complexTupleSiSg_SDySSSiGtvMAA7GenericCADxSg_SDySSxGtvMTV [override]
// CHECK: }

View File

@@ -2357,4 +2357,4 @@
],
"json_format_version": 8
}
}
}

View File

@@ -130,6 +130,44 @@ Func withTaskGroup(of:returning:body:) has mangled name changing from '_Concurre
Func pthread_main_np() is a new API without '@available'
Accessor AsyncCompactMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncDropFirstSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncDropFirstSequence.Iterator.count.Modify() has return type change from () to inout @yields Swift.Int
Accessor AsyncDropWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncDropWhileSequence.Iterator.predicate.Modify() has return type change from () to inout @yields ((τ_0_0.Element) async -> Swift.Bool)?
Accessor AsyncFilterSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncFlatMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncFlatMapSequence.Iterator.currentIterator.Modify() has return type change from () to inout @yields τ_0_1.AsyncIterator?
Accessor AsyncFlatMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
Accessor AsyncMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncPrefixSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncPrefixSequence.Iterator.remaining.Modify() has return type change from () to inout @yields Swift.Int
Accessor AsyncPrefixWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncPrefixWhileSequence.Iterator.predicateHasFailed.Modify() has return type change from () to inout @yields Swift.Bool
Accessor AsyncStream.Continuation.onTermination.Modify() has return type change from () to inout @yields ((_Concurrency.AsyncStream<τ_0_0>.Continuation.Termination) -> ())?
Accessor AsyncThrowingCompactMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncThrowingCompactMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
Accessor AsyncThrowingDropWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncThrowingDropWhileSequence.Iterator.doneDropping.Modify() has return type change from () to inout @yields Swift.Bool
Accessor AsyncThrowingDropWhileSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
Accessor AsyncThrowingFilterSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncThrowingFilterSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
Accessor AsyncThrowingFlatMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncThrowingFlatMapSequence.Iterator.currentIterator.Modify() has return type change from () to inout @yields τ_0_1.AsyncIterator?
Accessor AsyncThrowingFlatMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
Accessor AsyncThrowingMapSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncThrowingMapSequence.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
Accessor AsyncThrowingPrefixWhileSequence.Iterator.baseIterator.Modify() has return type change from () to inout @yields τ_0_0.AsyncIterator
Accessor AsyncThrowingPrefixWhileSequence.Iterator.predicateHasFailed.Modify() has return type change from () to inout @yields Swift.Bool
Accessor AsyncThrowingStream.Continuation.onTermination.Modify() has return type change from () to inout @yields ((_Concurrency.AsyncThrowingStream<τ_0_0, τ_0_1>.Continuation.Termination) -> ())?
Accessor TaskGroup.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
Accessor TaskGroup.Iterator.group.Modify() has return type change from () to inout @yields _Concurrency.TaskGroup<τ_0_0>
Accessor TaskPriority.rawValue.Modify() has return type change from () to inout @yields Swift.UInt8
Accessor ThrowingTaskGroup.Iterator.finished.Modify() has return type change from () to inout @yields Swift.Bool
Accessor ThrowingTaskGroup.Iterator.group.Modify() has return type change from () to inout @yields _Concurrency.ThrowingTaskGroup<τ_0_0, τ_0_1>
Accessor UnownedSerialExecutor.executor.Modify() has return type change from () to inout @yields Builtin.Executor
Accessor UnsafeContinuation.context.Modify() has return type change from () to inout @yields Builtin.RawUnsafeContinuation
// *** DO NOT DISABLE OR XFAIL THIS TEST. *** (See comment above.)