Make `SynthesizedFileUnit` attached to a `SourceFile`. This seemed like the
least ad-hoc approach to avoid doing unnecessary work for other `FileUnit`s.
TBDGen: when visiting a `SourceFile`, also visit its `SynthesizedFileUnit` if
it exists.
Serialization: do not treat `SynthesizedFileUnit` declarations as xrefs when
serializing the companion `SourceFile`.
Resolves TF-1239: AutoDiff test failures.
JVP functions are forward-mode derivative functions. They take original
arguments and return original results and a differential function. Differential
functions take derivatives wrt arguments and return derivatives wrt results.
`JVPEmitter` is a cloner that emits JVP and differential functions at the same
time. In JVP functions, function applications are replaced with JVP function
applications. In differential functions, function applications are replaced
with differential function applications.
In JVP functions, each basic block takes a differential struct containing callee
differentials. These structs are consumed by differential functions.
Make `ADContext` lazily create a `SynthesizedFileUnit` instead of creating one
during `ADContext` construction. This avoids always creating a
`SynthesizedFileUnit` in every module, since differentiation is a mandatory
transform that always runs.
It was nonetheless useful to test always creating a `SynthesizedFileUnit` for
testing purposes.
Add implicit declarations generated by the differentiation transform to a
`SynthesizedFileUnit` instead of an ad-hoc pre-existing `SourceFile`.
Resolves TF-1232: type reconstruction for AutoDiff-generated declarations.
Previously, type reconstruction failed because retroactively adding declarations
to a `SourceFile` did not update name lookup caches.
`PullbackEmitter` is a visitor that emits pullback functions. It implements
reverse-mode automatic differentiation, along with `VJPEmitter`.
Pullback functions take derivatives with respect to outputs and return
derivatives with respect to inputs. Every active value/address in an original
function has a corresponding adjoint value/buffer in the pullback function.
Pullback functions consume pullback structs and predecessor enums constructed
by VJP functions.
`VJPEmitter` is a cloner that emits VJP functions. It implements reverse-mode
automatic differentiation, along with `PullbackEmitter`.
`VJPEmitter` clones an original function, replacing function applications with
VJP function applications. In VJP functions, each basic block takes a pullback
struct (containing callee pullbacks) and produces a predecessor enum: these data
structures are consumed by pullback functions.
`LinearMapInfo` contains information about linear map structs and branching
trace enums, which are auxiliary data structures created by the differentiation
transform.
These data structures are constructed in JVP/VJP functions and consumed in
differential/pullback functions.
Canonicalizes `differentiable_function` instructions by filling in missing
derivative function operands.
Derivative function emission rules, based on the original function value:
- `function_ref`: look up differentiability witness with the exact or a minimal
superset derivative configuration. Emit a `differentiability_witness_function`
for the derivative function.
- `witness_method`: emit a `witness_method` with the minimal superset derivative
configuration for the derivative function.
- `class_method`: emit a `class_method` with the minimal superset derivative
configuration for the derivative function.
If an *actual* emitted derivative function has a superset derivative
configuration versus the *desired* derivative configuration, create a "subset
parameters thunk" to thunk the actual derivative to the desired type.
For `differentiable_function` instructions formed from curry thunk applications:
clone the curry thunk (with type `(Self) -> (T, ...) -> U`) and create a new
version with type `(Self) -> @differentiable (T, ...) -> U`.
Progress towards TF-1211.
The differentiation transform does the following:
- Canonicalizes differentiability witnesses by filling in missing derivative
function entries.
- Canonicalizes `differentiable_function` instructions by filling in missing
derivative function operands.
- If necessary, performs automatic differentiation: generating derivative
functions for original functions.
- When encountering non-differentiability code, produces a diagnostic and
errors out.
Partially resolves TF-1211: add the main canonicalization loop.
To incrementally stage changes, derivative functions are currently created
with empty bodies that fatal error with a nice message.
Derivative emitters will be upstreamed separately.