mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Merge pull request #76743 from swiftlang/coro-pa-context
Fix partial apply forwarder emission for coroutines that are methods of structs with type parameters
This commit is contained in:
@@ -1120,11 +1120,11 @@ public:
|
|||||||
virtual void addDynamicFunctionContext(Explosion &explosion) = 0;
|
virtual void addDynamicFunctionContext(Explosion &explosion) = 0;
|
||||||
virtual void addDynamicFunctionPointer(Explosion &explosion) = 0;
|
virtual void addDynamicFunctionPointer(Explosion &explosion) = 0;
|
||||||
|
|
||||||
virtual void addSelf(Explosion &explosion) { addArgument(explosion); }
|
void addSelf(Explosion &explosion) { addArgument(explosion); }
|
||||||
virtual void addWitnessSelfMetadata(llvm::Value *value) {
|
void addWitnessSelfMetadata(llvm::Value *value) {
|
||||||
addArgument(value);
|
addArgument(value);
|
||||||
}
|
}
|
||||||
virtual void addWitnessSelfWitnessTable(llvm::Value *value) {
|
void addWitnessSelfWitnessTable(llvm::Value *value) {
|
||||||
addArgument(value);
|
addArgument(value);
|
||||||
}
|
}
|
||||||
virtual void forwardErrorResult() = 0;
|
virtual void forwardErrorResult() = 0;
|
||||||
@@ -1438,12 +1438,6 @@ class CoroPartialApplicationForwarderEmission
|
|||||||
: public PartialApplicationForwarderEmission {
|
: public PartialApplicationForwarderEmission {
|
||||||
using super = PartialApplicationForwarderEmission;
|
using super = PartialApplicationForwarderEmission;
|
||||||
|
|
||||||
private:
|
|
||||||
llvm::Value *Self;
|
|
||||||
llvm::Value *FirstData;
|
|
||||||
llvm::Value *SecondData;
|
|
||||||
WitnessMetadata Witness;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CoroPartialApplicationForwarderEmission(
|
CoroPartialApplicationForwarderEmission(
|
||||||
IRGenModule &IGM, IRGenFunction &subIGF, llvm::Function *fwd,
|
IRGenModule &IGM, IRGenFunction &subIGF, llvm::Function *fwd,
|
||||||
@@ -1454,8 +1448,7 @@ public:
|
|||||||
ArrayRef<ParameterConvention> conventions)
|
ArrayRef<ParameterConvention> conventions)
|
||||||
: PartialApplicationForwarderEmission(
|
: PartialApplicationForwarderEmission(
|
||||||
IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType,
|
IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType,
|
||||||
substType, outType, subs, layout, conventions),
|
substType, outType, subs, layout, conventions) {}
|
||||||
Self(nullptr), FirstData(nullptr), SecondData(nullptr) {}
|
|
||||||
|
|
||||||
void begin() override {
|
void begin() override {
|
||||||
auto unsubstType = substType->getUnsubstitutedType(IGM.getSILModule());
|
auto unsubstType = substType->getUnsubstitutedType(IGM.getSILModule());
|
||||||
@@ -1499,41 +1492,13 @@ public:
|
|||||||
void gatherArgumentsFromApply() override {
|
void gatherArgumentsFromApply() override {
|
||||||
super::gatherArgumentsFromApply(false);
|
super::gatherArgumentsFromApply(false);
|
||||||
}
|
}
|
||||||
llvm::Value *getDynamicFunctionPointer() override {
|
llvm::Value *getDynamicFunctionPointer() override { return args.takeLast(); }
|
||||||
llvm::Value *Ret = SecondData;
|
llvm::Value *getDynamicFunctionContext() override { return args.takeLast(); }
|
||||||
SecondData = nullptr;
|
|
||||||
return Ret;
|
|
||||||
}
|
|
||||||
llvm::Value *getDynamicFunctionContext() override {
|
|
||||||
llvm::Value *Ret = FirstData;
|
|
||||||
FirstData = nullptr;
|
|
||||||
return Ret;
|
|
||||||
}
|
|
||||||
void addDynamicFunctionContext(Explosion &explosion) override {
|
void addDynamicFunctionContext(Explosion &explosion) override {
|
||||||
assert(!Self && "context value overrides 'self'");
|
addArgument(explosion);
|
||||||
FirstData = explosion.claimNext();
|
|
||||||
}
|
}
|
||||||
void addDynamicFunctionPointer(Explosion &explosion) override {
|
void addDynamicFunctionPointer(Explosion &explosion) override {
|
||||||
SecondData = explosion.claimNext();
|
addArgument(explosion);
|
||||||
}
|
|
||||||
void addSelf(Explosion &explosion) override {
|
|
||||||
assert(!FirstData && "'self' overrides another context value");
|
|
||||||
if (!hasSelfContextParameter(origType)) {
|
|
||||||
// witness methods can be declared on types that are not classes. Pass
|
|
||||||
// such "self" argument as a plain argument.
|
|
||||||
addArgument(explosion);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Self = explosion.claimNext();
|
|
||||||
FirstData = Self;
|
|
||||||
}
|
|
||||||
|
|
||||||
void addWitnessSelfMetadata(llvm::Value *value) override {
|
|
||||||
Witness.SelfMetadata = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
void addWitnessSelfWitnessTable(llvm::Value *value) override {
|
|
||||||
Witness.SelfWitnessTable = value;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void forwardErrorResult() override {
|
void forwardErrorResult() override {
|
||||||
@@ -1554,13 +1519,26 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
Explosion callCoroutine(FunctionPointer &fnPtr) {
|
Explosion callCoroutine(FunctionPointer &fnPtr) {
|
||||||
Callee callee({origType, substType, subs}, fnPtr, FirstData, SecondData);
|
bool isWitnessMethodCallee = origType->getRepresentation() ==
|
||||||
|
SILFunctionTypeRepresentation::WitnessMethod;
|
||||||
|
|
||||||
|
WitnessMetadata witnessMetadata;
|
||||||
|
if (isWitnessMethodCallee) {
|
||||||
|
witnessMetadata.SelfWitnessTable = args.takeLast();
|
||||||
|
witnessMetadata.SelfMetadata = args.takeLast();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Value *selfValue = nullptr;
|
||||||
|
if (calleeHasContext || hasSelfContextParameter(origType))
|
||||||
|
selfValue = args.takeLast();
|
||||||
|
|
||||||
|
Callee callee({origType, substType, subs}, fnPtr, selfValue);
|
||||||
|
|
||||||
std::unique_ptr<CallEmission> emitSuspend =
|
std::unique_ptr<CallEmission> emitSuspend =
|
||||||
getCallEmission(subIGF, Self, std::move(callee));
|
getCallEmission(subIGF, callee.getSwiftContext(), std::move(callee));
|
||||||
|
|
||||||
emitSuspend->begin();
|
emitSuspend->begin();
|
||||||
emitSuspend->setArgs(args, /*isOutlined=*/false, &Witness);
|
emitSuspend->setArgs(args, /*isOutlined=*/false, &witnessMetadata);
|
||||||
Explosion yieldedValues;
|
Explosion yieldedValues;
|
||||||
emitSuspend->emitToExplosion(yieldedValues, /*isOutlined=*/false);
|
emitSuspend->emitToExplosion(yieldedValues, /*isOutlined=*/false);
|
||||||
emitSuspend->end();
|
emitSuspend->end();
|
||||||
@@ -1966,12 +1944,7 @@ static llvm::Value *emitPartialApplicationForwarder(
|
|||||||
} else {
|
} else {
|
||||||
argValue = subIGF.Builder.CreateBitCast(rawData, expectedArgTy);
|
argValue = subIGF.Builder.CreateBitCast(rawData, expectedArgTy);
|
||||||
}
|
}
|
||||||
if (haveContextArgument) {
|
emission->addArgument(argValue);
|
||||||
Explosion e;
|
|
||||||
e.add(argValue);
|
|
||||||
emission->addDynamicFunctionContext(e);
|
|
||||||
} else
|
|
||||||
emission->addArgument(argValue);
|
|
||||||
|
|
||||||
// If there's a data pointer required, grab it and load out the
|
// If there's a data pointer required, grab it and load out the
|
||||||
// extra, previously-curried parameters.
|
// extra, previously-curried parameters.
|
||||||
|
|||||||
@@ -42,5 +42,39 @@ ModifyAccessorTests.test("SimpleModifyAccessor") {
|
|||||||
expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct))
|
expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ModifyAccessorTests.test("GenericModifyAccessor") {
|
||||||
|
struct S<T : Differentiable & SignedNumeric & Comparable>: Differentiable {
|
||||||
|
private var _x : T
|
||||||
|
|
||||||
|
func _endMutation() {}
|
||||||
|
|
||||||
|
var x: T {
|
||||||
|
get{_x}
|
||||||
|
set(newValue) { _x = newValue }
|
||||||
|
_modify {
|
||||||
|
defer { _endMutation() }
|
||||||
|
if (x > -x) {
|
||||||
|
yield &_x
|
||||||
|
} else {
|
||||||
|
yield &_x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
init(_ x : T) {
|
||||||
|
self._x = x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func modify_struct(_ x : Float) -> Float {
|
||||||
|
var s = S<Float>(x)
|
||||||
|
s.x *= s.x
|
||||||
|
return s.x
|
||||||
|
}
|
||||||
|
|
||||||
|
expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
runAllTests()
|
runAllTests()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user