[AutoDiff] First cut of coroutines differentiation (#71461)

This PR implements first set of changes required to support autodiff for coroutines. It mostly targeted to `_modify` accessors in standard library (and beyond), but overall implementation is quite generic.

There are some specifics of implementation and known limitations:
 - Only `@yield_once` coroutines are naturally supported
 - VJP is a coroutine itself: it yields the results *and* returns a pullback closure as a normal return. This allows us to capture values produced in resume part of a coroutine (this is required for defers and other cleanups / commits)
 - Pullback is a coroutine, we assume that coroutine cannot abort and therefore we execute the original coroutine in reverse from return via yield and then back to the entry
 - It seems there is no semantically sane way to support `_read` coroutines (as we will need to "accept" adjoints via yields), therefore only coroutines with inout yields are supported (`_modify` accessors). Pullbacks of such coroutines take adjoint buffer as input argument, yield this buffer (to accumulate adjoint values in the caller) and finally return the adjoints indirectly.
 - Coroutines (as opposed to normal functions) are not first-class values: there is no AST type for them, one cannot e.g. store them into tuples, etc. So, everywhere where AST type is required, we have to hack around.
 - As there is no AST type for coroutines, there is no way one could register custom derivative for coroutines. So far only compiler-produced derivatives are supported
 - There are lots of common things wrt normal function apply's, but still there are subtle but important differences. I tried to organize the code to enable code reuse, still it was not always possible, so some code duplication could be seen
 - The order of how pullback closures are produced in VJP is a bit different: for normal apply's VJP produces both value and pullback closure via a single nested VJP apply. This is not so anymore with coroutine VJP's: yielded values are produced at `begin_apply` site and pullback closure is available only from `end_apply`, so we need to track the order in which pullbacks are produced (and arrange consumption of the values accordingly – effectively delay them)
 - On the way some complementary changes were required in e.g. mangler / demangler

This patch covers the generation of derivatives up to SIL level, however, it is not enough as codegen of `partial_apply` of a coroutine is completely broken. The fix for this will be submitted separately as it is not directly autodiff-related.

---------

Co-authored-by: Andrew Savonichev <andrew.savonichev@gmail.com>
Co-authored-by: Richard Wei <rxwei@apple.com>
This commit is contained in:
Anton Korobeynikov
2024-04-04 17:24:55 -07:00
committed by GitHub
parent b3b2f37262
commit c7a216058f
30 changed files with 1045 additions and 303 deletions

View File

@@ -17,6 +17,7 @@
#include "swift/Basic/STLExtras.h"
#define DEBUG_TYPE "differentiation"
#include "swift/SIL/ApplySite.h"
#include "swift/SILOptimizer/Differentiation/Common.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
@@ -145,6 +146,20 @@ void collectAllFormalResultsInTypeOrder(SILFunction &function,
auto *argument = function.getArgumentsWithoutIndirectResults()[i];
results.push_back(argument);
}
// Treat yields as semantic results. Note that we can only differentiate
// @yield_once with simple control flow, so we can assume that the function
// contains only a single `yield` instruction
auto yieldIt =
std::find_if(function.begin(), function.end(),
[](const SILBasicBlock &BB) -> bool {
const TermInst *TI = BB.getTerminator();
return isa<YieldInst>(TI);
});
if (yieldIt != function.end()) {
auto *yieldInst = cast<YieldInst>(yieldIt->getTerminator());
for (auto yield : yieldInst->getOperandValues())
results.push_back(yield);
}
}
void collectAllDirectResultsInTypeOrder(SILFunction &function,
@@ -161,30 +176,30 @@ void collectAllDirectResultsInTypeOrder(SILFunction &function,
}
void collectAllActualResultsInTypeOrder(
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
FullApplySite fai, ArrayRef<SILValue> extractedDirectResults,
SmallVectorImpl<SILValue> &results) {
auto calleeConvs = ai->getSubstCalleeConv();
auto calleeConvs = fai.getSubstCalleeConv();
unsigned indResIdx = 0, dirResIdx = 0;
for (auto &resInfo : calleeConvs.getResults()) {
results.push_back(resInfo.isFormalDirect()
? extractedDirectResults[dirResIdx++]
: ai->getIndirectSILResults()[indResIdx++]);
: fai.getIndirectSILResults()[indResIdx++]);
}
}
void collectMinimalIndicesForFunctionCall(
ApplyInst *ai, const AutoDiffConfig &parentConfig,
FullApplySite ai, const AutoDiffConfig &parentConfig,
const DifferentiableActivityInfo &activityInfo,
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
SmallVectorImpl<unsigned> &resultIndices) {
auto calleeFnTy = ai->getSubstCalleeType();
auto calleeConvs = ai->getSubstCalleeConv();
auto calleeFnTy = ai.getSubstCalleeType();
auto calleeConvs = ai.getSubstCalleeConv();
// Parameter indices are indices (in the callee type signature) of parameter
// arguments that are varied or are arguments.
// Record all parameter indices in type order.
unsigned currentParamIdx = 0;
for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) {
for (auto applyArg : ai.getArgumentsWithoutIndirectResults()) {
if (activityInfo.isActive(applyArg, parentConfig))
paramIndices.push_back(currentParamIdx);
++currentParamIdx;
@@ -196,7 +211,7 @@ void collectMinimalIndicesForFunctionCall(
forEachApplyDirectResult(ai, [&](SILValue directResult) {
directResults.push_back(directResult);
});
auto indirectResults = ai->getIndirectSILResults();
auto indirectResults = ai.getIndirectSILResults();
// Record all results and result indices in type order.
results.reserve(calleeFnTy->getNumResults());
unsigned dirResIdx = 0;
@@ -225,10 +240,20 @@ void collectMinimalIndicesForFunctionCall(
if (!param.isAutoDiffSemanticResult())
continue;
unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults();
results.push_back(ai->getArgument(idx));
results.push_back(ai.getArgument(idx));
resultIndices.push_back(semanticResultParamResultIndex++);
}
// Record all yields. While we do not have a way to represent direct yields
// (_read accessors) we run activity analysis for them. These will be
// diagnosed later.
if (BeginApplyInst *bai = dyn_cast<BeginApplyInst>(*ai)) {
for (const auto &yieldAndIdx : enumerate(calleeConvs.getYields())) {
results.push_back(bai->getYieldedValues()[yieldAndIdx.index()]);
resultIndices.push_back(semanticResultParamResultIndex++);
}
}
// Make sure the function call has active results.
#ifndef NDEBUG
assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults());