[AutoDiff upstream] Add SIL derivative function type calculation. (#29396)

Add `SILFunctionType::getAutoDiffDerivativeFunctionType`.

It computes the derivative `SILFunctionType` for an "original"
`SILFunctionType`, given:

- Differentiability parameter indices
- Differentiability result index
- Derivative function kind
- Derivative function generic signature (optional)
- Other auxiliary parameters

Add doc comments explaining typing rules, preconditions, and other details.

Partially resolves TF-1124.
Unblocks upstreaming other SIL differentiable programming infrastructure.
This commit is contained in:
Dan Zheng
2020-01-24 10:09:29 -08:00
committed by GitHub
parent 7b611fcb1a
commit 2d08a3f7e2
2 changed files with 274 additions and 0 deletions

View File

@@ -4390,6 +4390,89 @@ public:
const clang::FunctionType *getClangFunctionType() const;
/// Returns the type of the derivative function for the given parameter
/// indices, result index, derivative function kind, derivative function
/// generic signature (optional), and other auxiliary parameters.
///
/// Preconditions:
/// - Parameters corresponding to parameter indices must conform to
/// `Differentiable`.
/// - The result corresponding to the result index must conform to
/// `Differentiable`.
///
/// Typing rules, given:
/// - Original function type: $(T0, T1, ...) -> (R0, R1, ...)
///
/// Terminology:
/// - The derivative of a `Differentiable`-conforming type has the
/// `TangentVector` associated type. `TangentVector` is abbreviated as `Tan`
/// below.
/// - "wrt" parameters refers to parameters indicated by the parameter
/// indices.
/// - "wrt" result refers to the result indicated by the result index.
///
/// JVP derivative type:
/// - Takes original parameters.
/// - Returns original results, followed by a differential function, which
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
///
/// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
/// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
/// original results | derivatives wrt params | derivative wrt result
///
/// VJP derivative type:
/// - Takes original parameters.
/// - Returns original results, followed by a pullback function, which
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
///
/// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
/// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
/// original results | derivative wrt result | derivatives wrt params
///
/// A "constrained derivative generic signature" is computed from
/// `derivativeFunctionGenericSignature`, if specified. Otherwise, it is
/// computed from the original generic signature. A "constrained derivative
/// generic signature" requires all "wrt" parameters to conform to
/// `Differentiable`; this is important for correctness.
///
/// This "constrained derivative generic signature" is used for
/// parameter/result type lowering. It is used as the actual generic signature
/// of the derivative function type iff the original function type has a
/// generic signature and not all generic parameters are bound to concrete
/// types. Otherwise, no derivative generic signature is used.
///
/// Other properties of the original function type are copied exactly:
/// `ExtInfo`, coroutine kind, callee convention, yields, optional error
/// result, witness method conformance, etc.
///
/// Special cases:
/// - Reabstraction thunks have special derivative type calculation. The
/// original function-typed last parameter is transformed into a
/// `@differentiable` function-typed parameter in the derivative type. This
/// is necessary for the differentiation transform to support reabstraction
/// thunk differentiation because the function argument is opaque and cannot
/// be differentiated. Instead, the argument is made `@differentiable` and
/// reabstraction thunk JVP/VJP callers are responsible for passing a
/// `@differentiable` function.
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
/// derivative approaches. The last argument can simply be a
/// corresponding derivative function, instead of a `@differentiable`
/// function - this is more direct. It may be possible to implement
/// reabstraction thunk derivatives using "reabstraction thunks for
/// the original function's derivative", avoiding extra code generation.
///
/// Caveats:
/// - We may support multiple result indices instead of a single result index
/// eventually. At the SIL level, this enables differentiating wrt multiple
/// function results. At the Swift level, this enables differentiating wrt
/// multiple tuple elements for tuple-returning functions.
CanSILFunctionType getAutoDiffDerivativeFunctionType(
IndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
bool isReabstractionThunk = false);
ExtInfo getExtInfo() const {
return ExtInfo(Bits.SILFunctionType.ExtInfoBits, getClangFunctionType());
}