[AutoDiff] Disable differentiable_function_extract explicit type as… (#35239)

`differentiability_function_extract` instruction has an optional explicit
extractee type. This is currently used by TypeSubstCloner and the
LoadableByAddress transform to rewrite `differentiability_function_extract`
instructions while preserving `@differentiable` function type invariants.

There is an assertion that `differentiability_function_extract` instructions do
not have explicit extractee types outside of canonical/lowered SIL. However,
this does not handle the SIL deserialization case above: when a function
containing a `differentiable_function_extract` instruction with an explicit type
is deserialized into a raw SIL module (which happens when optimizations are
enabled).

Removing the assertion unblocks this encountered use case.

A more robust longer-term solution may be to change SIL `@differentiable`
function types to explicitly store component original/JVP/VJP function types.

Also fix `differentiable_function_extract` extractee type serialization.

Resolves SR-14004.
This commit is contained in:
Dan Zheng
2021-01-04 18:40:11 -05:00
committed by GitHub
parent 650c1e600c
commit 126f1ac6fb
7 changed files with 53 additions and 20 deletions

View File

@@ -731,18 +731,6 @@ DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst(
: getExtracteeType(function, extractee, module),
function.getOwnershipKind()),
Extractee(extractee), HasExplicitExtracteeType(extracteeType.hasValue()) {
#ifndef NDEBUG
if (extracteeType.hasValue()) {
// Note: explicit extractee type is used to avoid inconsistent typing in:
// - Canonical SIL, due to generic specialization.
// - Lowered SIL, due to LoadableByAddress.
// See `TypeSubstCloner::visitDifferentiableFunctionExtractInst` for an
// explanation of how explicit extractee type is used.
assert((module.getStage() == SILStage::Canonical ||
module.getStage() == SILStage::Lowered) &&
"Explicit type is valid only in canonical or lowered SIL");
}
#endif
}
SILType LinearFunctionExtractInst::