mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[AutoDiff] First cut of coroutines differentiation (#71461)
This PR implements first set of changes required to support autodiff for coroutines. It mostly targeted to `_modify` accessors in standard library (and beyond), but overall implementation is quite generic. There are some specifics of implementation and known limitations: - Only `@yield_once` coroutines are naturally supported - VJP is a coroutine itself: it yields the results *and* returns a pullback closure as a normal return. This allows us to capture values produced in resume part of a coroutine (this is required for defers and other cleanups / commits) - Pullback is a coroutine, we assume that coroutine cannot abort and therefore we execute the original coroutine in reverse from return via yield and then back to the entry - It seems there is no semantically sane way to support `_read` coroutines (as we will need to "accept" adjoints via yields), therefore only coroutines with inout yields are supported (`_modify` accessors). Pullbacks of such coroutines take adjoint buffer as input argument, yield this buffer (to accumulate adjoint values in the caller) and finally return the adjoints indirectly. - Coroutines (as opposed to normal functions) are not first-class values: there is no AST type for them, one cannot e.g. store them into tuples, etc. So, everywhere where AST type is required, we have to hack around. - As there is no AST type for coroutines, there is no way one could register custom derivative for coroutines. So far only compiler-produced derivatives are supported - There are lots of common things wrt normal function apply's, but still there are subtle but important differences. I tried to organize the code to enable code reuse, still it was not always possible, so some code duplication could be seen - The order of how pullback closures are produced in VJP is a bit different: for normal apply's VJP produces both value and pullback closure via a single nested VJP apply. This is not so anymore with coroutine VJP's: yielded values are produced at `begin_apply` site and pullback closure is available only from `end_apply`, so we need to track the order in which pullbacks are produced (and arrange consumption of the values accordingly – effectively delay them) - On the way some complementary changes were required in e.g. mangler / demangler This patch covers the generation of derivatives up to SIL level, however, it is not enough as codegen of `partial_apply` of a coroutine is completely broken. The fix for this will be submitted separately as it is not directly autodiff-related. --------- Co-authored-by: Andrew Savonichev <andrew.savonichev@gmail.com> Co-authored-by: Richard Wei <rxwei@apple.com>
This commit is contained in:
committed by
GitHub
parent
b3b2f37262
commit
c7a216058f
@@ -81,6 +81,8 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
|
||||
s << val << '\n';
|
||||
});
|
||||
// Outputs are indirect result buffers and return values, count `m`.
|
||||
// For the purposes of differentiation, we consider yields to be results as
|
||||
// well
|
||||
collectAllFormalResultsInTypeOrder(function, outputValues);
|
||||
LLVM_DEBUG({
|
||||
auto &s = getADDebugStream();
|
||||
@@ -312,14 +314,17 @@ void DifferentiableActivityInfo::setUsefulAndPropagateToOperands(
|
||||
for (auto incomingValue : incomingValues)
|
||||
setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex);
|
||||
return;
|
||||
} else if (bbArg->isTerminatorResult()) {
|
||||
}
|
||||
|
||||
if (bbArg->isTerminatorResult()) {
|
||||
if (TryApplyInst *tai = dyn_cast<TryApplyInst>(bbArg->getTerminatorForResult())) {
|
||||
propagateUseful(tai, dependentVariableIndex);
|
||||
return;
|
||||
} else
|
||||
llvm::report_fatal_error("unknown terminator with result");
|
||||
} else
|
||||
llvm::report_fatal_error("do not know how to handle this incoming bb argument");
|
||||
}
|
||||
llvm::report_fatal_error("unknown terminator with result");
|
||||
}
|
||||
|
||||
llvm::report_fatal_error("do not know how to handle this incoming bb argument");
|
||||
}
|
||||
|
||||
auto *inst = value->getDefiningInstruction();
|
||||
|
||||
Reference in New Issue
Block a user