mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[AutoDiff] First cut of coroutines differentiation (#71461)
This PR implements first set of changes required to support autodiff for coroutines. It mostly targeted to `_modify` accessors in standard library (and beyond), but overall implementation is quite generic. There are some specifics of implementation and known limitations: - Only `@yield_once` coroutines are naturally supported - VJP is a coroutine itself: it yields the results *and* returns a pullback closure as a normal return. This allows us to capture values produced in resume part of a coroutine (this is required for defers and other cleanups / commits) - Pullback is a coroutine, we assume that coroutine cannot abort and therefore we execute the original coroutine in reverse from return via yield and then back to the entry - It seems there is no semantically sane way to support `_read` coroutines (as we will need to "accept" adjoints via yields), therefore only coroutines with inout yields are supported (`_modify` accessors). Pullbacks of such coroutines take adjoint buffer as input argument, yield this buffer (to accumulate adjoint values in the caller) and finally return the adjoints indirectly. - Coroutines (as opposed to normal functions) are not first-class values: there is no AST type for them, one cannot e.g. store them into tuples, etc. So, everywhere where AST type is required, we have to hack around. - As there is no AST type for coroutines, there is no way one could register custom derivative for coroutines. So far only compiler-produced derivatives are supported - There are lots of common things wrt normal function apply's, but still there are subtle but important differences. I tried to organize the code to enable code reuse, still it was not always possible, so some code duplication could be seen - The order of how pullback closures are produced in VJP is a bit different: for normal apply's VJP produces both value and pullback closure via a single nested VJP apply. This is not so anymore with coroutine VJP's: yielded values are produced at `begin_apply` site and pullback closure is available only from `end_apply`, so we need to track the order in which pullbacks are produced (and arrange consumption of the values accordingly – effectively delay them) - On the way some complementary changes were required in e.g. mangler / demangler This patch covers the generation of derivatives up to SIL level, however, it is not enough as codegen of `partial_apply` of a coroutine is completely broken. The fix for this will be submitted separately as it is not directly autodiff-related. --------- Co-authored-by: Andrew Savonichev <andrew.savonichev@gmail.com> Co-authored-by: Richard Wei <rxwei@apple.com>
This commit is contained in:
committed by
GitHub
parent
b3b2f37262
commit
c7a216058f
@@ -17,6 +17,7 @@
|
||||
#include "swift/Basic/STLExtras.h"
|
||||
#define DEBUG_TYPE "differentiation"
|
||||
|
||||
#include "swift/SIL/ApplySite.h"
|
||||
#include "swift/SILOptimizer/Differentiation/Common.h"
|
||||
#include "swift/AST/TypeCheckRequests.h"
|
||||
#include "swift/SILOptimizer/Differentiation/ADContext.h"
|
||||
@@ -145,6 +146,20 @@ void collectAllFormalResultsInTypeOrder(SILFunction &function,
|
||||
auto *argument = function.getArgumentsWithoutIndirectResults()[i];
|
||||
results.push_back(argument);
|
||||
}
|
||||
// Treat yields as semantic results. Note that we can only differentiate
|
||||
// @yield_once with simple control flow, so we can assume that the function
|
||||
// contains only a single `yield` instruction
|
||||
auto yieldIt =
|
||||
std::find_if(function.begin(), function.end(),
|
||||
[](const SILBasicBlock &BB) -> bool {
|
||||
const TermInst *TI = BB.getTerminator();
|
||||
return isa<YieldInst>(TI);
|
||||
});
|
||||
if (yieldIt != function.end()) {
|
||||
auto *yieldInst = cast<YieldInst>(yieldIt->getTerminator());
|
||||
for (auto yield : yieldInst->getOperandValues())
|
||||
results.push_back(yield);
|
||||
}
|
||||
}
|
||||
|
||||
void collectAllDirectResultsInTypeOrder(SILFunction &function,
|
||||
@@ -161,30 +176,30 @@ void collectAllDirectResultsInTypeOrder(SILFunction &function,
|
||||
}
|
||||
|
||||
void collectAllActualResultsInTypeOrder(
|
||||
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
|
||||
FullApplySite fai, ArrayRef<SILValue> extractedDirectResults,
|
||||
SmallVectorImpl<SILValue> &results) {
|
||||
auto calleeConvs = ai->getSubstCalleeConv();
|
||||
auto calleeConvs = fai.getSubstCalleeConv();
|
||||
unsigned indResIdx = 0, dirResIdx = 0;
|
||||
for (auto &resInfo : calleeConvs.getResults()) {
|
||||
results.push_back(resInfo.isFormalDirect()
|
||||
? extractedDirectResults[dirResIdx++]
|
||||
: ai->getIndirectSILResults()[indResIdx++]);
|
||||
: fai.getIndirectSILResults()[indResIdx++]);
|
||||
}
|
||||
}
|
||||
|
||||
void collectMinimalIndicesForFunctionCall(
|
||||
ApplyInst *ai, const AutoDiffConfig &parentConfig,
|
||||
FullApplySite ai, const AutoDiffConfig &parentConfig,
|
||||
const DifferentiableActivityInfo &activityInfo,
|
||||
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> ¶mIndices,
|
||||
SmallVectorImpl<unsigned> &resultIndices) {
|
||||
auto calleeFnTy = ai->getSubstCalleeType();
|
||||
auto calleeConvs = ai->getSubstCalleeConv();
|
||||
auto calleeFnTy = ai.getSubstCalleeType();
|
||||
auto calleeConvs = ai.getSubstCalleeConv();
|
||||
|
||||
// Parameter indices are indices (in the callee type signature) of parameter
|
||||
// arguments that are varied or are arguments.
|
||||
// Record all parameter indices in type order.
|
||||
unsigned currentParamIdx = 0;
|
||||
for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) {
|
||||
for (auto applyArg : ai.getArgumentsWithoutIndirectResults()) {
|
||||
if (activityInfo.isActive(applyArg, parentConfig))
|
||||
paramIndices.push_back(currentParamIdx);
|
||||
++currentParamIdx;
|
||||
@@ -196,7 +211,7 @@ void collectMinimalIndicesForFunctionCall(
|
||||
forEachApplyDirectResult(ai, [&](SILValue directResult) {
|
||||
directResults.push_back(directResult);
|
||||
});
|
||||
auto indirectResults = ai->getIndirectSILResults();
|
||||
auto indirectResults = ai.getIndirectSILResults();
|
||||
// Record all results and result indices in type order.
|
||||
results.reserve(calleeFnTy->getNumResults());
|
||||
unsigned dirResIdx = 0;
|
||||
@@ -225,10 +240,20 @@ void collectMinimalIndicesForFunctionCall(
|
||||
if (!param.isAutoDiffSemanticResult())
|
||||
continue;
|
||||
unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults();
|
||||
results.push_back(ai->getArgument(idx));
|
||||
results.push_back(ai.getArgument(idx));
|
||||
resultIndices.push_back(semanticResultParamResultIndex++);
|
||||
}
|
||||
|
||||
// Record all yields. While we do not have a way to represent direct yields
|
||||
// (_read accessors) we run activity analysis for them. These will be
|
||||
// diagnosed later.
|
||||
if (BeginApplyInst *bai = dyn_cast<BeginApplyInst>(*ai)) {
|
||||
for (const auto &yieldAndIdx : enumerate(calleeConvs.getYields())) {
|
||||
results.push_back(bai->getYieldedValues()[yieldAndIdx.index()]);
|
||||
resultIndices.push_back(semanticResultParamResultIndex++);
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure the function call has active results.
|
||||
#ifndef NDEBUG
|
||||
assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults());
|
||||
|
||||
Reference in New Issue
Block a user