Merge pull request #30711 from rxwei/differential-operators

[AutoDiff upstream] Add differential operators and some utilities.
This commit is contained in:
marcrasi
2020-04-01 10:11:35 -07:00
committed by GitHub
12 changed files with 652 additions and 40 deletions

View File

@@ -1106,14 +1106,9 @@ static ManagedValue emitBuiltinAutoDiffApplyTransposeFunction(
origFnArgVals.push_back(arg.getValue());
// Get the transpose function.
// TODO(TF-1142): Create a linear_function_extract instead of an undef.
auto fnTy = origFnVal->getType().castTo<SILFunctionType>();
auto transposeFnType =
fnTy->getWithoutDifferentiability()->getAutoDiffTransposeFunctionType(
fnTy->getDifferentiabilityParameterIndices(), SGF.SGM.M.Types,
LookUpConformanceInModule(SGF.SGM.M.getSwiftModule()));
SILValue transposeFn =
SILUndef::get(SILType::getPrimitiveObjectType(transposeFnType), SGF.F);
SILValue transposeFn = SGF.B.createLinearFunctionExtract(
loc, LinearDifferentiableFunctionTypeComponent::Transpose, origFnVal);
auto transposeFnType = transposeFn->getType().castTo<SILFunctionType>();
auto transposeFnUnsubstType =
transposeFnType->getUnsubstitutedType(SGF.getModule());
if (transposeFnType != transposeFnUnsubstType) {
@@ -1204,19 +1199,16 @@ static ManagedValue emitBuiltinLinearFunction(
assert(args.size() == 2);
auto origFn = args.front();
auto origType = origFn.getType().castTo<SILFunctionType>();
// TODO(TF-1142): Create a linear_function instead of an undef.
auto linearFnTy = origType->getWithDifferentiability(
DifferentiabilityKind::Linear,
auto linearFn = SGF.B.createLinearFunction(
loc,
IndexSubset::getDefault(
SGF.getASTContext(), origType->getNumParameters(),
/*includeAll*/ true));
SILValue linearFn = SILUndef::get(
SILType::getPrimitiveObjectType(linearFnTy), SGF.F);
SGF.getASTContext(),
origType->getNumParameters(),
/*includeAll*/ true),
origFn.forward(SGF), args[1].forward(SGF));
return SGF.emitManagedRValueWithCleanup(linearFn);
}
/// Emit SIL for the named builtin: globalStringTablePointer. Unlike the default
/// ownership convention for named builtins, which is to take (non-trivial)
/// arguments as Owned, this builtin accepts owned as well as guaranteed