Files
swift-mirror/lib/SILOptimizer/Differentiation/Thunk.cpp
2025-09-17 15:46:33 -10:00

893 lines
37 KiB
C++

//===--- Thunk.cpp - Automatic differentiation thunks ---------*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Automatic differentiation thunk generation utilities.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "differentiation"
#include "swift/SILOptimizer/Differentiation/Thunk.h"
#include "swift/SILOptimizer/Differentiation/Common.h"
#include "swift/AST/AnyFunctionRef.h"
#include "swift/AST/Requirement.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/Basic/Assertions.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
namespace swift {
namespace autodiff {
//===----------------------------------------------------------------------===//
// Thunk helpers
//===----------------------------------------------------------------------===//
// These helpers are copied/adapted from SILGen. They should be refactored and
// moved to a shared location.
//===----------------------------------------------------------------------===//
CanSILFunctionType buildThunkType(SILFunction *fn,
CanSILFunctionType &sourceType,
CanSILFunctionType &expectedType,
GenericEnvironment *&genericEnv,
SubstitutionMap &interfaceSubs,
bool withoutActuallyEscaping,
DifferentiationThunkKind thunkKind) {
CanType inputSubstType;
CanType outputSubstType;
CanType dynamicSelfType;
return buildSILFunctionThunkType(
fn, sourceType, expectedType, inputSubstType, outputSubstType, genericEnv,
interfaceSubs, dynamicSelfType, withoutActuallyEscaping, thunkKind);
}
/// Forward function arguments, handling ownership convention mismatches.
/// Adapted from `forwardFunctionArguments` in SILGenPoly.cpp.
///
/// Forwarded arguments are appended to `forwardedArgs`.
///
/// Local allocations are appended to `localAllocations`. They need to be
/// deallocated via `dealloc_stack`.
///
/// Local values requiring cleanup are appended to `valuesToCleanup`.
static void forwardFunctionArgumentsConvertingOwnership(
SILBuilder &builder, SILLocation loc, CanSILFunctionType fromTy,
CanSILFunctionType toTy, ArrayRef<SILArgument *> originalArgs,
SmallVectorImpl<SILValue> &forwardedArgs,
SmallVectorImpl<AllocStackInst *> &localAllocations,
SmallVectorImpl<SILValue> &valuesToCleanup) {
auto fromParameters = fromTy->getParameters();
auto toParameters = toTy->getParameters();
assert(fromParameters.size() == toParameters.size());
assert(fromParameters.size() == originalArgs.size());
for (auto index : indices(originalArgs)) {
auto &arg = originalArgs[index];
auto fromParam = fromParameters[index];
auto toParam = toParameters[index];
// To convert guaranteed argument to be owned, create a copy.
if (fromParam.isConsumedInCaller() && !toParam.isConsumedInCallee()) {
// If the argument has an object type, create a `copy_value`.
if (arg->getType().isObject()) {
auto argCopy = builder.emitCopyValueOperation(loc, arg);
forwardedArgs.push_back(argCopy);
continue;
}
// If the argument has an address type, create a local allocation and
// `copy_addr` its contents to the local allocation.
auto *alloc = builder.createAllocStack(loc, arg->getType());
builder.createCopyAddr(loc, arg, alloc, IsNotTake, IsInitialization);
localAllocations.push_back(alloc);
forwardedArgs.push_back(alloc);
continue;
}
// To convert owned argument to be guaranteed, borrow the argument.
if (fromParam.isGuaranteedInCaller() && !toParam.isGuaranteedInCaller()) {
auto bbi = builder.emitBeginBorrowOperation(loc, arg);
forwardedArgs.push_back(bbi);
valuesToCleanup.push_back(bbi);
valuesToCleanup.push_back(arg);
continue;
}
// Otherwise, simply forward the argument.
forwardedArgs.push_back(arg);
}
}
SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
SILModule &module, SILLocation loc,
SILFunction *caller,
CanSILFunctionType fromType,
CanSILFunctionType toType) {
assert(!fromType->getCombinedSubstitutions());
assert(!toType->getCombinedSubstitutions());
SubstitutionMap interfaceSubs;
GenericEnvironment *genericEnv = nullptr;
auto thunkType =
buildThunkType(caller, fromType, toType, genericEnv, interfaceSubs,
/*withoutActuallyEscaping*/ false,
DifferentiationThunkKind::Reabstraction);
auto thunkDeclType =
thunkType->getWithExtInfo(thunkType->getExtInfo().withNoEscape(false));
auto fromInterfaceType = fromType->mapTypeOutOfContext()->getCanonicalType();
auto toInterfaceType = toType->mapTypeOutOfContext()->getCanonicalType();
Mangle::ASTMangler mangler(module.getASTContext());
std::string name = mangler.mangleReabstractionThunkHelper(
thunkType, fromInterfaceType, toInterfaceType, Type(), Type(),
module.getSwiftModule());
auto *thunk = fb.getOrCreateSharedFunction(
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
ProfileCounter(), IsReabstractionThunk, IsNotDynamic, IsNotDistributed,
IsNotRuntimeAccessible);
if (!thunk->empty())
return thunk;
thunk->setGenericEnvironment(genericEnv);
auto *entry = thunk->createBasicBlock();
SILBuilder builder(entry);
createEntryArguments(thunk);
SILFunctionConventions fromConv(fromType, module);
SILFunctionConventions toConv(toType, module);
assert(toConv.useLoweredAddresses());
// Forward thunk arguments, handling ownership convention mismatches.
SmallVector<SILValue, 4> forwardedArgs;
for (auto indRes : thunk->getIndirectResults())
forwardedArgs.push_back(indRes);
SmallVector<AllocStackInst *, 4> localAllocations;
SmallVector<SILValue, 4> valuesToCleanup;
forwardFunctionArgumentsConvertingOwnership(
builder, loc, fromType, toType,
thunk->getArgumentsWithoutIndirectResults().drop_back(), forwardedArgs,
localAllocations, valuesToCleanup);
SmallVector<SILValue, 4> arguments;
auto toArgIter = forwardedArgs.begin();
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
auto createAllocStack = [&](SILType type) {
auto *alloc = builder.createAllocStack(loc, type);
localAllocations.push_back(alloc);
return alloc;
};
// Handle indirect results.
assert(fromType->getNumResults() == toType->getNumResults());
for (unsigned resIdx : range(toType->getNumResults())) {
auto fromRes = fromConv.getResults()[resIdx];
auto toRes = toConv.getResults()[resIdx];
// No abstraction mismatch.
if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
// If result types are indirect, directly pass as next argument.
if (toRes.isFormalIndirect())
useNextArgument();
continue;
}
// Convert indirect result to direct result.
if (fromRes.isFormalIndirect()) {
SILType resultTy =
fromConv.getSILType(fromRes, builder.getTypeExpansionContext());
assert(resultTy.isAddress());
auto *indRes = createAllocStack(resultTy);
arguments.push_back(indRes);
continue;
}
// Convert direct result to indirect result.
// Increment thunk argument iterator; reabstraction handled later.
++toArgIter;
}
// Reabstract parameters.
assert(toType->getNumParameters() == fromType->getNumParameters());
for (unsigned paramIdx : range(toType->getNumParameters())) {
auto fromParam = fromConv.getParameters()[paramIdx];
auto toParam = toConv.getParameters()[paramIdx];
// No abstraction mismatch. Directly use next argument.
if (fromParam.isFormalIndirect() == toParam.isFormalIndirect()) {
useNextArgument();
continue;
}
// Convert indirect parameter to direct parameter.
if (fromParam.isFormalIndirect()) {
auto paramTy = fromConv.getSILType(fromType->getParameters()[paramIdx],
builder.getTypeExpansionContext());
if (!paramTy.hasArchetype())
paramTy = thunk->mapTypeIntoContext(paramTy);
assert(paramTy.isAddress());
auto toArg = *toArgIter++;
auto *buf = createAllocStack(toArg->getType());
toArg = builder.emitCopyValueOperation(loc, toArg);
builder.emitStoreValueOperation(loc, toArg, buf,
StoreOwnershipQualifier::Init);
valuesToCleanup.push_back(buf);
arguments.push_back(buf);
continue;
}
// Convert direct parameter to indirect parameter.
assert(toParam.isFormalIndirect());
auto toArg = *toArgIter++;
auto load = builder.emitLoadBorrowOperation(loc, toArg);
if (isa<LoadBorrowInst>(load))
valuesToCleanup.push_back(load);
arguments.push_back(load);
}
auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();
auto *apply = builder.createApply(loc, fnArg, SubstitutionMap(), arguments);
// Get return elements.
SmallVector<SILValue, 4> results;
// Extract all direct results.
SmallVector<SILValue, 4> directResults;
extractAllElements(apply, builder, directResults);
auto fromDirResultsIter = directResults.begin();
auto fromIndResultsIter = apply->getIndirectSILResults().begin();
auto toIndResultsIter = thunk->getIndirectResults().begin();
// Reabstract results.
for (unsigned resIdx : range(toType->getNumResults())) {
auto fromRes = fromConv.getResults()[resIdx];
auto toRes = toConv.getResults()[resIdx];
// Check function-typed results.
if (isa<SILFunctionType>(fromRes.getInterfaceType()) &&
isa<SILFunctionType>(toRes.getInterfaceType())) {
auto fromFnType = cast<SILFunctionType>(fromRes.getInterfaceType());
auto toFnType = cast<SILFunctionType>(toRes.getInterfaceType());
auto fromUnsubstFnType = fromFnType->getUnsubstitutedType(module);
auto toUnsubstFnType = toFnType->getUnsubstitutedType(module);
// If unsubstituted function types are not equal, perform reabstraction.
if (fromUnsubstFnType != toUnsubstFnType) {
auto fromFn = *fromDirResultsIter++;
auto newFromFn = reabstractFunction(
builder, fb, loc, fromFn, toFnType,
[](SubstitutionMap substMap) { return substMap; });
results.push_back(newFromFn);
continue;
}
}
// No abstraction mismatch.
if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
// If result types are direct, add call result as direct thunk result.
if (toRes.isFormalDirect())
results.push_back(*fromDirResultsIter++);
// If result types are indirect, increment indirect result iterators.
else {
++fromIndResultsIter;
++toIndResultsIter;
}
continue;
}
// Load direct results from indirect results.
if (fromRes.isFormalIndirect()) {
auto indRes = *fromIndResultsIter++;
auto load = builder.emitLoadValueOperation(loc, indRes,
LoadOwnershipQualifier::Take);
results.push_back(load);
continue;
}
// Store direct results to indirect results.
assert(toRes.isFormalIndirect());
#ifndef NDEBUG
SILType resultTy =
toConv.getSILType(toRes, builder.getTypeExpansionContext());
assert(resultTy.isAddress());
#endif
auto indRes = *toIndResultsIter++;
auto dirRes = *fromDirResultsIter++;
builder.emitStoreValueOperation(loc, dirRes, indRes,
StoreOwnershipQualifier::Init);
}
auto retVal = joinElements(results, builder, loc);
// Clean up local values.
// Guaranteed values need an `end_borrow`.
// Owned values need to be destroyed.
for (auto arg : valuesToCleanup) {
switch (arg->getOwnershipKind()) {
case OwnershipKind::Any:
llvm_unreachable("value with any ownership kind?!");
case OwnershipKind::Guaranteed:
builder.emitEndBorrowOperation(loc, arg);
break;
case OwnershipKind::Owned:
case OwnershipKind::Unowned:
case OwnershipKind::None:
builder.emitDestroyOperation(loc, arg);
break;
}
}
// Deallocate local allocations.
for (auto *alloc : llvm::reverse(localAllocations))
builder.createDeallocStack(loc, alloc);
// Create return.
builder.createReturn(loc, retVal);
LLVM_DEBUG(auto &s = getADDebugStream() << "Created reabstraction thunk.\n";
s << " From type: " << fromType << '\n';
s << " To type: " << toType << '\n'; s << '\n'
<< *thunk);
return thunk;
}
// FIXME: This is pretty rudimentary as of now as there is no proper AST type
// for coroutine and therefore we cannot e.g. store a coroutine into a tuple or
// do other things that are allowed with first-class function types. For now we
// have to unsafely bitcast coroutine to function type and vice versa. This
// function should be rethought when we will have proper AST coroutine types.
SILValue reabstractCoroutine(
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
SILValue fn, CanSILFunctionType toType,
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions) {
auto &module = *fn->getModule();
auto fromType = fn->getType().getAs<SILFunctionType>();
auto unsubstFromType = fromType->getUnsubstitutedType(module);
auto unsubstToType = toType->getUnsubstitutedType(module);
LLVM_DEBUG(auto &s = getADDebugStream() << "Converting coroutine\n";
s << " From type: " << fromType << '\n';
s << " To type: " << toType << '\n'; s << '\n');
if (fromType != unsubstFromType)
fn = builder.createConvertFunction(
loc, fn, SILType::getPrimitiveObjectType(unsubstFromType),
/*withoutActuallyEscaping*/ false);
fn = builder.createConvertFunction(loc, fn,
SILType::getPrimitiveObjectType(unsubstToType),
/*withoutActuallyEscaping*/ false);
if (toType != unsubstToType)
fn = builder.createConvertFunction(loc, fn,
SILType::getPrimitiveObjectType(toType),
/*withoutActuallyEscaping*/ false);
return fn;
}
SILValue reabstractFunction(
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
SILValue fn, CanSILFunctionType toType,
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions) {
auto &module = *fn->getModule();
auto fromType = fn->getType().getAs<SILFunctionType>();
auto unsubstFromType = fromType->getUnsubstitutedType(module);
auto unsubstToType = toType->getUnsubstitutedType(module);
auto *thunk = getOrCreateReabstractionThunk(fb, module, loc,
/*caller*/ fn->getFunction(),
unsubstFromType, unsubstToType);
auto *thunkRef = builder.createFunctionRef(loc, thunk);
if (fromType != unsubstFromType)
fn = builder.createConvertFunction(
loc, fn, SILType::getPrimitiveObjectType(unsubstFromType),
/*withoutActuallyEscaping*/ false);
fn = builder.createPartialApply(
loc, thunkRef, remapSubstitutions(thunk->getForwardingSubstitutionMap()),
{fn}, fromType->getCalleeConvention());
if (toType != unsubstToType)
fn = builder.createConvertFunction(loc, fn,
SILType::getPrimitiveObjectType(toType),
/*withoutActuallyEscaping*/ false);
return fn;
}
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForLinearMap(
SILOptFunctionBuilder &fb, SILFunction *parentThunk,
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig,
ADContext &adContext) {
LLVM_DEBUG(getADDebugStream()
<< "Getting a subset parameters thunk for "
<< (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp")
<< " linear map " << linearMapType
<< " from " << actualConfig << " to " << desiredConfig << '\n');
assert(!linearMapType->getCombinedSubstitutions());
assert(!targetType->getCombinedSubstitutions());
SubstitutionMap interfaceSubs;
GenericEnvironment *genericEnv = nullptr;
auto thunkType = buildThunkType(parentThunk, linearMapType, targetType,
genericEnv, interfaceSubs,
/*withoutActuallyEscaping*/ true,
DifferentiationThunkKind::Reabstraction);
Mangle::DifferentiationMangler mangler(parentThunk->getASTContext());
auto fromInterfaceType =
linearMapType->mapTypeOutOfContext()->getCanonicalType();
auto thunkName = mangler.mangleLinearMapSubsetParametersThunk(
fromInterfaceType, kind.getLinearMapKind(),
actualConfig.parameterIndices, actualConfig.resultIndices,
desiredConfig.parameterIndices);
auto loc = parentThunk->getLocation();
auto *thunk = fb.getOrCreateSharedFunction(
loc, thunkName, thunkType, IsBare, IsTransparent, IsSerialized,
ProfileCounter(), IsThunk, IsNotDynamic, IsNotDistributed,
IsNotRuntimeAccessible);
if (!thunk->empty())
return {thunk, interfaceSubs};
thunk->setGenericEnvironment(genericEnv);
auto *entry = thunk->createBasicBlock();
TangentBuilder builder(entry, adContext);
createEntryArguments(thunk);
// Get arguments.
SmallVector<SILValue, 4> arguments;
SmallVector<AllocStackInst *, 4> localAllocations;
SmallVector<SILValue, 4> valuesToCleanup;
auto cleanupValues = [&]() {
for (auto value : llvm::reverse(valuesToCleanup))
builder.emitDestroyOperation(loc, value);
for (auto *alloc : llvm::reverse(localAllocations))
builder.createDeallocStack(loc, alloc);
};
// Build a `.zero` argument for the given `Differentiable`-conforming type.
auto buildZeroArgument = [&](SILParameterInfo zeroSILParameter) {
auto zeroSILType = zeroSILParameter.getSILStorageInterfaceType();
auto zeroSILObjType = zeroSILType.getObjectType();
auto zeroType = zeroSILType.getASTType();
auto tangentSpace =
zeroType->getAutoDiffTangentSpace(LookUpConformanceInModule());
assert(tangentSpace && "No tangent space for this type");
switch (tangentSpace->getKind()) {
case TangentSpace::Kind::TangentVector: {
auto *buf = builder.createAllocStack(loc, zeroSILObjType);
localAllocations.push_back(buf);
builder.emitZeroIntoBuffer(loc, buf, IsInitialization);
if (zeroSILType.isAddress()) {
arguments.push_back(buf);
if (zeroSILParameter.isGuaranteedInCaller()) {
valuesToCleanup.push_back(buf);
}
} else {
auto arg = builder.emitLoadValueOperation(loc, buf,
LoadOwnershipQualifier::Take);
arguments.push_back(arg);
if (zeroSILParameter.isGuaranteedInCaller()) {
valuesToCleanup.push_back(arg);
}
}
break;
}
case TangentSpace::Kind::Tuple: {
llvm_unreachable("Unimplemented: Handle zero initialization for tuples");
}
}
};
// The indices in `actualConfig` and `desiredConfig` are with respect to the
// original function. However, the differential parameters and pullback
// results may already be w.r.t. a subset. We create a map between the
// original function's actual parameter indices and the linear map's actual
// indices.
// Example:
// Original: (T0, T1, T2) -> R
// Actual indices: 0, 2
// Original differential: (T0, T2) -> R
// Original pullback: R -> (T0, T2)
// Desired indices w.r.t. original: 2
// Desired indices w.r.t. linear map: 1
SmallVector<unsigned, 4> actualParamIndicesMap(
actualConfig.parameterIndices->getCapacity(), UINT_MAX);
{
unsigned indexInBitVec = 0;
for (auto index : actualConfig.parameterIndices->getIndices()) {
actualParamIndicesMap[index] = indexInBitVec;
++indexInBitVec;
}
}
auto mapOriginalParameterIndex = [&](unsigned index) -> unsigned {
auto mappedIndex = actualParamIndicesMap[index];
assert(mappedIndex < actualConfig.parameterIndices->getCapacity());
return mappedIndex;
};
auto toIndirectResultsIter = thunk->getIndirectResults().begin();
auto useNextIndirectResult = [&]() {
assert(toIndirectResultsIter != thunk->getIndirectResults().end());
arguments.push_back(*toIndirectResultsIter++);
};
switch (kind) {
// Differential arguments are:
// - All indirect results, followed by:
// - An interleaving of:
// - Thunk arguments (when parameter index is in both desired and actual
// indices).
// - Zeros (when parameter is not in desired indices).
case AutoDiffDerivativeFunctionKind::JVP: {
unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults();
// Forward desired indirect results
for (unsigned idx : *actualConfig.resultIndices) {
if (idx >= numIndirectResults)
break;
auto resultInfo = linearMapType->getResults()[idx];
assert(idx < linearMapType->getNumResults());
// Forward result argument in case we do not need to thunk it away
if (desiredConfig.resultIndices->contains(idx)) {
useNextIndirectResult();
continue;
}
// Otherwise, allocate and use an uninitialized indirect result
auto *indirectResult = builder.createAllocStack(
loc, resultInfo.getSILStorageInterfaceType());
localAllocations.push_back(indirectResult);
arguments.push_back(indirectResult);
}
assert(toIndirectResultsIter == thunk->getIndirectResults().end());
auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin();
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
// Iterate over actual indices.
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
// If index is desired, use next argument.
if (desiredConfig.isWrtParameter(i)) {
useNextArgument();
}
// Otherwise, construct and use a zero argument.
else {
auto zeroSILParameter =
linearMapType->getParameters()[mapOriginalParameterIndex(i)];
buildZeroArgument(zeroSILParameter);
}
}
break;
}
// Pullback arguments are:
// - An interleaving of:
// - Thunk indirect results (when parameter index is in both desired and
// actual indices).
// - Zeros (when parameter is not in desired indices).
// - All actual arguments.
case AutoDiffDerivativeFunctionKind::VJP: {
// Collect pullback arguments.
unsigned pullbackResultIndex = 0;
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
auto origParamInfo = origFnType->getParameters()[i];
// Skip original semantic result parameters. All non-indirect-result pullback
// arguments (including semantic result` arguments) are appended to `arguments`
// later.
if (origParamInfo.isAutoDiffSemanticResult())
continue;
auto resultInfo = linearMapType->getResults()[pullbackResultIndex];
assert(pullbackResultIndex < linearMapType->getNumResults());
++pullbackResultIndex;
// Skip pullback direct results. Only indirect results are relevant as
// arguments.
if (resultInfo.isFormalDirect())
continue;
// If index is desired, use next pullback indirect result.
if (desiredConfig.isWrtParameter(i)) {
useNextIndirectResult();
continue;
}
// Otherwise, allocate and use an uninitialized pullback indirect result.
auto *indirectResult = builder.createAllocStack(
loc, resultInfo.getSILStorageInterfaceType());
localAllocations.push_back(indirectResult);
arguments.push_back(indirectResult);
}
// Forward all actual non-indirect-result arguments.
auto thunkArgs = thunk->getArgumentsWithoutIndirectResults();
// Slice out the function to be called
thunkArgs = thunkArgs.slice(0, thunkArgs.size() - 1);
unsigned thunkArg = 0;
for (unsigned idx : *actualConfig.resultIndices) {
// Forward result argument in case we do not need to thunk it away
if (desiredConfig.resultIndices->contains(idx))
arguments.push_back(thunkArgs[thunkArg++]);
else // otherwise, zero it out
buildZeroArgument(linearMapType->getParameters()[arguments.size()]);
}
break;
}
}
// Get the linear map thunk argument and apply it.
auto *linearMap = thunk->getArguments().back();
auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments);
// If differential thunk, deallocate local allocations and directly return
// `apply` result (if it is desired).
// TODO: Unify with VJP code below
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
SmallVector<SILValue, 8> differentialDirectResults;
extractAllElements(ai, builder, differentialDirectResults);
SmallVector<SILValue, 8> allResults;
collectAllActualResultsInTypeOrder(ai, differentialDirectResults, allResults);
SmallVector<SILValue, 8> results;
unsigned firstSemanticParamResultIdx = origFnType->getNumResults();
for (unsigned resultIndex : *actualConfig.resultIndices) {
SILValue result;
if (resultIndex >= firstSemanticParamResultIdx) {
auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx;
result =
*std::next(ai->getAutoDiffSemanticResultArguments().begin(),
semanticResultArgIdx);
} else
result = allResults[resultIndex];
// If result is desired:
// - Do nothing if result is indirect.
// (It was already forwarded to the `apply` instruction).
// - Push it to `results` if result is direct.
if (desiredConfig.isWrtResult(resultIndex)) {
if (result->getType().isObject())
results.push_back(result);
} else { // Otherwise, cleanup the unused results.
if (result->getType().isAddress())
builder.emitDestroyAddrAndFold(loc, result);
else
builder.emitDestroyValueOperation(loc, result);
}
}
cleanupValues();
auto result = joinElements(results, builder, loc);
builder.createReturn(loc, result);
return {thunk, interfaceSubs};
}
// If pullback thunk, return only the desired results and clean up the
// undesired results.
SmallVector<SILValue, 8> pullbackDirectResults;
extractAllElements(ai, builder, pullbackDirectResults);
SmallVector<SILValue, 8> allResults;
collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults);
// Collect pullback semantic result arguments in type order.
unsigned semanticResultArgIdx = 0;
SILFunctionConventions origConv(origFnType, thunk->getModule());
for (auto paramIdx : actualConfig.parameterIndices->getIndices()) {
auto paramInfo = origConv.getParameters()[paramIdx];
if (!paramInfo.isAutoDiffSemanticResult())
continue;
auto semanticResultArg =
*std::next(ai->getAutoDiffSemanticResultArguments().begin(),
semanticResultArgIdx++);
unsigned mappedParamIdx = mapOriginalParameterIndex(paramIdx);
allResults.insert(allResults.begin() + mappedParamIdx, semanticResultArg);
}
assert(allResults.size() == actualConfig.parameterIndices->getNumIndices() &&
"Number of pullback results should match number of differentiability "
"parameters");
SmallVector<SILValue, 8> results;
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
unsigned mappedIndex = mapOriginalParameterIndex(i);
// If result is desired:
// - Do nothing if result is indirect.
// (It was already forwarded to the `apply` instruction).
// - Push it to `results` if result is direct.
auto result = allResults[mappedIndex];
if (desiredConfig.isWrtParameter(i)) {
if (result->getType().isObject())
results.push_back(result);
}
// Otherwise, cleanup the unused results.
else {
if (result->getType().isAddress())
builder.emitDestroyAddrAndFold(loc, result);
else
builder.emitDestroyValueOperation(loc, result);
}
}
// Deallocate local allocations and return final direct result.
cleanupValues();
auto result = joinElements(results, builder, loc);
builder.createReturn(loc, result);
return {thunk, interfaceSubs};
}
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForDerivativeFunction(
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
AutoDiffDerivativeFunctionKind kind, const AutoDiffConfig &desiredConfig,
const AutoDiffConfig &actualConfig, ADContext &adContext) {
LLVM_DEBUG(getADDebugStream()
<< "Getting a subset parameters thunk for derivative "
<< (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp")
<< " function " << derivativeFn
<< " of the original function " << origFnOperand
<< " from " << actualConfig << " to " << desiredConfig << '\n');
auto origFnType = origFnOperand->getType().castTo<SILFunctionType>();
auto &module = fb.getModule();
auto lookupConformance = LookUpConformanceInModule();
// Compute target type for thunking.
auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();
auto targetType = origFnType->getAutoDiffDerivativeFunctionType(
desiredConfig.parameterIndices, desiredConfig.resultIndices, kind,
module.Types, lookupConformance);
auto *caller = derivativeFn->getFunction();
if (targetType->hasArchetype()) {
auto substTargetType =
caller->mapTypeIntoContext(targetType->mapTypeOutOfContext())
->getCanonicalType();
targetType = SILType::getPrimitiveObjectType(substTargetType)
.castTo<SILFunctionType>();
}
assert(derivativeFnType->getNumParameters() ==
targetType->getNumParameters());
assert(derivativeFnType->getNumResults() == targetType->getNumResults());
// Build thunk type.
SubstitutionMap interfaceSubs;
GenericEnvironment *genericEnv = nullptr;
auto thunkType = buildThunkType(derivativeFn->getFunction(), derivativeFnType,
targetType, genericEnv, interfaceSubs,
/*withoutActuallyEscaping*/ false,
DifferentiationThunkKind::IndexSubset);
// FIXME: The logic for resolving `assocRef` does not reapply function
// conversions, which is problematic if `derivativeFn` is a `partial_apply`
// instruction.
StringRef origName;
if (auto *origFnRef =
peerThroughFunctionConversions<FunctionRefInst>(origFnOperand)) {
origName = origFnRef->getReferencedFunction()->getName();
} else if (auto *origMethodInst =
peerThroughFunctionConversions<MethodInst>(origFnOperand)) {
origName = origMethodInst->getMember()
.getAnyFunctionRef()
->getAbstractFunctionDecl()
->getNameStr();
}
assert(!origName.empty() && "Original function name could not be resolved");
Mangle::DifferentiationMangler mangler(adContext.getASTContext());
auto thunkName = mangler.mangleDerivativeFunctionSubsetParametersThunk(
origName, targetType->mapTypeOutOfContext()->getCanonicalType(),
kind, actualConfig.parameterIndices, actualConfig.resultIndices,
desiredConfig.parameterIndices);
auto loc = origFnOperand.getLoc();
auto *thunk = fb.getOrCreateSharedFunction(
loc, thunkName, thunkType, IsBare, IsTransparent,
caller->getSerializedKind(), ProfileCounter(), IsThunk, IsNotDynamic,
IsNotDistributed, IsNotRuntimeAccessible);
if (!thunk->empty())
return {thunk, interfaceSubs};
thunk->setGenericEnvironment(genericEnv);
auto *entry = thunk->createBasicBlock();
SILBuilder builder(entry);
createEntryArguments(thunk);
SubstitutionMap assocSubstMap;
if (auto *partialApply = dyn_cast<PartialApplyInst>(derivativeFn))
assocSubstMap = partialApply->getSubstitutionMap();
// FIXME: The logic for resolving `assocRef` does not reapply function
// conversions, which is problematic if `derivativeFn` is a `partial_apply`
// instruction.
SILValue assocRef;
if (auto *derivativeFnRef =
peerThroughFunctionConversions<FunctionRefInst>(derivativeFn)) {
auto *assoc = derivativeFnRef->getReferencedFunction();
assocRef = builder.createFunctionRef(loc, assoc);
} else if (auto *assocMethodInst =
peerThroughFunctionConversions<WitnessMethodInst>(
derivativeFn)) {
assocRef = builder.createWitnessMethod(
loc, assocMethodInst->getLookupType(),
assocMethodInst->getConformance(), assocMethodInst->getMember(),
thunk->mapTypeIntoContext(assocMethodInst->getType()));
} else if (auto *assocMethodInst =
peerThroughFunctionConversions<ClassMethodInst>(
derivativeFn)) {
auto classOperand = thunk->getArgumentsWithoutIndirectResults().back();
#ifndef NDEBUG
auto classOperandType = assocMethodInst->getOperand()->getType();
assert(classOperand->getType() == classOperandType);
#endif
assocRef = builder.createClassMethod(
loc, classOperand, assocMethodInst->getMember(),
thunk->mapTypeIntoContext(assocMethodInst->getType()));
} else if (auto *diffWitFn = peerThroughFunctionConversions<
DifferentiabilityWitnessFunctionInst>(derivativeFn)) {
assocRef = builder.createDifferentiabilityWitnessFunction(
loc, diffWitFn->getWitnessKind(), diffWitFn->getWitness());
}
assert(assocRef && "Expected derivative function to be resolved");
assocSubstMap = assocSubstMap.subst(thunk->getForwardingSubstitutionMap());
derivativeFnType = assocRef->getType().castTo<SILFunctionType>();
SmallVector<SILValue, 4> arguments;
arguments.append(thunk->getArguments().begin(), thunk->getArguments().end());
assert(arguments.size() ==
derivativeFnType->getNumParameters() +
derivativeFnType->getNumIndirectFormalResults());
auto *apply = builder.createApply(loc, assocRef, assocSubstMap, arguments);
// Extract all direct results.
SmallVector<SILValue, 8> directResults;
extractAllElements(apply, builder, directResults);
auto linearMap = directResults.back();
directResults.pop_back();
auto linearMapType = linearMap->getType().castTo<SILFunctionType>();
auto linearMapTargetType = targetType->getResults()
.back()
.getSILStorageInterfaceType()
.castTo<SILFunctionType>();
auto unsubstLinearMapType = linearMapType->getUnsubstitutedType(module);
auto unsubstLinearMapTargetType =
linearMapTargetType->getUnsubstitutedType(module);
SILFunction *linearMapThunk;
SubstitutionMap linearMapSubs;
std::tie(linearMapThunk, linearMapSubs) =
getOrCreateSubsetParametersThunkForLinearMap(
fb, thunk, origFnType, unsubstLinearMapType,
unsubstLinearMapTargetType, kind, desiredConfig, actualConfig,
adContext);
auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk);
SILValue thunkedLinearMap = linearMap;
if (linearMapType != unsubstLinearMapType) {
thunkedLinearMap = builder.createConvertFunction(
loc, thunkedLinearMap,
SILType::getPrimitiveObjectType(unsubstLinearMapType),
/*withoutActuallyEscaping*/ false);
}
thunkedLinearMap = builder.createPartialApply(
loc, linearMapThunkFRI, linearMapSubs, {thunkedLinearMap},
ParameterConvention::Direct_Guaranteed);
if (linearMapTargetType != unsubstLinearMapTargetType) {
thunkedLinearMap = builder.createConvertFunction(
loc, thunkedLinearMap,
SILType::getPrimitiveObjectType(linearMapTargetType),
/*withoutActuallyEscaping*/ false);
}
assert(origFnType->getNumAutoDiffSemanticResults() > 0);
if (origFnType->getNumResults() > 0 &&
origFnType->getResults().front().isFormalDirect()) {
directResults.push_back(thunkedLinearMap);
auto result = joinElements(directResults, builder, loc);
builder.createReturn(loc, result);
} else {
builder.createReturn(loc, thunkedLinearMap);
}
return {thunk, interfaceSubs};
}
} // end namespace autodiff
} // end namespace swift