Commit Graph

111 Commits

Author SHA1 Message Date
marcrasi
207958c12f [AutoDiff] fix tests to pass on tensorflow branch (#31264) 2020-04-24 09:53:05 -07:00
ematejska
4cd68edf8c [Autodiff upstream] Add DifferentiabilityWitnessDevirtualizer SILOptimizer pass (#30984)
Add DifferentiabilityWitnessDevirtualizer: an optimization pass that
devirtualizes `differentiability_witness_function` instructions into
`function_ref` instructions.

Co-authored-by: Dan Zheng <danielzheng@google.com>
2020-04-23 02:13:05 -07:00
marcrasi
a48880d13f [AutoDiff upstream] add more validation tests (#31190) 2020-04-22 17:31:11 -07:00
marcrasi
bc0dd81950 [AutoDiff upstream] SILOptimizer tests (#31114) 2020-04-17 20:46:47 -07:00
marcrasi
99356cd9f4 [AutoDiff upstream] add more differentiation tests (#30933) 2020-04-13 09:40:45 -07:00
marcrasi
d99d8da956 [AutoDiff upstream] handle differentiable_function in DiagnoseInvalidEscapingCaptures (#30909) 2020-04-09 09:34:24 -07:00
Dan Zheng
2eb460de4d [AutoDiff upstream] Add forward-mode differentiation. (#30878)
JVP functions are forward-mode derivative functions. They take original
arguments and return original results and a differential function. Differential
functions take derivatives wrt arguments and return derivatives wrt results.

`JVPEmitter` is a cloner that emits JVP and differential functions at the same
time. In JVP functions, function applications are replaced with JVP function
applications. In differential functions, function applications are replaced
with differential function applications.

In JVP functions, each basic block takes a differential struct containing callee
differentials. These structs are consumed by differential functions.
2020-04-08 11:29:21 -07:00
Dan Zheng
d6bbf97886 Add simple generated derivative code FileCheck test. 2020-04-05 21:15:23 -07:00
Dan Zheng
b833271215 Simplify test before stdlib derivatives are upstreamed. 2020-04-05 20:35:35 -07:00
Dan Zheng
146c11ec80 [AutoDiff upstream] Add differentiable_function canonicalization. (#30818)
Canonicalizes `differentiable_function` instructions by filling in missing
derivative function operands.

Derivative function emission rules, based on the original function value:

- `function_ref`: look up differentiability witness with the exact or a minimal
  superset derivative configuration. Emit a `differentiability_witness_function`
  for the derivative function.
- `witness_method`: emit a `witness_method` with the minimal superset derivative
  configuration for the derivative function.
- `class_method`: emit a `class_method` with the minimal superset derivative
  configuration for the derivative function.

If an *actual* emitted derivative function has a superset derivative
configuration versus the *desired* derivative configuration, create a "subset
parameters thunk" to thunk the actual derivative to the desired type.

For `differentiable_function` instructions formed from curry thunk applications:
clone the curry thunk (with type `(Self) -> (T, ...) -> U`) and create a new
version with type `(Self) -> @differentiable (T, ...) -> U`.

Progress towards TF-1211.
2020-04-05 20:19:10 -07:00
Dan Zheng
aa66cce808 [AutoDiff upstream] Add differentiation transform.
The differentiation transform does the following:
- Canonicalizes differentiability witnesses by filling in missing derivative
  function entries.
- Canonicalizes `differentiable_function` instructions by filling in missing
  derivative function operands.
- If necessary, performs automatic differentiation: generating derivative
  functions for original functions.
  - When encountering non-differentiability code, produces a diagnostic and
    errors out.

Partially resolves TF-1211: add the main canonicalization loop.

To incrementally stage changes, derivative functions are currently created
with empty bodies that fatal error with a nice message.

Derivative emitters will be upstreamed separately.
2020-04-02 15:43:57 -07:00