[AutoDiff] Support differentiable functions with multiple semantic results (#66873)

Add support for differentiable functions having multiple semantic results

Co-authored-by: Brad Larson <larson@sunsetlakesoftware.com>
This commit is contained in:
Anton Korobeynikov
2023-07-06 16:31:39 -07:00
committed by GitHub
parent 29ce7a341d
commit eb82df6bc6
26 changed files with 617 additions and 174 deletions

View File

@@ -722,9 +722,10 @@ void ModuleFile::loadDerivativeFunctionConfigurations(
}
auto derivativeGenSig = derivativeGenSigOrError.get();
// NOTE(TF-1038): Result indices are currently unsupported in derivative
// registration attributes. In the meantime, always use `{0}` (wrt the
// first and only result).
auto resultIndices = IndexSubset::get(ctx, 1, {0});
// registration attributes. In the meantime, always use all results.
auto *resultIndices =
autodiff::getFunctionSemanticResultIndices(originalAFD,
parameterIndices);
results.insert({parameterIndices, resultIndices, derivativeGenSig});
}
}