mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[AutoDiff upstream] Add SIL derivative function type calculation. (#29396)
Add `SILFunctionType::getAutoDiffDerivativeFunctionType`. It computes the derivative `SILFunctionType` for an "original" `SILFunctionType`, given: - Differentiability parameter indices - Differentiability result index - Derivative function kind - Derivative function generic signature (optional) - Other auxiliary parameters Add doc comments explaining typing rules, preconditions, and other details. Partially resolves TF-1124. Unblocks upstreaming other SIL differentiable programming infrastructure.
This commit is contained in:
@@ -4390,6 +4390,89 @@ public:
|
||||
|
||||
const clang::FunctionType *getClangFunctionType() const;
|
||||
|
||||
/// Returns the type of the derivative function for the given parameter
|
||||
/// indices, result index, derivative function kind, derivative function
|
||||
/// generic signature (optional), and other auxiliary parameters.
|
||||
///
|
||||
/// Preconditions:
|
||||
/// - Parameters corresponding to parameter indices must conform to
|
||||
/// `Differentiable`.
|
||||
/// - The result corresponding to the result index must conform to
|
||||
/// `Differentiable`.
|
||||
///
|
||||
/// Typing rules, given:
|
||||
/// - Original function type: $(T0, T1, ...) -> (R0, R1, ...)
|
||||
///
|
||||
/// Terminology:
|
||||
/// - The derivative of a `Differentiable`-conforming type has the
|
||||
/// `TangentVector` associated type. `TangentVector` is abbreviated as `Tan`
|
||||
/// below.
|
||||
/// - "wrt" parameters refers to parameters indicated by the parameter
|
||||
/// indices.
|
||||
/// - "wrt" result refers to the result indicated by the result index.
|
||||
///
|
||||
/// JVP derivative type:
|
||||
/// - Takes original parameters.
|
||||
/// - Returns original results, followed by a differential function, which
|
||||
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
|
||||
///
|
||||
/// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
|
||||
/// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
|
||||
/// original results | derivatives wrt params | derivative wrt result
|
||||
///
|
||||
/// VJP derivative type:
|
||||
/// - Takes original parameters.
|
||||
/// - Returns original results, followed by a pullback function, which
|
||||
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
|
||||
///
|
||||
/// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
|
||||
/// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
|
||||
/// original results | derivative wrt result | derivatives wrt params
|
||||
///
|
||||
/// A "constrained derivative generic signature" is computed from
|
||||
/// `derivativeFunctionGenericSignature`, if specified. Otherwise, it is
|
||||
/// computed from the original generic signature. A "constrained derivative
|
||||
/// generic signature" requires all "wrt" parameters to conform to
|
||||
/// `Differentiable`; this is important for correctness.
|
||||
///
|
||||
/// This "constrained derivative generic signature" is used for
|
||||
/// parameter/result type lowering. It is used as the actual generic signature
|
||||
/// of the derivative function type iff the original function type has a
|
||||
/// generic signature and not all generic parameters are bound to concrete
|
||||
/// types. Otherwise, no derivative generic signature is used.
|
||||
///
|
||||
/// Other properties of the original function type are copied exactly:
|
||||
/// `ExtInfo`, coroutine kind, callee convention, yields, optional error
|
||||
/// result, witness method conformance, etc.
|
||||
///
|
||||
/// Special cases:
|
||||
/// - Reabstraction thunks have special derivative type calculation. The
|
||||
/// original function-typed last parameter is transformed into a
|
||||
/// `@differentiable` function-typed parameter in the derivative type. This
|
||||
/// is necessary for the differentiation transform to support reabstraction
|
||||
/// thunk differentiation because the function argument is opaque and cannot
|
||||
/// be differentiated. Instead, the argument is made `@differentiable` and
|
||||
/// reabstraction thunk JVP/VJP callers are responsible for passing a
|
||||
/// `@differentiable` function.
|
||||
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
|
||||
/// derivative approaches. The last argument can simply be a
|
||||
/// corresponding derivative function, instead of a `@differentiable`
|
||||
/// function - this is more direct. It may be possible to implement
|
||||
/// reabstraction thunk derivatives using "reabstraction thunks for
|
||||
/// the original function's derivative", avoiding extra code generation.
|
||||
///
|
||||
/// Caveats:
|
||||
/// - We may support multiple result indices instead of a single result index
|
||||
/// eventually. At the SIL level, this enables differentiating wrt multiple
|
||||
/// function results. At the Swift level, this enables differentiating wrt
|
||||
/// multiple tuple elements for tuple-returning functions.
|
||||
CanSILFunctionType getAutoDiffDerivativeFunctionType(
|
||||
IndexSubset *parameterIndices, unsigned resultIndex,
|
||||
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
|
||||
LookupConformanceFn lookupConformance,
|
||||
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
|
||||
bool isReabstractionThunk = false);
|
||||
|
||||
ExtInfo getExtInfo() const {
|
||||
return ExtInfo(Bits.SILFunctionType.ExtInfoBits, getClangFunctionType());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user