mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Merge pull request #30711 from rxwei/differential-operators
[AutoDiff upstream] Add differential operators and some utilities.
This commit is contained in:
@@ -1106,14 +1106,9 @@ static ManagedValue emitBuiltinAutoDiffApplyTransposeFunction(
|
||||
origFnArgVals.push_back(arg.getValue());
|
||||
|
||||
// Get the transpose function.
|
||||
// TODO(TF-1142): Create a linear_function_extract instead of an undef.
|
||||
auto fnTy = origFnVal->getType().castTo<SILFunctionType>();
|
||||
auto transposeFnType =
|
||||
fnTy->getWithoutDifferentiability()->getAutoDiffTransposeFunctionType(
|
||||
fnTy->getDifferentiabilityParameterIndices(), SGF.SGM.M.Types,
|
||||
LookUpConformanceInModule(SGF.SGM.M.getSwiftModule()));
|
||||
SILValue transposeFn =
|
||||
SILUndef::get(SILType::getPrimitiveObjectType(transposeFnType), SGF.F);
|
||||
SILValue transposeFn = SGF.B.createLinearFunctionExtract(
|
||||
loc, LinearDifferentiableFunctionTypeComponent::Transpose, origFnVal);
|
||||
auto transposeFnType = transposeFn->getType().castTo<SILFunctionType>();
|
||||
auto transposeFnUnsubstType =
|
||||
transposeFnType->getUnsubstitutedType(SGF.getModule());
|
||||
if (transposeFnType != transposeFnUnsubstType) {
|
||||
@@ -1204,19 +1199,16 @@ static ManagedValue emitBuiltinLinearFunction(
|
||||
assert(args.size() == 2);
|
||||
auto origFn = args.front();
|
||||
auto origType = origFn.getType().castTo<SILFunctionType>();
|
||||
// TODO(TF-1142): Create a linear_function instead of an undef.
|
||||
auto linearFnTy = origType->getWithDifferentiability(
|
||||
DifferentiabilityKind::Linear,
|
||||
auto linearFn = SGF.B.createLinearFunction(
|
||||
loc,
|
||||
IndexSubset::getDefault(
|
||||
SGF.getASTContext(), origType->getNumParameters(),
|
||||
/*includeAll*/ true));
|
||||
SILValue linearFn = SILUndef::get(
|
||||
SILType::getPrimitiveObjectType(linearFnTy), SGF.F);
|
||||
SGF.getASTContext(),
|
||||
origType->getNumParameters(),
|
||||
/*includeAll*/ true),
|
||||
origFn.forward(SGF), args[1].forward(SGF));
|
||||
return SGF.emitManagedRValueWithCleanup(linearFn);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Emit SIL for the named builtin: globalStringTablePointer. Unlike the default
|
||||
/// ownership convention for named builtins, which is to take (non-trivial)
|
||||
/// arguments as Owned, this builtin accepts owned as well as guaranteed
|
||||
|
||||
Reference in New Issue
Block a user