Lift temporary cross-file derivative registration restriction.
`@derivative` attribute type-checking simplications coming soon: TF-1099.
Original function and derivative function must have same access level, with one
exception: public original functions may have internal `@usableFromInline`
derivatives.
Canonicalizes `differentiable_function` instructions by filling in missing
derivative function operands.
Derivative function emission rules, based on the original function value:
- `function_ref`: look up differentiability witness with the exact or a minimal
superset derivative configuration. Emit a `differentiability_witness_function`
for the derivative function.
- `witness_method`: emit a `witness_method` with the minimal superset derivative
configuration for the derivative function.
- `class_method`: emit a `class_method` with the minimal superset derivative
configuration for the derivative function.
If an *actual* emitted derivative function has a superset derivative
configuration versus the *desired* derivative configuration, create a "subset
parameters thunk" to thunk the actual derivative to the desired type.
For `differentiable_function` instructions formed from curry thunk applications:
clone the curry thunk (with type `(Self) -> (T, ...) -> U`) and create a new
version with type `(Self) -> @differentiable (T, ...) -> U`.
Progress towards TF-1211.
The differentiation transform does the following:
- Canonicalizes differentiability witnesses by filling in missing derivative
function entries.
- Canonicalizes `differentiable_function` instructions by filling in missing
derivative function operands.
- If necessary, performs automatic differentiation: generating derivative
functions for original functions.
- When encountering non-differentiability code, produces a diagnostic and
errors out.
Partially resolves TF-1211: add the main canonicalization loop.
To incrementally stage changes, derivative functions are currently created
with empty bodies that fatal error with a nice message.
Derivative emitters will be upstreamed separately.
Previously, two conditions were necessary to enable differentiable programming:
- Using the `-enable-experimental-differentiable-programming` frontend flag.
- Importing the `_Differentiation` module.
Importing the `_Differentiation` module is the true condition because it
contains the required compiler-known `Differentiable` protocol. The frontend
flag is redundant and cumbersome.
Now, the frontend flag is removed.
Importing `_Differentiation` is the only condition.
* Add all [differential operators](https://github.com/apple/swift/blob/master/docs/DifferentiableProgramming.md#list-of-differential-operators).
* Add `withoutDerivative(at:)`, used for efficiently stopping the derivative propagation at a value and causing the derivative at the value to be zero.
* Add utility `differentiableFunction(from:)`, used for creating a `@differentiable` function from an original function and a derivative function.
Mostly work done by @marcrasi and @dan-zheng.
Partially resolves TF-843.
TODO:
* Add `AnyDerivative`.
* Add `Array.differentiableMap(_:)` and `differentiableReduce(_:_:)`.
Add mangling scheme for `@differentiable` and `@differentiable(linear)` function
types. Mangling support is important for debug information, among other things.
Update docs and add tests.
Resolves TF-948.
The `@transpose` attribute registers a function as the transpose of another
function-like declaration: a `func`, `init`, `subscript`, or `var` computed
property declaration.
The `@transpose` attribute also has an optional `wrt:` clause specifying the
linearity parameters, i.e. the parameters that are transposed with respect to.
The linearity parameters must conform to the `Differentiable` protocol and
satisfy `Self == TangentVector`.
If the `wrt:` clause is unspecified, the linearity parameters are inferred to be
all parameters that conform to `Differentiable` and that satisfy
`Self == TangentVector`.
`@transpose` attribute type-checking verifies that the type of the transpose
function declaration is consistent with the type of the referenced original
declaration and the linearity parameters.
Resolves TF-830.
Add `AdditiveArithmetic` derived conformances for structs and classes, gated by
the `-enable-experimental-differentiable-programming` flag.
Structs and classes whose stored properties all conform to `Differentiable` can
derive `Differentiable`:
- `associatedtype TangentVector: Differentiable & AdditiveArithmetic`
- Member `TangentVector` structs are synthesized whose stored properties are
all `var` stored properties that conform to `Differentiable` and that are
not `@noDerivative`.
- `mutating func move(along: TangentVector)`
The `@noDerivative` attribute may be declared on stored properties to opt out of
inclusion in synthesized `TangentVector` structs.
Some stored properties cannot be used in `TangentVector` struct synthesis and
are implicitly marked as `@noDerivative`, with a warning:
- `let` stored properties.
- These cannot be updated by `mutating func move(along: TangentVector)`.
- Non-`Differentiable`-conforming stored properties.
`@noDerivative` also implies `@_semantics("autodiff.nonvarying")`, which is
relevant for differentiable activity analysis.
Add type-checking and SILGen tests.
Resolves TF-845.
Add the `@differentiable` function conversion pipeline:
- New expressions that convert between `@differentiable`,
`@differentiable(linear)`, and non-`@differentiable` functions:
- `DifferentiableFunction`
- `LinearFunction`
- `DifferentiableFunctionExtractOriginal`
- `LinearFunctionExtractOriginal`
- `LinearToDifferentiableFunction`
- All the AST handling (e.g. printing) necessary for those expressions.
- SILGen for those expressions.
- CSApply code that inserts these expressions to implicitly convert between
the various function types.
- Sema tests for the implicit conversions.
- SILGen tests for the SILGen of these expressions.
Resolves TF-833.
Add type checking for `@differentiable` function types:
- Check that parameters and results conform to `Differentiable`.
- Implicitly conform parameters and results whose types are generic parameters
to `Differentiable`.
- Upstream most of the differentiable_func_type_type_checking.swift test from
`tensorflow` branch. A few function conversion tests have not been added
because they depend on the `@differentiable` function conversion pipeline.
Diagnose gracefully when the `Differentiable` protocol is unavailable because
`_Differentiation` has not been imported.
Resolves TF-823 and TF-1219.
Add `linear_function` and `linear_function_extract` instructions.
`linear_function` creates a `@differentiable(linear)` function-typed value from
an original function operand and a transpose function operand (optional).
`linear_function_extract` extracts either the original or transpose function
value from a `@differentiable(linear)` function.
Resolves TF-1142 and TF-1143.
Generate `differentiable_function` and `differentiable_function_extract` in
derivative witness table/vtable thunks.
`differentiation_function` is later canonicalized by the differentiation
transform.
Add SIL FileCheck tests.
Add `AdditiveArithmetic` derived conformances for structs, gated by the
`-enable-experimential-additive-arithmetic-derivation` flag.
Structs whose stored properties all conform to `AdditiveArithmetic` can derive
`AdditiveArithmetic`:
- `static var zero: Self`
- `static func +(lhs: Self, rhs: Self) -> Self`
- `static func -(lhs: Self, rhs: Self) -> Self`
- An "effective memberwise initializer":
- Either a synthesized memberwise initializer or a user-defined initializer
with the same type.
Effective memberwise initializers are used only by derived conformances for
`Self`-returning protocol requirements like `AdditiveArithmetic.+`, which
require memberwise initialization.
Resolves TF-844.
Unblocks TF-845: upstream `Differentiable` derived conformances.
Previously, all witnesses of a `@differentiable` protocol requirement were
required to have the same attribute (or one with superset parameter indices).
However, this leads to many annotations on witnesses and is not ideal for
usability. `@differentiable` attributes are really only significant on
public witnesses, so that they are clearly `@differentiable` at a glance (in
source code, interface files, and API documentation), without looking through
protocol conformance hierarchies.
Now, only *public* witnesses of `@differentiable` protocol requirements are
required to have the same attribute (or one with superset parameter indices).
For less-visible witnesses, an implicit `@differentiable` attribute is created
with the same configuration as the requirement's.
Resolves TF-1117.
Upstreams #29771 from tensorflow branch.
Define type signatures and SILGen for the following builtins:
```
/// Applies the {jvp|vjp} of `f` to `arg1`, ..., `argN`.
func applyDerivative_arityN_{jvp|vjp}(f, arg1, ..., argN) -> jvp/vjp return type
/// Applies the transpose of `f` to `arg`.
func applyTranspose_arityN(f, arg) -> transpose return type
/// Makes a differentiable function from the given `original`, `jvp`, and
/// `vjp` functions.
func differentiableFunction_arityN(original, jvp, vjp)
/// Makes a linear function from the given `original` and `transpose` functions.
func linearFunction_arityN(original, transpose)
```
Add SILGen FileCheck tests for all builtins.
Remove logic for parsing and diagnosing `jvp:` and `vjp:` arguments for
`@differentiable` attribute. No logic remains for handling those arguments.
Follow-up to TF-1001.
Serialize "is linear?" flag, differentiability parameter indices, and
differentiability generic signature.
Deserialization has some ad-hoc logic for setting the original declaration and
parameter indices for `@differentiable` attributes because
`DeclDeserializer::deserializeDeclAttributes` does not have access to the
original declaration.
Resolves TF-836.
Delete `@differentiable` attribute `jvp:` and `vjp:` arguments for derivative
registration. `@derivative` attribute is now the canonical way to register
derivatives.
Resolves TF-1001.
* Move muliti-functionality SIL tests to test/AutoDiff/SIL.
* Add `with_derivative` clause to `differentiable_function` instructions.
Otherwise, IRGen for test/AutoDiff/SIL/differentiable_function_inst.sil fails on
tensorflow branch because the differentiation transform cannot differentiate
external functions.
Add test/AutoDiff/lit.local.cfg: run tests only when `differentiable_programming`
is enabled in lit. With this, individual tests no longer need
`REQUIRES: differentiable_programming`.
Move multi-functionality SIL tests from test/AutoDiff/SIL/Serialization to
test/AutoDiff/SIL.
Garden test filenames.
Add `differentiable_function` and `differentiable_function_extract`
instructions.
`differentiable_function` creates a `@differentiable` function-typed
value from an original function operand and derivative function operands
(optional).
`differentiable_function_extract` extracts either the original or
derivative function value from a `@differentiable` function.
The differentiation transform canonicalizes `differentiable_function`
instructions, filling in derivative function operands if missing.
Resolves TF-1139 and TF-1140.
`@differentiable` attribute on protocol requirements and non-final class
members now produces derivative function entries in witness tables and vtables.
This enables `witness_method` and `class_method` differentiation.
Existing type-checking rules:
- Witness declarations of `@differentiable` protocol requirements must have a
`@differentiable` attribute with the same configuration (or a configuration
with superset parameter indices).
- Witness table derivative function entries are SILGen'd for `@differentiable`
witness declarations.
- Class vtable derivative function entries are SILGen'd for non-final
`@differentiable` class members.
- These derivative entries can be overridden or inherited, just like other
vtable entries.
Resolves TF-1212.
`@differentiable` attribute on protocol requirements and non-final class members
will produce derivative function entries in witness tables and vtables.
This patch adds an optional derivative function configuration
(`AutoDiffDerivativeFunctionIdentifier`) to `SILDeclRef` to represent these
derivative function entries.
Derivative function configurations consist of:
- A derivative function kind (JVP or VJP).
- Differentiability parameter indices.
Resolves TF-1209.
Enables TF-1212: upstream derivative function entries in witness tables/vtables.
Generate SIL differentiability witnesses from `@differentiable` and
`@derivative` declaration attributes.
Add SILGen utilities for:
- Emiting differentiability witnesses.
- Creating derivative function thunks, which are used as entries in
differentiability witnesses.
When users register a custom derivative function, it is necessary to create a
thunk with the expected derivative type computed from the original function's
type. This is important for consistent typing and consistent differentiability
witness entry mangling.
See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details.
Resolves TF-1138.
Semantically, an `inout` parameter is both a parameter and a result.
`@differentiable` and `@derivative` attributes now support original functions
with one "semantic result": either a formal result or an `inout` parameter.
Derivative typing rules for functions with `inout` parameters are now defined.
The differential/pullback type of a function with `inout` differentiability
parameters also has `inout` parameters. This is ideal for performance.
Differential typing rules:
- Case 1: original function has no `inout` parameters.
- Original: `(T0, T1, ...) -> R`
- Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
- Case 2: original function has a non-wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Differential: `(T0.Tan, ...) -> T1.Tan`
- Case 3: original function has a wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
Pullback typing rules:
- Case 1: original function has no `inout` parameters.
- Original: `(T0, T1, ...) -> R`
- Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
- Case 2: original function has a non-wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Pullback: `(T1.Tan) -> (T0.Tan, ...)`
- Case 3: original function has a wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
Resolves TF-1164.
Remove meaningless `assert(false)` assertion from
`hasOverridingDifferentiableAttribute` in `TypeCheckDeclOverride.cpp`.
The assertion served no purpose and can safely be removed.
Resolves TF-1167. Add test.
Diagnose `@derivative` and `@transpose` attributes that are missing the
required comma before the `wrt:` clause:
```
@derivative(of: foo wrt: x)
@transpose(of: bar wrt: (x, y))
```
Previously, this was undiagnosed.
Resolves TF-1168.
- Support `@differentiable` and `@derivative` attributes for original
initializers in final classes. Reject original initializers in non-final
classes.
- Synchronize tests.