Commit Graph

4 Commits

Author SHA1 Message Date
Dan Zheng
d3b6b89de6 [AutoDiff] Support multiple differentiability result indices in SIL. (#32206)
`DifferentiableFunctionInst` now stores result indices.
`SILAutoDiffIndices` now stores result indices instead of a source index.

`@differentiable` SIL function types may now have multiple differentiability
result indices and `@noDerivative` resutls.

`@differentiable` AST function types do not have `@noDerivative` results (yet),
so this functionality is not exposed to users.

Resolves TF-689 and TF-1256.

Infrastructural support for TF-983: supporting differentiation of `apply`
instructions with multiple active semantic results.
2020-06-05 16:25:17 -07:00
Dan Zheng
2eb460de4d [AutoDiff upstream] Add forward-mode differentiation. (#30878)
JVP functions are forward-mode derivative functions. They take original
arguments and return original results and a differential function. Differential
functions take derivatives wrt arguments and return derivatives wrt results.

`JVPEmitter` is a cloner that emits JVP and differential functions at the same
time. In JVP functions, function applications are replaced with JVP function
applications. In differential functions, function applications are replaced
with differential function applications.

In JVP functions, each basic block takes a differential struct containing callee
differentials. These structs are consumed by differential functions.
2020-04-08 11:29:21 -07:00
Dan Zheng
146c11ec80 [AutoDiff upstream] Add differentiable_function canonicalization. (#30818)
Canonicalizes `differentiable_function` instructions by filling in missing
derivative function operands.

Derivative function emission rules, based on the original function value:

- `function_ref`: look up differentiability witness with the exact or a minimal
  superset derivative configuration. Emit a `differentiability_witness_function`
  for the derivative function.
- `witness_method`: emit a `witness_method` with the minimal superset derivative
  configuration for the derivative function.
- `class_method`: emit a `class_method` with the minimal superset derivative
  configuration for the derivative function.

If an *actual* emitted derivative function has a superset derivative
configuration versus the *desired* derivative configuration, create a "subset
parameters thunk" to thunk the actual derivative to the desired type.

For `differentiable_function` instructions formed from curry thunk applications:
clone the curry thunk (with type `(Self) -> (T, ...) -> U`) and create a new
version with type `(Self) -> @differentiable (T, ...) -> U`.

Progress towards TF-1211.
2020-04-05 20:19:10 -07:00
Dan Zheng
aa66cce808 [AutoDiff upstream] Add differentiation transform.
The differentiation transform does the following:
- Canonicalizes differentiability witnesses by filling in missing derivative
  function entries.
- Canonicalizes `differentiable_function` instructions by filling in missing
  derivative function operands.
- If necessary, performs automatic differentiation: generating derivative
  functions for original functions.
  - When encountering non-differentiability code, produces a diagnostic and
    errors out.

Partially resolves TF-1211: add the main canonicalization loop.

To incrementally stage changes, derivative functions are currently created
with empty bodies that fatal error with a nice message.

Derivative emitters will be upstreamed separately.
2020-04-02 15:43:57 -07:00