Add test/AutoDiff/lit.local.cfg: run tests only when `differentiable_programming`
is enabled in lit. With this, individual tests no longer need
`REQUIRES: differentiable_programming`.
Move multi-functionality SIL tests from test/AutoDiff/SIL/Serialization to
test/AutoDiff/SIL.
Garden test filenames.
Add `differentiable_function` and `differentiable_function_extract`
instructions.
`differentiable_function` creates a `@differentiable` function-typed
value from an original function operand and derivative function operands
(optional).
`differentiable_function_extract` extracts either the original or
derivative function value from a `@differentiable` function.
The differentiation transform canonicalizes `differentiable_function`
instructions, filling in derivative function operands if missing.
Resolves TF-1139 and TF-1140.
`@differentiable` attribute on protocol requirements and non-final class
members now produces derivative function entries in witness tables and vtables.
This enables `witness_method` and `class_method` differentiation.
Existing type-checking rules:
- Witness declarations of `@differentiable` protocol requirements must have a
`@differentiable` attribute with the same configuration (or a configuration
with superset parameter indices).
- Witness table derivative function entries are SILGen'd for `@differentiable`
witness declarations.
- Class vtable derivative function entries are SILGen'd for non-final
`@differentiable` class members.
- These derivative entries can be overridden or inherited, just like other
vtable entries.
Resolves TF-1212.
`@differentiable` attribute on protocol requirements and non-final class members
will produce derivative function entries in witness tables and vtables.
This patch adds an optional derivative function configuration
(`AutoDiffDerivativeFunctionIdentifier`) to `SILDeclRef` to represent these
derivative function entries.
Derivative function configurations consist of:
- A derivative function kind (JVP or VJP).
- Differentiability parameter indices.
Resolves TF-1209.
Enables TF-1212: upstream derivative function entries in witness tables/vtables.
Generate SIL differentiability witnesses from `@differentiable` and
`@derivative` declaration attributes.
Add SILGen utilities for:
- Emiting differentiability witnesses.
- Creating derivative function thunks, which are used as entries in
differentiability witnesses.
When users register a custom derivative function, it is necessary to create a
thunk with the expected derivative type computed from the original function's
type. This is important for consistent typing and consistent differentiability
witness entry mangling.
See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details.
Resolves TF-1138.
Semantically, an `inout` parameter is both a parameter and a result.
`@differentiable` and `@derivative` attributes now support original functions
with one "semantic result": either a formal result or an `inout` parameter.
Derivative typing rules for functions with `inout` parameters are now defined.
The differential/pullback type of a function with `inout` differentiability
parameters also has `inout` parameters. This is ideal for performance.
Differential typing rules:
- Case 1: original function has no `inout` parameters.
- Original: `(T0, T1, ...) -> R`
- Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
- Case 2: original function has a non-wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Differential: `(T0.Tan, ...) -> T1.Tan`
- Case 3: original function has a wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
Pullback typing rules:
- Case 1: original function has no `inout` parameters.
- Original: `(T0, T1, ...) -> R`
- Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
- Case 2: original function has a non-wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Pullback: `(T1.Tan) -> (T0.Tan, ...)`
- Case 3: original function has a wrt `inout` parameter.
- Original: `(T0, inout T1, ...) -> Void`
- Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
Resolves TF-1164.
Remove meaningless `assert(false)` assertion from
`hasOverridingDifferentiableAttribute` in `TypeCheckDeclOverride.cpp`.
The assertion served no purpose and can safely be removed.
Resolves TF-1167. Add test.
Diagnose `@derivative` and `@transpose` attributes that are missing the
required comma before the `wrt:` clause:
```
@derivative(of: foo wrt: x)
@transpose(of: bar wrt: (x, y))
```
Previously, this was undiagnosed.
Resolves TF-1168.
- Support `@differentiable` and `@derivative` attributes for original
initializers in final classes. Reject original initializers in non-final
classes.
- Synchronize tests.
Attempt to look up original function before checking whether the `value:` result
conforms to `Differentiable`.
This improves diagnostics: "original function not found" should be diagnosed as
early as possible.
Previously, `@derivative` attribute type-checking produced a confusing error
referencing unbound types `T` and `U`:
```
'@derivative(of:)' attribute requires function to return a two-element tuple of
type '(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or
'(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'
```
Now, the error is less confusing:
```
'@derivative(of:)' attribute requires function to return a two-element tuple;
first element must have label 'value:' and second element must have label
'pullback:' or 'differential:'
```
This is necessary because the `Differentiable` protocol exists in stdlib core
on `tensorflow` branch but in the `_Differentiation` module on `master` branch.
The robust solution is to add auto-import `_Differentiation` logic to `tensorflow`.
The `differentiability_witness_function` instruction looks up a
differentiability witness function (JVP, VJP, or transpose) for a referenced
function via SIL differentiability witnesses.
Add round-trip parsing/serialization and IRGen tests.
Notes:
- Differentiability witnesses for linear functions require more support.
`differentiability_witness_function [transpose]` instructions do not yet
have IRGen.
- Nothing currently generates `differentiability_witness_function` instructions.
The differentiation transform does, but it hasn't been upstreamed yet.
Resolves TF-1141.
There is currently a difference between the tensorflow branch and the
master branch. On tensorflow, the differentiability support is merged
into the standard library. This changes the decoration of the witness.
Loosen the test to accept either.
We should change the tensorflow branch to generate the
`_Differentiation` module with the support and then auto-import the
module in the longer term. This can be gated by the
`-enable-experimental-autodifferentiation` flag to the driver to gain
the same behaviour on both the branches.
SIL differentiability witnesses are a new top-level SIL construct mapping
an "original" SIL function and derivative configuration to derivative SIL
functions.
This patch adds `SILDifferentiabilityWitness` IRGen.
`SILDifferentiabilityWitness` has a fixed `{ i8*, i8* }` layout:
JVP and VJP derivative function pointers.
Resolves TF-1146.
SIL differentiability witnesses are a new top-level SIL construct mapping
an "original" SIL function and derivative configuration to derivative SIL
functions.
This patch adds `SILDifferentiabilityWitness` serialization/deserialization.
Resolves TF-1136.
SIL differentiability witnesses are a new top-level SIL construct mapping
"original" SIL functions to derivative SIL functions.
SIL differentiability witnesses have the following components:
- "Original" `SILFunction`.
- SIL linkage.
- Differentiability parameter indices (`IndexSubset`).
- Differentiability result indices (`IndexSubset`).
- Derivative `GenericSignature` representing differentiability generic
requirements (optional).
- JVP derivative `SILFunction` (optional).
- VJP derivative `SILFunction` (optional).
- "Is serialized?" bit.
This patch adds the `SILDifferentiabilityWitness` data structure, with
documentation, parsing, and printing.
Resolves TF-911.
Todos:
- TF-1136: upstream `SILDifferentiabilityWitness` serialization.
- TF-1137: upstream `SILDifferentiabilityWitness` verification.
- TF-1138: upstream `SILDifferentiabilityWitness` SILGen from
`@differentiable` and `@derivative` attributes.
- TF-20: robust mangling for `SILDifferentiabilityWitness` names.
The `@noDerivative` attribute marks the non-differentiability parameters of a
`@differentiable` function type. All parameters except those marked with
`@noDerivative` are differentiability parameters.
For example, `@differentiable (Float, @noDerivative Float) -> Float` is only
differentiable with respect to its first parameter.
The `@noDerivative` attribute is represented as a
`SILParameterDifferentiability` bit on `SILParameterInfo`.
Add round-trip serialization tests.
Resolves TF-872.
For protocol requirements and class members with `@differentiable` attribute,
conforming types and subclasses must have the same `@differentiable` attribute
(or one with a superset of differentiability parameters) on implementing/
overriding declarations.
For implementing/overriding declarations that are missing a `@differentiable`
attribute, emit a fix-it that adds the missing attribute.
Resolves TF-1118.
The `@differentiable` attribute marks a function as differentiable.
Example:
```
@differentiable(wrt: x, jvp: derivativeFoo where T: Differentiable)
func id<T>(_ x: T) -> T { x }
```
The `@differentiable` attribute has an optional `wrt:` clause specifying the
parameters that are differentiated "with respect to", i.e. the differentiability
parameters. The differentiability parameters must conform to the
`Differentiable` protocol.
If the `wrt:` clause is unspecified, the differentiability parameters are
currently inferred to be all parameters that conform to `Differentiable`.
The `@differentiable` attribute also has optional `jvp:` and `vjp:` labels
for registering derivative functions. These labels are deprecated in favor of
the `@derivative` attribute and will be removed soon.
The `@differentiable` attribute also has an optional `where` clause, specifying
extra differentiability requirements for generic functions.
The `@differentiable` attribute is gated by the
`-enable-experimental-differentiable-programming` flag.
Code changes:
- Add `DifferentiableAttributeTypeCheckRequest`.
- Currently, the request returns differentiability parameter indices, while
also resolving `JVPFunction`, `VJPFunction`, and
`DerivativeGenericSignature` and mutating them in-place in
`DifferentiableAttr`. This was the simplest approach that worked without
introducing request cycles.
- Add "is type-checked" bit to `DifferentiableAttr`.
- Alternatively, I tried changing `DifferentiableAttributeTypeCheckRequest` to
use `CacheKind::Cache` instead of `CacheKind::SeparatelyCached`, but it did
not seem to work: `@differentiable` attributes in non-primary-files were
left unchecked.
Type-checking rules (summary):
- `@differentiable` attribute must be declared on a function-like "original"
declaration: `func`, `init`, `subscript`, `var` (computed properties only).
- Parsed differentiability parameters must be valid (if they exist).
- Parsed `where` clause must be valid (if it exists).
- Differentiability parameters must all conform to `Differentiable`.
- Original result must all conform to `Differentiable`.
- If JVP/VJP functions are specified, they must match the expected type.
- `@differentiable(jvp:vjp:)` for derivative registration is deprecated in
favor of `@derivative` attribute, and will be removed soon.
- Duplicate `@differentiable` attributes with the same differentiability
parameters are invalid.
- For protocol requirements and class members with `@differentiable` attribute,
conforming types and subclasses must have the same `@differentiable` attribute
(or one with a superset of differentiability parameter indices) on
implementing/overriding declarations.
Enable qualified declaration names in `@derivative` attribute, just like
`@transpose` attribute.
`DerivativeAttr` now stores a base type `TypeRepr *`, which is non-null for
parsed attributes that reference a qualified original declaration.
Add `TypeResolutionFlags::AllowModule` flag to enable module lookup via
`TypeChecker::lookupMember` given a `ModuleType`.
Add tests for type-qualified and module-qualified declaration names.
Resolves TF-1058.
The new _Differentiable module is not available in any shipping OS release, but its public API currently doesn’t have availability, either.
Temporarily disable tests that import it when we’re testing with OS-provided libraries.
The typecheck test test/AutoDiff/stdlib/differentiable_protocol.swift imports _Differentable but still somehow succeeds in these configs, so leave that one enabled.
rdar://57975086
The `@transpose(of:)` attribute registers a function as a transpose of another
function. This patch adds the `@transpose(of:)` attribute definition, syntax,
parsing, and printing.
Resolves TF-827.
Todos:
- Type-checking (TF-830, TF-1060).
- Enable serialization (TF-838).
- Use module-qualified names instead of custom qualified name syntax/parsing
(TF-1066).
Upstream `@derivative` attribute serialization/deserialization.
Test all original declaration kinds and various `wrt:` parameter clauses.
Resolves TF-837.
The `@derivative` attribute registers a function as a derivative of another
function-like declaration: a `func`, `init`, `subscript`, or `var` computed
property declaration.
The `@derivative` attribute also has an optional `wrt:` clause specifying the
parameters that are differentiated "with respect to", i.e. the differentiation
parameters. The differentiation parameters must conform to the `Differentiable`
protocol.
If the `wrt:` clause is unspecified, the differentiation parameters are inferred
to be all parameters that conform to `Differentiable`.
`@derivative` attribute type-checking verifies that the type of the derivative
function declaration is consistent with the type of the referenced original
declaration and the differentiation parameters.
The `@derivative` attribute is gated by the
`-enable-experimental-differentiable-programming` flag.
Resolves TF-829.
Add `Differentiable` conformances for floating-point types to the
`_Differentiation` module. The `TangentVector` associated type for
floating-point types is `Self`.
This design adheres to the differentiable programming manifesto:
docs/DifferentiableProgramming.md.
Partially resolves TF-1052.
The `@derivative(of:)` attribute registers a function as a derivative of another
function. This patch adds the `@derivative(of:)` attribute definition, syntax,
parsing, and printing.
Resolves TF-826.
Todos:
- Type-checking (TF-829).
- Serialization (TF-837).
- Move `DifferentiableAttr` definition above `DeclAttributes` in
include/swift/AST/Attr.h, like other attributes.
- Remove unnecessary arguments from `DifferentiableAttr::DifferentiableAttr`
and `DifferentiableAttr::setDerivativeGenericSignature`.
- Add libSyntax test for `@differentiable` attributes.