Merge pull request #76743 from swiftlang/coro-pa-context

Fix partial apply forwarder emission for coroutines that are methods of structs with type parameters
This commit is contained in:
Dario Rexin
2024-11-21 03:05:25 -08:00
committed by GitHub
2 changed files with 59 additions and 52 deletions

View File

@@ -1120,11 +1120,11 @@ public:
virtual void addDynamicFunctionContext(Explosion &explosion) = 0; virtual void addDynamicFunctionContext(Explosion &explosion) = 0;
virtual void addDynamicFunctionPointer(Explosion &explosion) = 0; virtual void addDynamicFunctionPointer(Explosion &explosion) = 0;
virtual void addSelf(Explosion &explosion) { addArgument(explosion); } void addSelf(Explosion &explosion) { addArgument(explosion); }
virtual void addWitnessSelfMetadata(llvm::Value *value) { void addWitnessSelfMetadata(llvm::Value *value) {
addArgument(value); addArgument(value);
} }
virtual void addWitnessSelfWitnessTable(llvm::Value *value) { void addWitnessSelfWitnessTable(llvm::Value *value) {
addArgument(value); addArgument(value);
} }
virtual void forwardErrorResult() = 0; virtual void forwardErrorResult() = 0;
@@ -1438,12 +1438,6 @@ class CoroPartialApplicationForwarderEmission
: public PartialApplicationForwarderEmission { : public PartialApplicationForwarderEmission {
using super = PartialApplicationForwarderEmission; using super = PartialApplicationForwarderEmission;
private:
llvm::Value *Self;
llvm::Value *FirstData;
llvm::Value *SecondData;
WitnessMetadata Witness;
public: public:
CoroPartialApplicationForwarderEmission( CoroPartialApplicationForwarderEmission(
IRGenModule &IGM, IRGenFunction &subIGF, llvm::Function *fwd, IRGenModule &IGM, IRGenFunction &subIGF, llvm::Function *fwd,
@@ -1454,8 +1448,7 @@ public:
ArrayRef<ParameterConvention> conventions) ArrayRef<ParameterConvention> conventions)
: PartialApplicationForwarderEmission( : PartialApplicationForwarderEmission(
IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType, IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType,
substType, outType, subs, layout, conventions), substType, outType, subs, layout, conventions) {}
Self(nullptr), FirstData(nullptr), SecondData(nullptr) {}
void begin() override { void begin() override {
auto unsubstType = substType->getUnsubstitutedType(IGM.getSILModule()); auto unsubstType = substType->getUnsubstitutedType(IGM.getSILModule());
@@ -1499,41 +1492,13 @@ public:
void gatherArgumentsFromApply() override { void gatherArgumentsFromApply() override {
super::gatherArgumentsFromApply(false); super::gatherArgumentsFromApply(false);
} }
llvm::Value *getDynamicFunctionPointer() override { llvm::Value *getDynamicFunctionPointer() override { return args.takeLast(); }
llvm::Value *Ret = SecondData; llvm::Value *getDynamicFunctionContext() override { return args.takeLast(); }
SecondData = nullptr;
return Ret;
}
llvm::Value *getDynamicFunctionContext() override {
llvm::Value *Ret = FirstData;
FirstData = nullptr;
return Ret;
}
void addDynamicFunctionContext(Explosion &explosion) override { void addDynamicFunctionContext(Explosion &explosion) override {
assert(!Self && "context value overrides 'self'"); addArgument(explosion);
FirstData = explosion.claimNext();
} }
void addDynamicFunctionPointer(Explosion &explosion) override { void addDynamicFunctionPointer(Explosion &explosion) override {
SecondData = explosion.claimNext(); addArgument(explosion);
}
void addSelf(Explosion &explosion) override {
assert(!FirstData && "'self' overrides another context value");
if (!hasSelfContextParameter(origType)) {
// witness methods can be declared on types that are not classes. Pass
// such "self" argument as a plain argument.
addArgument(explosion);
return;
}
Self = explosion.claimNext();
FirstData = Self;
}
void addWitnessSelfMetadata(llvm::Value *value) override {
Witness.SelfMetadata = value;
}
void addWitnessSelfWitnessTable(llvm::Value *value) override {
Witness.SelfWitnessTable = value;
} }
void forwardErrorResult() override { void forwardErrorResult() override {
@@ -1554,13 +1519,26 @@ public:
} }
Explosion callCoroutine(FunctionPointer &fnPtr) { Explosion callCoroutine(FunctionPointer &fnPtr) {
Callee callee({origType, substType, subs}, fnPtr, FirstData, SecondData); bool isWitnessMethodCallee = origType->getRepresentation() ==
SILFunctionTypeRepresentation::WitnessMethod;
WitnessMetadata witnessMetadata;
if (isWitnessMethodCallee) {
witnessMetadata.SelfWitnessTable = args.takeLast();
witnessMetadata.SelfMetadata = args.takeLast();
}
llvm::Value *selfValue = nullptr;
if (calleeHasContext || hasSelfContextParameter(origType))
selfValue = args.takeLast();
Callee callee({origType, substType, subs}, fnPtr, selfValue);
std::unique_ptr<CallEmission> emitSuspend = std::unique_ptr<CallEmission> emitSuspend =
getCallEmission(subIGF, Self, std::move(callee)); getCallEmission(subIGF, callee.getSwiftContext(), std::move(callee));
emitSuspend->begin(); emitSuspend->begin();
emitSuspend->setArgs(args, /*isOutlined=*/false, &Witness); emitSuspend->setArgs(args, /*isOutlined=*/false, &witnessMetadata);
Explosion yieldedValues; Explosion yieldedValues;
emitSuspend->emitToExplosion(yieldedValues, /*isOutlined=*/false); emitSuspend->emitToExplosion(yieldedValues, /*isOutlined=*/false);
emitSuspend->end(); emitSuspend->end();
@@ -1966,12 +1944,7 @@ static llvm::Value *emitPartialApplicationForwarder(
} else { } else {
argValue = subIGF.Builder.CreateBitCast(rawData, expectedArgTy); argValue = subIGF.Builder.CreateBitCast(rawData, expectedArgTy);
} }
if (haveContextArgument) { emission->addArgument(argValue);
Explosion e;
e.add(argValue);
emission->addDynamicFunctionContext(e);
} else
emission->addArgument(argValue);
// If there's a data pointer required, grab it and load out the // If there's a data pointer required, grab it and load out the
// extra, previously-curried parameters. // extra, previously-curried parameters.

View File

@@ -42,5 +42,39 @@ ModifyAccessorTests.test("SimpleModifyAccessor") {
expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct)) expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct))
} }
ModifyAccessorTests.test("GenericModifyAccessor") {
struct S<T : Differentiable & SignedNumeric & Comparable>: Differentiable {
private var _x : T
func _endMutation() {}
var x: T {
get{_x}
set(newValue) { _x = newValue }
_modify {
defer { _endMutation() }
if (x > -x) {
yield &_x
} else {
yield &_x
}
}
}
init(_ x : T) {
self._x = x
}
}
func modify_struct(_ x : Float) -> Float {
var s = S<Float>(x)
s.x *= s.x
return s.x
}
expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct))
}
runAllTests() runAllTests()