mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[AutoDiff] First cut of coroutines differentiation (#71461)
This PR implements first set of changes required to support autodiff for coroutines. It mostly targeted to `_modify` accessors in standard library (and beyond), but overall implementation is quite generic. There are some specifics of implementation and known limitations: - Only `@yield_once` coroutines are naturally supported - VJP is a coroutine itself: it yields the results *and* returns a pullback closure as a normal return. This allows us to capture values produced in resume part of a coroutine (this is required for defers and other cleanups / commits) - Pullback is a coroutine, we assume that coroutine cannot abort and therefore we execute the original coroutine in reverse from return via yield and then back to the entry - It seems there is no semantically sane way to support `_read` coroutines (as we will need to "accept" adjoints via yields), therefore only coroutines with inout yields are supported (`_modify` accessors). Pullbacks of such coroutines take adjoint buffer as input argument, yield this buffer (to accumulate adjoint values in the caller) and finally return the adjoints indirectly. - Coroutines (as opposed to normal functions) are not first-class values: there is no AST type for them, one cannot e.g. store them into tuples, etc. So, everywhere where AST type is required, we have to hack around. - As there is no AST type for coroutines, there is no way one could register custom derivative for coroutines. So far only compiler-produced derivatives are supported - There are lots of common things wrt normal function apply's, but still there are subtle but important differences. I tried to organize the code to enable code reuse, still it was not always possible, so some code duplication could be seen - The order of how pullback closures are produced in VJP is a bit different: for normal apply's VJP produces both value and pullback closure via a single nested VJP apply. This is not so anymore with coroutine VJP's: yielded values are produced at `begin_apply` site and pullback closure is available only from `end_apply`, so we need to track the order in which pullbacks are produced (and arrange consumption of the values accordingly – effectively delay them) - On the way some complementary changes were required in e.g. mangler / demangler This patch covers the generation of derivatives up to SIL level, however, it is not enough as codegen of `partial_apply` of a coroutine is completely broken. The fix for this will be submitted separately as it is not directly autodiff-related. --------- Co-authored-by: Andrew Savonichev <andrew.savonichev@gmail.com> Co-authored-by: Richard Wei <rxwei@apple.com>
This commit is contained in:
committed by
GitHub
parent
b3b2f37262
commit
c7a216058f
@@ -17,6 +17,7 @@
|
||||
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
|
||||
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
|
||||
|
||||
#include "swift/SIL/ApplySite.h"
|
||||
#include "swift/SILOptimizer/Differentiation/Common.h"
|
||||
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
|
||||
|
||||
@@ -51,6 +52,12 @@ struct NestedApplyInfo {
|
||||
/// The original pullback type before reabstraction. `None` if the pullback
|
||||
/// type is not reabstracted.
|
||||
std::optional<CanSILFunctionType> originalPullbackType;
|
||||
/// Index of `apply` pullback in nested pullback call
|
||||
unsigned pullbackIdx = -1U;
|
||||
/// Pullback value itself that is memoized in some cases (e.g. pullback is
|
||||
/// called by `begin_apply`, but should be destroyed after `end_apply`).
|
||||
SILValue pullback = SILValue();
|
||||
SILValue beginApplyToken = SILValue();
|
||||
};
|
||||
|
||||
/// Per-module contextual information for the Differentiation pass.
|
||||
@@ -97,7 +104,7 @@ private:
|
||||
|
||||
/// Mapping from original `apply` instructions to their corresponding
|
||||
/// `NestedApplyInfo`s.
|
||||
llvm::DenseMap<ApplyInst *, NestedApplyInfo> nestedApplyInfo;
|
||||
llvm::DenseMap<FullApplySite, NestedApplyInfo> nestedApplyInfo;
|
||||
|
||||
/// List of generated functions (JVPs, VJPs, pullbacks, and thunks).
|
||||
/// Saved for deletion during cleanup.
|
||||
@@ -185,7 +192,7 @@ public:
|
||||
invokers.insert({witness, DifferentiationInvoker(witness)});
|
||||
}
|
||||
|
||||
llvm::DenseMap<ApplyInst *, NestedApplyInfo> &getNestedApplyInfo() {
|
||||
llvm::DenseMap<FullApplySite, NestedApplyInfo> &getNestedApplyInfo() {
|
||||
return nestedApplyInfo;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "swift/AST/DiagnosticsSIL.h"
|
||||
#include "swift/AST/Expr.h"
|
||||
#include "swift/AST/SemanticAttrs.h"
|
||||
#include "swift/SIL/ApplySite.h"
|
||||
#include "swift/SIL/SILDifferentiabilityWitness.h"
|
||||
#include "swift/SIL/SILFunction.h"
|
||||
#include "swift/SIL/Projection.h"
|
||||
@@ -112,7 +113,7 @@ void collectAllDirectResultsInTypeOrder(SILFunction &function,
|
||||
/// Given a function call site, gathers all of its actual results (both direct
|
||||
/// and indirect) in an order defined by its result type.
|
||||
void collectAllActualResultsInTypeOrder(
|
||||
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
|
||||
FullApplySite fai, ArrayRef<SILValue> extractedDirectResults,
|
||||
SmallVectorImpl<SILValue> &results);
|
||||
|
||||
/// For an `apply` instruction with active results, compute:
|
||||
@@ -120,7 +121,7 @@ void collectAllActualResultsInTypeOrder(
|
||||
/// - The set of minimal parameter and result indices for differentiating the
|
||||
/// `apply` instruction.
|
||||
void collectMinimalIndicesForFunctionCall(
|
||||
ApplyInst *ai, const AutoDiffConfig &parentConfig,
|
||||
FullApplySite fai, const AutoDiffConfig &parentConfig,
|
||||
const DifferentiableActivityInfo &activityInfo,
|
||||
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> ¶mIndices,
|
||||
SmallVectorImpl<unsigned> &resultIndices);
|
||||
|
||||
@@ -77,9 +77,9 @@ private:
|
||||
/// For differentials: these are successor enums.
|
||||
llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;
|
||||
|
||||
/// Mapping from `apply` instructions in the original function to the
|
||||
/// Mapping from `apply` / `begin_apply` instructions in the original function to the
|
||||
/// corresponding linear map tuple type index.
|
||||
llvm::DenseMap<ApplyInst *, unsigned> linearMapIndexMap;
|
||||
llvm::DenseMap<FullApplySite, unsigned> linearMapIndexMap;
|
||||
|
||||
/// Mapping from predecessor-successor basic block pairs in the original
|
||||
/// function to the corresponding branching trace enum case.
|
||||
@@ -112,9 +112,9 @@ private:
|
||||
void populateBranchingTraceDecl(SILBasicBlock *originalBB,
|
||||
SILLoopInfo *loopInfo);
|
||||
|
||||
/// Given an `apply` instruction, conditionally gets a linear map tuple field
|
||||
/// AST type for its linear map function if it is active.
|
||||
Type getLinearMapType(ADContext &context, ApplyInst *ai);
|
||||
/// Given an `apply` / `begin_apply` instruction, conditionally gets a linear
|
||||
/// map tuple field AST type for its linear map function if it is active.
|
||||
Type getLinearMapType(ADContext &context, FullApplySite fai);
|
||||
|
||||
/// Generates linear map struct and branching enum declarations for the given
|
||||
/// function. Linear map structs are populated with linear map fields and a
|
||||
@@ -180,18 +180,18 @@ public:
|
||||
}
|
||||
|
||||
/// Finds the linear map index in the pullback tuple for the given
|
||||
/// `apply` instruction in the original function.
|
||||
unsigned lookUpLinearMapIndex(ApplyInst *ai) const {
|
||||
assert(ai->getFunction() == original);
|
||||
auto lookup = linearMapIndexMap.find(ai);
|
||||
/// `apply` / `begin_apply` instruction in the original function.
|
||||
unsigned lookUpLinearMapIndex(FullApplySite fas) const {
|
||||
assert(fas->getFunction() == original);
|
||||
auto lookup = linearMapIndexMap.find(fas);
|
||||
assert(lookup != linearMapIndexMap.end() &&
|
||||
"No linear map field corresponding to the given `apply`");
|
||||
return lookup->getSecond();
|
||||
}
|
||||
|
||||
Type lookUpLinearMapType(ApplyInst *ai) const {
|
||||
unsigned idx = lookUpLinearMapIndex(ai);
|
||||
return getLinearMapTupleType(ai->getParentBlock())->getElement(idx).getType();
|
||||
Type lookUpLinearMapType(FullApplySite fas) const {
|
||||
unsigned idx = lookUpLinearMapIndex(fas);
|
||||
return getLinearMapTupleType(fas->getParent())->getElement(idx).getType();
|
||||
}
|
||||
|
||||
bool hasHeapAllocatedContext() const {
|
||||
|
||||
@@ -56,6 +56,11 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
|
||||
CanSILFunctionType fromType,
|
||||
CanSILFunctionType toType);
|
||||
|
||||
SILValue reabstractCoroutine(
|
||||
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
|
||||
SILValue fn, CanSILFunctionType toType,
|
||||
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions);
|
||||
|
||||
/// Reabstracts the given function-typed value `fn` to the target type `toType`.
|
||||
/// Remaps substitutions using `remapSubstitutions`.
|
||||
SILValue reabstractFunction(
|
||||
|
||||
Reference in New Issue
Block a user