Commit Graph

9 Commits

Author SHA1 Message Date
Dan Zheng
24445dd2e2 [AutoDiff upstream] Add differentiability witness SILGen. (#30545)
Generate SIL differentiability witnesses from `@differentiable` and
`@derivative` declaration attributes.

Add SILGen utilities for:
- Emiting differentiability witnesses.
- Creating derivative function thunks, which are used as entries in
  differentiability witnesses.

When users register a custom derivative function, it is necessary to create a
thunk with the expected derivative type computed from the original function's
type. This is important for consistent typing and consistent differentiability
witness entry mangling.

See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details.

Resolves TF-1138.
2020-03-21 02:05:04 -07:00
John McCall
ceff414820 Distinguish invocation and pattern substitutions on SILFunctionType.
In order to allow this, I've had to rework the syntax of substituted function types; what was previously spelled `<T> in () -> T for <X>` is now spelled `@substituted <T> () -> T for <X>`.  I think this is a nice improvement for readability, but it did require me to churn a lot of test cases.

Distinguishing the substitutions has two chief advantages over the existing representation.  First, the semantics seem quite a bit clearer at use points; the `implicit` bit was very subtle and not always obvious how to use.  More importantly, it allows the expression of generic function types that must satisfy a particular generic abstraction pattern, which was otherwise impossible to express.

As an example of the latter, consider the following protocol conformance:

```
protocol P { func foo() }
struct A<T> : P { func foo() {} }
```

The lowered signature of `P.foo` is `<Self: P> (@in_guaranteed Self) -> ()`.  Without this change, the lowered signature of `A.foo`'s witness would be `<T> (@in_guaranteed A<T>) -> ()`, which does not preserve information about the conformance substitution in any useful way.  With this change, the lowered signature of this witness could be `<T> @substituted <Self: P> (@in_guaranteed Self) -> () for <A<T>>`, which nicely preserves the exact substitutions which relate the witness to the requirement.

When we adopt this, it will both obviate the need for the special witness-table conformance field in SILFunctionType and make it far simpler for the SILOptimizer to devirtualize witness methods.  This patch does not actually take that step, however; it merely makes it possible to do so.

As another piece of unfinished business, while `SILFunctionType::substGenericArgs()` conceptually ought to simply set the given substitutions as the invocation substitutions, that would disturb a number of places that expect that method to produce an unsubstituted type.  This patch only set invocation arguments when the generic type is a substituted type, which we currently never produce in type-lowering.

My plan is to start by producing substituted function types for accessors.  Accessors are an important case because the coroutine continuation function is essentially an implicit component of the function type which the current substitution rules simply erase the intended abstraction of.  They're also used in narrower ways that should exercise less of the optimizer.
2020-03-07 16:25:59 -05:00
Dan Zheng
697c722a5f [AutoDiff] Type-checking support for inout parameter differentiation. (#29959)
Semantically, an `inout` parameter is both a parameter and a result.

`@differentiable` and `@derivative` attributes now support original functions
with one "semantic result": either a formal result or an `inout` parameter.

Derivative typing rules for functions with `inout` parameters are now defined.

The differential/pullback type of a function with `inout` differentiability
parameters also has `inout` parameters. This is ideal for performance.

Differential typing rules:
- Case 1: original function has no `inout` parameters.
  - Original:     `(T0, T1, ...) -> R`
  - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
- Case 2: original function has a non-wrt `inout` parameter.
  - Original:     `(T0, inout T1, ...) -> Void`
  - Differential: `(T0.Tan, ...) -> T1.Tan`
- Case 3: original function has a wrt `inout` parameter.
  - Original:     `(T0, inout T1, ...) -> Void`
  - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`

Pullback typing rules:
- Case 1: original function has no `inout` parameters.
  - Original: `(T0, T1, ...) -> R`
  - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
- Case 2: original function has a non-wrt `inout` parameter.
  - Original: `(T0, inout T1, ...) -> Void`
  - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
- Case 3: original function has a wrt `inout` parameter.
  - Original: `(T0, inout T1, ...) -> Void`
  - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`

Resolves TF-1164.
2020-02-21 09:47:53 -08:00
Dan Zheng
a49428ca7c [AutoDiff upstream] Add differentiability_witness_function instruction. (#29765)
The `differentiability_witness_function` instruction looks up a
differentiability witness function (JVP, VJP, or transpose) for a referenced
function via SIL differentiability witnesses.

Add round-trip parsing/serialization and IRGen tests.

Notes:
- Differentiability witnesses for linear functions require more support.
  `differentiability_witness_function [transpose]` instructions do not yet
  have IRGen.
- Nothing currently generates `differentiability_witness_function` instructions.
  The differentiation transform does, but it hasn't been upstreamed yet.

Resolves TF-1141.
2020-02-13 16:55:46 -08:00
Dan Zheng
c2ac96f09f [AutoDiff upstream] Add SIL transpose function type calculation. (#29755)
Add `SILFunctionType::getAutoDiffTransposeFunctionType`.

It computes the transpose `SILFucntionType` for an original `SILFunctionType`,
given:

- Linearity parameter indices
- Transpose function generic signature (optional)
- Other auxiliary parameters

Add doc comments explaining typing rules, preconditions, and other details.
Add `isTranspose` flag to `autodiff::getConstrainedDerivativeGenericSignature`.

Partially resolves TF-1125.
Unblocks TF-1141: upstream `differentiability_witness_function` instruction.
2020-02-11 04:08:25 -08:00
Dan Zheng
a174243159 [AutoDiff upstream] Add SIL differentiability witness IRGen. (#29704)
SIL differentiability witnesses are a new top-level SIL construct mapping
an "original" SIL function and derivative configuration to derivative SIL
functions.

This patch adds `SILDifferentiabilityWitness` IRGen.

`SILDifferentiabilityWitness` has a fixed `{ i8*, i8* }` layout:
JVP and VJP derivative function pointers.

Resolves TF-1146.
2020-02-07 14:10:34 -08:00
Dan Zheng
0c0018fb9b [AutoDiff] Improve getConstrainedDerivativeGenericSignature helper. (#29620)
Move `getConstrainedDerivativeGenericSignature` under `autodiff` namespace.
Improve naming and documentation.
2020-02-03 17:40:49 -08:00
Dan Zheng
a98fe74044 [AutoDiff upstream] Add TangentSpace. (#29107)
`TangentSpace` represents the tangent space of a type.

- For `Differentiable`-conforming types:
  - The tangent space is the `TangentVector` associated type.
- For tuple types:
  - The tangent space is a tuple of the elements' tangent space types, for the
    elements that have a tangent space.
- Other types have no tangent space.

`TypeBase::getAutoDiffTangentSpace` gets the tangent space of a type.

`TangentSpace` is used to:
- Compute the derivative function type of a given original function type.
- Compute the type of tangent/adjoint values during automatic differentiation.

Progress towards TF-828: upstream `@differentiable` attribute type-checking.
2020-01-12 12:29:08 -08:00
Dan Zheng
bb1052ca3e [AutoDiff upstream] Upstream @derivative attribute type-checking. (#28738)
The `@derivative` attribute registers a function as a derivative of another
function-like declaration: a `func`, `init`, `subscript`, or `var` computed
property declaration.

The `@derivative` attribute also has an optional `wrt:` clause specifying the
parameters that are differentiated "with respect to", i.e. the differentiation
parameters. The differentiation parameters must conform to the `Differentiable`
protocol.

If the `wrt:` clause is unspecified, the differentiation parameters are inferred
to be all parameters that conform to `Differentiable`.

`@derivative` attribute type-checking verifies that the type of the derivative
function declaration is consistent with the type of the referenced original
declaration and the differentiation parameters.

The `@derivative` attribute is gated by the
`-enable-experimental-differentiable-programming` flag.

Resolves TF-829.
2019-12-12 18:18:18 -08:00