mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
983 lines
40 KiB
C++
983 lines
40 KiB
C++
//===--- Thunk.cpp - Automatic differentiation thunks ---------*- C++ -*---===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Automatic differentiation thunk generation utilities.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define DEBUG_TYPE "differentiation"
|
|
|
|
#include "swift/SILOptimizer/Differentiation/Thunk.h"
|
|
#include "swift/SILOptimizer/Differentiation/Common.h"
|
|
|
|
#include "swift/AST/AnyFunctionRef.h"
|
|
#include "swift/AST/GenericSignatureBuilder.h"
|
|
#include "swift/AST/Requirement.h"
|
|
#include "swift/AST/SubstitutionMap.h"
|
|
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
|
|
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
|
|
|
|
namespace swift {
|
|
namespace autodiff {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Thunk helpers
|
|
//===----------------------------------------------------------------------===//
|
|
// These helpers are copied/adapted from SILGen. They should be refactored and
|
|
// moved to a shared location.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
|
|
OpenedArchetypeType *openedExistential,
|
|
GenericEnvironment *&genericEnv,
|
|
SubstitutionMap &contextSubs,
|
|
SubstitutionMap &interfaceSubs,
|
|
ArchetypeType *&newArchetype) {
|
|
// If there's no opened existential, we just inherit the generic environment
|
|
// from the parent function.
|
|
if (openedExistential == nullptr) {
|
|
auto genericSig = fn->getLoweredFunctionType()->getSubstGenericSignature();
|
|
genericEnv = fn->getGenericEnvironment();
|
|
interfaceSubs = fn->getForwardingSubstitutionMap();
|
|
contextSubs = interfaceSubs;
|
|
return genericSig;
|
|
}
|
|
|
|
auto &ctx = fn->getASTContext();
|
|
GenericSignatureBuilder builder(ctx);
|
|
|
|
// Add the existing generic signature.
|
|
int depth = 0;
|
|
if (inheritGenericSig) {
|
|
if (auto genericSig =
|
|
fn->getLoweredFunctionType()->getSubstGenericSignature()) {
|
|
builder.addGenericSignature(genericSig);
|
|
depth = genericSig->getGenericParams().back()->getDepth() + 1;
|
|
}
|
|
}
|
|
|
|
// Add a new generic parameter to replace the opened existential.
|
|
auto *newGenericParam = GenericTypeParamType::get(depth, 0, ctx);
|
|
|
|
builder.addGenericParameter(newGenericParam);
|
|
Requirement newRequirement(RequirementKind::Conformance, newGenericParam,
|
|
openedExistential->getOpenedExistentialType());
|
|
auto source =
|
|
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
|
|
builder.addRequirement(newRequirement, source, nullptr);
|
|
|
|
auto genericSig = std::move(builder).computeGenericSignature(
|
|
/*allowConcreteGenericParams=*/true);
|
|
genericEnv = genericSig->getGenericEnvironment();
|
|
|
|
newArchetype =
|
|
genericEnv->mapTypeIntoContext(newGenericParam)->castTo<ArchetypeType>();
|
|
|
|
// Calculate substitutions to map the caller's archetypes to the thunk's
|
|
// archetypes.
|
|
if (auto calleeGenericSig =
|
|
fn->getLoweredFunctionType()->getSubstGenericSignature()) {
|
|
contextSubs = SubstitutionMap::get(
|
|
calleeGenericSig,
|
|
[&](SubstitutableType *type) -> Type {
|
|
return genericEnv->mapTypeIntoContext(type);
|
|
},
|
|
MakeAbstractConformanceForGenericType());
|
|
}
|
|
|
|
// Calculate substitutions to map interface types to the caller's archetypes.
|
|
interfaceSubs = SubstitutionMap::get(
|
|
genericSig,
|
|
[&](SubstitutableType *type) -> Type {
|
|
if (type->isEqual(newGenericParam))
|
|
return openedExistential;
|
|
return fn->mapTypeIntoContext(type);
|
|
},
|
|
MakeAbstractConformanceForGenericType());
|
|
|
|
return genericSig->getCanonicalSignature();
|
|
}
|
|
|
|
CanSILFunctionType buildThunkType(SILFunction *fn,
|
|
CanSILFunctionType &sourceType,
|
|
CanSILFunctionType &expectedType,
|
|
GenericEnvironment *&genericEnv,
|
|
SubstitutionMap &interfaceSubs,
|
|
bool withoutActuallyEscaping,
|
|
DifferentiationThunkKind thunkKind) {
|
|
assert(!expectedType->isPolymorphic() &&
|
|
!expectedType->getCombinedSubstitutions());
|
|
assert(!sourceType->isPolymorphic() &&
|
|
!sourceType->getCombinedSubstitutions());
|
|
|
|
// Cannot build a reabstraction thunk without context. Ownership semantics
|
|
// on the result type are required.
|
|
if (thunkKind == DifferentiationThunkKind::Reabstraction)
|
|
assert(expectedType->getExtInfo().hasContext());
|
|
|
|
// This may inherit @noescape from the expected type. The `@noescape`
|
|
// attribute is only stripped when using this type to materialize a new decl.
|
|
// Use `@convention(thin)` if:
|
|
// - Building a reabstraction thunk type.
|
|
// - Building an index subset thunk type, where the expected type has context
|
|
// (i.e. is `@convention(thick)`).
|
|
auto extInfoBuilder = expectedType->getExtInfo().intoBuilder();
|
|
if (thunkKind == DifferentiationThunkKind::Reabstraction ||
|
|
extInfoBuilder.hasContext()) {
|
|
extInfoBuilder = extInfoBuilder.withRepresentation(
|
|
SILFunctionType::Representation::Thin);
|
|
}
|
|
if (withoutActuallyEscaping)
|
|
extInfoBuilder = extInfoBuilder.withNoEscape(false);
|
|
|
|
// Does the thunk type involve archetypes other than opened existentials?
|
|
bool hasArchetypes = false;
|
|
// Does the thunk type involve an open existential type?
|
|
CanOpenedArchetypeType openedExistential;
|
|
auto archetypeVisitor = [&](CanType t) {
|
|
if (auto archetypeTy = dyn_cast<OpenedArchetypeType>(t)) {
|
|
if (auto opened = dyn_cast<OpenedArchetypeType>(archetypeTy)) {
|
|
assert((openedExistential == CanArchetypeType() ||
|
|
openedExistential == opened) &&
|
|
"one too many open existentials");
|
|
openedExistential = opened;
|
|
} else {
|
|
hasArchetypes = true;
|
|
}
|
|
}
|
|
};
|
|
|
|
// Use the generic signature from the context if the thunk involves
|
|
// generic parameters.
|
|
CanGenericSignature genericSig;
|
|
SubstitutionMap contextSubs;
|
|
ArchetypeType *newArchetype = nullptr;
|
|
|
|
if (expectedType->hasArchetype() || sourceType->hasArchetype()) {
|
|
expectedType.visit(archetypeVisitor);
|
|
sourceType.visit(archetypeVisitor);
|
|
genericSig =
|
|
buildThunkSignature(fn, hasArchetypes, openedExistential, genericEnv,
|
|
contextSubs, interfaceSubs, newArchetype);
|
|
}
|
|
|
|
auto substTypeHelper = [&](SubstitutableType *type) -> Type {
|
|
if (CanType(type) == openedExistential)
|
|
return newArchetype;
|
|
return Type(type).subst(contextSubs);
|
|
};
|
|
auto substConformanceHelper = LookUpConformanceInSubstitutionMap(contextSubs);
|
|
|
|
// Utility function to apply contextSubs, and also replace the
|
|
// opened existential with the new archetype.
|
|
auto substLoweredTypeIntoThunkContext =
|
|
[&](CanSILFunctionType t) -> CanSILFunctionType {
|
|
return SILType::getPrimitiveObjectType(t)
|
|
.subst(fn->getModule(), substTypeHelper, substConformanceHelper)
|
|
.castTo<SILFunctionType>();
|
|
};
|
|
|
|
sourceType = substLoweredTypeIntoThunkContext(sourceType);
|
|
expectedType = substLoweredTypeIntoThunkContext(expectedType);
|
|
|
|
// If our parent function was pseudogeneric, this thunk must also be
|
|
// pseudogeneric, since we have no way to pass generic parameters.
|
|
if (genericSig)
|
|
if (fn->getLoweredFunctionType()->isPseudogeneric())
|
|
extInfoBuilder = extInfoBuilder.withIsPseudogeneric();
|
|
|
|
// Add the function type as the parameter.
|
|
auto contextConvention =
|
|
SILType::getPrimitiveObjectType(sourceType).isTrivial(*fn)
|
|
? ParameterConvention::Direct_Unowned
|
|
: ParameterConvention::Direct_Guaranteed;
|
|
SmallVector<SILParameterInfo, 4> params;
|
|
params.append(expectedType->getParameters().begin(),
|
|
expectedType->getParameters().end());
|
|
// Add reabstraction function parameter only if building a reabstraction thunk
|
|
// type.
|
|
if (thunkKind == DifferentiationThunkKind::Reabstraction)
|
|
params.push_back({sourceType, sourceType->getExtInfo().hasContext()
|
|
? contextConvention
|
|
: ParameterConvention::Direct_Unowned});
|
|
|
|
auto mapTypeOutOfContext = [&](CanType type) -> CanType {
|
|
return type->mapTypeOutOfContext()->getCanonicalType(genericSig);
|
|
};
|
|
|
|
// Map the parameter and expected types out of context to get the interface
|
|
// type of the thunk.
|
|
SmallVector<SILParameterInfo, 4> interfaceParams;
|
|
interfaceParams.reserve(params.size());
|
|
for (auto ¶m : params) {
|
|
auto interfaceParam = param.map(mapTypeOutOfContext);
|
|
interfaceParams.push_back(interfaceParam);
|
|
}
|
|
|
|
SmallVector<SILYieldInfo, 4> interfaceYields;
|
|
for (auto &yield : expectedType->getYields()) {
|
|
auto interfaceYield = yield.map(mapTypeOutOfContext);
|
|
interfaceYields.push_back(interfaceYield);
|
|
}
|
|
|
|
SmallVector<SILResultInfo, 4> interfaceResults;
|
|
for (auto &result : expectedType->getResults()) {
|
|
auto interfaceResult = result.map(mapTypeOutOfContext);
|
|
interfaceResults.push_back(interfaceResult);
|
|
}
|
|
|
|
Optional<SILResultInfo> interfaceErrorResult;
|
|
if (expectedType->hasErrorResult()) {
|
|
auto errorResult = expectedType->getErrorResult();
|
|
interfaceErrorResult = errorResult.map(mapTypeOutOfContext);
|
|
}
|
|
|
|
// The type of the thunk function.
|
|
return SILFunctionType::get(
|
|
genericSig, extInfoBuilder.build(), expectedType->getCoroutineKind(),
|
|
ParameterConvention::Direct_Unowned, interfaceParams, interfaceYields,
|
|
interfaceResults, interfaceErrorResult,
|
|
expectedType->getPatternSubstitutions(), SubstitutionMap(),
|
|
fn->getASTContext());
|
|
}
|
|
|
|
/// Forward function arguments, handling ownership convention mismatches.
|
|
/// Adapted from `forwardFunctionArguments` in SILGenPoly.cpp.
|
|
///
|
|
/// Forwarded arguments are appended to `forwardedArgs`.
|
|
///
|
|
/// Local allocations are appended to `localAllocations`. They need to be
|
|
/// deallocated via `dealloc_stack`.
|
|
///
|
|
/// Local values requiring cleanup are appended to `valuesToCleanup`.
|
|
static void forwardFunctionArgumentsConvertingOwnership(
|
|
SILBuilder &builder, SILLocation loc, CanSILFunctionType fromTy,
|
|
CanSILFunctionType toTy, ArrayRef<SILArgument *> originalArgs,
|
|
SmallVectorImpl<SILValue> &forwardedArgs,
|
|
SmallVectorImpl<AllocStackInst *> &localAllocations,
|
|
SmallVectorImpl<SILValue> &valuesToCleanup) {
|
|
auto fromParameters = fromTy->getParameters();
|
|
auto toParameters = toTy->getParameters();
|
|
assert(fromParameters.size() == toParameters.size());
|
|
assert(fromParameters.size() == originalArgs.size());
|
|
for (auto index : indices(originalArgs)) {
|
|
auto &arg = originalArgs[index];
|
|
auto fromParam = fromParameters[index];
|
|
auto toParam = toParameters[index];
|
|
// To convert guaranteed argument to be owned, create a copy.
|
|
if (fromParam.isConsumed() && !toParam.isConsumed()) {
|
|
// If the argument has an object type, create a `copy_value`.
|
|
if (arg->getType().isObject()) {
|
|
auto argCopy = builder.emitCopyValueOperation(loc, arg);
|
|
forwardedArgs.push_back(argCopy);
|
|
continue;
|
|
}
|
|
// If the argument has an address type, create a local allocation and
|
|
// `copy_addr` its contents to the local allocation.
|
|
auto *alloc = builder.createAllocStack(loc, arg->getType());
|
|
builder.createCopyAddr(loc, arg, alloc, IsNotTake, IsInitialization);
|
|
localAllocations.push_back(alloc);
|
|
forwardedArgs.push_back(alloc);
|
|
continue;
|
|
}
|
|
// To convert owned argument to be guaranteed, borrow the argument.
|
|
if (fromParam.isGuaranteed() && !toParam.isGuaranteed()) {
|
|
auto bbi = builder.emitBeginBorrowOperation(loc, arg);
|
|
forwardedArgs.push_back(bbi);
|
|
valuesToCleanup.push_back(bbi);
|
|
valuesToCleanup.push_back(arg);
|
|
continue;
|
|
}
|
|
// Otherwise, simply forward the argument.
|
|
forwardedArgs.push_back(arg);
|
|
}
|
|
}
|
|
|
|
SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
|
|
SILModule &module, SILLocation loc,
|
|
SILFunction *caller,
|
|
CanSILFunctionType fromType,
|
|
CanSILFunctionType toType) {
|
|
assert(!fromType->getCombinedSubstitutions());
|
|
assert(!toType->getCombinedSubstitutions());
|
|
|
|
SubstitutionMap interfaceSubs;
|
|
GenericEnvironment *genericEnv = nullptr;
|
|
auto thunkType =
|
|
buildThunkType(caller, fromType, toType, genericEnv, interfaceSubs,
|
|
/*withoutActuallyEscaping*/ false,
|
|
DifferentiationThunkKind::Reabstraction);
|
|
auto thunkDeclType =
|
|
thunkType->getWithExtInfo(thunkType->getExtInfo().withNoEscape(false));
|
|
|
|
auto fromInterfaceType = fromType->mapTypeOutOfContext()->getCanonicalType();
|
|
auto toInterfaceType = toType->mapTypeOutOfContext()->getCanonicalType();
|
|
|
|
Mangle::ASTMangler mangler;
|
|
std::string name = mangler.mangleReabstractionThunkHelper(
|
|
thunkType, fromInterfaceType, toInterfaceType, Type(),
|
|
module.getSwiftModule());
|
|
|
|
auto *thunk = fb.getOrCreateSharedFunction(
|
|
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
|
|
ProfileCounter(), IsReabstractionThunk, IsNotDynamic);
|
|
if (!thunk->empty())
|
|
return thunk;
|
|
|
|
thunk->setGenericEnvironment(genericEnv);
|
|
auto *entry = thunk->createBasicBlock();
|
|
SILBuilder builder(entry);
|
|
createEntryArguments(thunk);
|
|
|
|
SILFunctionConventions fromConv(fromType, module);
|
|
SILFunctionConventions toConv(toType, module);
|
|
assert(toConv.useLoweredAddresses());
|
|
|
|
// Forward thunk arguments, handling ownership convention mismatches.
|
|
SmallVector<SILValue, 4> forwardedArgs;
|
|
for (auto indRes : thunk->getIndirectResults())
|
|
forwardedArgs.push_back(indRes);
|
|
SmallVector<AllocStackInst *, 4> localAllocations;
|
|
SmallVector<SILValue, 4> valuesToCleanup;
|
|
forwardFunctionArgumentsConvertingOwnership(
|
|
builder, loc, fromType, toType,
|
|
thunk->getArgumentsWithoutIndirectResults().drop_back(), forwardedArgs,
|
|
localAllocations, valuesToCleanup);
|
|
|
|
SmallVector<SILValue, 4> arguments;
|
|
auto toArgIter = forwardedArgs.begin();
|
|
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
|
|
|
|
auto createAllocStack = [&](SILType type) {
|
|
auto *alloc = builder.createAllocStack(loc, type);
|
|
localAllocations.push_back(alloc);
|
|
return alloc;
|
|
};
|
|
|
|
// Handle indirect results.
|
|
assert(fromType->getNumResults() == toType->getNumResults());
|
|
for (unsigned resIdx : range(toType->getNumResults())) {
|
|
auto fromRes = fromConv.getResults()[resIdx];
|
|
auto toRes = toConv.getResults()[resIdx];
|
|
// No abstraction mismatch.
|
|
if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
|
|
// If result types are indirect, directly pass as next argument.
|
|
if (toRes.isFormalIndirect())
|
|
useNextArgument();
|
|
continue;
|
|
}
|
|
// Convert indirect result to direct result.
|
|
if (fromRes.isFormalIndirect()) {
|
|
SILType resultTy =
|
|
fromConv.getSILType(fromRes, builder.getTypeExpansionContext());
|
|
assert(resultTy.isAddress());
|
|
auto *indRes = createAllocStack(resultTy);
|
|
arguments.push_back(indRes);
|
|
continue;
|
|
}
|
|
// Convert direct result to indirect result.
|
|
// Increment thunk argument iterator; reabstraction handled later.
|
|
++toArgIter;
|
|
}
|
|
|
|
// Reabstract parameters.
|
|
assert(toType->getNumParameters() == fromType->getNumParameters());
|
|
for (unsigned paramIdx : range(toType->getNumParameters())) {
|
|
auto fromParam = fromConv.getParameters()[paramIdx];
|
|
auto toParam = toConv.getParameters()[paramIdx];
|
|
// No abstraction mismatch. Directly use next argument.
|
|
if (fromParam.isFormalIndirect() == toParam.isFormalIndirect()) {
|
|
useNextArgument();
|
|
continue;
|
|
}
|
|
// Convert indirect parameter to direct parameter.
|
|
if (fromParam.isFormalIndirect()) {
|
|
auto paramTy = fromConv.getSILType(fromType->getParameters()[paramIdx],
|
|
builder.getTypeExpansionContext());
|
|
if (!paramTy.hasArchetype())
|
|
paramTy = thunk->mapTypeIntoContext(paramTy);
|
|
assert(paramTy.isAddress());
|
|
auto toArg = *toArgIter++;
|
|
auto *buf = createAllocStack(toArg->getType());
|
|
toArg = builder.emitCopyValueOperation(loc, toArg);
|
|
builder.emitStoreValueOperation(loc, toArg, buf,
|
|
StoreOwnershipQualifier::Init);
|
|
valuesToCleanup.push_back(buf);
|
|
arguments.push_back(buf);
|
|
continue;
|
|
}
|
|
// Convert direct parameter to indirect parameter.
|
|
assert(toParam.isFormalIndirect());
|
|
auto toArg = *toArgIter++;
|
|
auto load = builder.emitLoadBorrowOperation(loc, toArg);
|
|
if (isa<LoadBorrowInst>(load))
|
|
valuesToCleanup.push_back(load);
|
|
arguments.push_back(load);
|
|
}
|
|
|
|
auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();
|
|
auto *apply = builder.createApply(loc, fnArg, SubstitutionMap(), arguments);
|
|
|
|
// Get return elements.
|
|
SmallVector<SILValue, 4> results;
|
|
// Extract all direct results.
|
|
SmallVector<SILValue, 4> directResults;
|
|
extractAllElements(apply, builder, directResults);
|
|
|
|
auto fromDirResultsIter = directResults.begin();
|
|
auto fromIndResultsIter = apply->getIndirectSILResults().begin();
|
|
auto toIndResultsIter = thunk->getIndirectResults().begin();
|
|
// Reabstract results.
|
|
for (unsigned resIdx : range(toType->getNumResults())) {
|
|
auto fromRes = fromConv.getResults()[resIdx];
|
|
auto toRes = toConv.getResults()[resIdx];
|
|
// Check function-typed results.
|
|
if (isa<SILFunctionType>(fromRes.getInterfaceType()) &&
|
|
isa<SILFunctionType>(toRes.getInterfaceType())) {
|
|
auto fromFnType = cast<SILFunctionType>(fromRes.getInterfaceType());
|
|
auto toFnType = cast<SILFunctionType>(toRes.getInterfaceType());
|
|
auto fromUnsubstFnType = fromFnType->getUnsubstitutedType(module);
|
|
auto toUnsubstFnType = toFnType->getUnsubstitutedType(module);
|
|
// If unsubstituted function types are not equal, perform reabstraction.
|
|
if (fromUnsubstFnType != toUnsubstFnType) {
|
|
auto fromFn = *fromDirResultsIter++;
|
|
auto newFromFn = reabstractFunction(
|
|
builder, fb, loc, fromFn, toFnType,
|
|
[](SubstitutionMap substMap) { return substMap; });
|
|
results.push_back(newFromFn);
|
|
continue;
|
|
}
|
|
}
|
|
// No abstraction mismatch.
|
|
if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
|
|
// If result types are direct, add call result as direct thunk result.
|
|
if (toRes.isFormalDirect())
|
|
results.push_back(*fromDirResultsIter++);
|
|
// If result types are indirect, increment indirect result iterators.
|
|
else {
|
|
++fromIndResultsIter;
|
|
++toIndResultsIter;
|
|
}
|
|
continue;
|
|
}
|
|
// Load direct results from indirect results.
|
|
if (fromRes.isFormalIndirect()) {
|
|
auto indRes = *fromIndResultsIter++;
|
|
auto load = builder.emitLoadValueOperation(loc, indRes,
|
|
LoadOwnershipQualifier::Take);
|
|
results.push_back(load);
|
|
continue;
|
|
}
|
|
// Store direct results to indirect results.
|
|
assert(toRes.isFormalIndirect());
|
|
#ifndef NDEBUG
|
|
SILType resultTy =
|
|
toConv.getSILType(toRes, builder.getTypeExpansionContext());
|
|
assert(resultTy.isAddress());
|
|
#endif
|
|
auto indRes = *toIndResultsIter++;
|
|
auto dirRes = *fromDirResultsIter++;
|
|
builder.emitStoreValueOperation(loc, dirRes, indRes,
|
|
StoreOwnershipQualifier::Init);
|
|
}
|
|
auto retVal = joinElements(results, builder, loc);
|
|
|
|
// Clean up local values.
|
|
// Guaranteed values need an `end_borrow`.
|
|
// Owned values need to be destroyed.
|
|
for (auto arg : valuesToCleanup) {
|
|
switch (arg.getOwnershipKind()) {
|
|
case OwnershipKind::Any:
|
|
llvm_unreachable("value with any ownership kind?!");
|
|
case OwnershipKind::Guaranteed:
|
|
builder.emitEndBorrowOperation(loc, arg);
|
|
break;
|
|
case OwnershipKind::Owned:
|
|
case OwnershipKind::Unowned:
|
|
case OwnershipKind::None:
|
|
builder.emitDestroyOperation(loc, arg);
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Deallocate local allocations.
|
|
for (auto *alloc : llvm::reverse(localAllocations))
|
|
builder.createDeallocStack(loc, alloc);
|
|
|
|
// Create return.
|
|
builder.createReturn(loc, retVal);
|
|
|
|
LLVM_DEBUG(auto &s = getADDebugStream() << "Created reabstraction thunk.\n";
|
|
s << " From type: " << fromType << '\n';
|
|
s << " To type: " << toType << '\n'; s << '\n'
|
|
<< *thunk);
|
|
|
|
return thunk;
|
|
}
|
|
|
|
SILValue reabstractFunction(
|
|
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
|
|
SILValue fn, CanSILFunctionType toType,
|
|
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions) {
|
|
auto &module = *fn->getModule();
|
|
auto fromType = fn->getType().getAs<SILFunctionType>();
|
|
auto unsubstFromType = fromType->getUnsubstitutedType(module);
|
|
auto unsubstToType = toType->getUnsubstitutedType(module);
|
|
|
|
auto *thunk = getOrCreateReabstractionThunk(fb, module, loc,
|
|
/*caller*/ fn->getFunction(),
|
|
unsubstFromType, unsubstToType);
|
|
auto *thunkRef = builder.createFunctionRef(loc, thunk);
|
|
|
|
if (fromType != unsubstFromType)
|
|
fn = builder.createConvertFunction(
|
|
loc, fn, SILType::getPrimitiveObjectType(unsubstFromType),
|
|
/*withoutActuallyEscaping*/ false);
|
|
|
|
fn = builder.createPartialApply(
|
|
loc, thunkRef, remapSubstitutions(thunk->getForwardingSubstitutionMap()),
|
|
{fn}, fromType->getCalleeConvention());
|
|
|
|
if (toType != unsubstToType)
|
|
fn = builder.createConvertFunction(loc, fn,
|
|
SILType::getPrimitiveObjectType(toType),
|
|
/*withoutActuallyEscaping*/ false);
|
|
|
|
return fn;
|
|
}
|
|
|
|
std::pair<SILFunction *, SubstitutionMap>
|
|
getOrCreateSubsetParametersThunkForLinearMap(
|
|
SILOptFunctionBuilder &fb, SILFunction *parentThunk,
|
|
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
|
|
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
|
|
const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig,
|
|
ADContext &adContext) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Getting a subset parameters thunk for " << linearMapType
|
|
<< " from " << actualConfig << " to " << desiredConfig << '\n');
|
|
|
|
assert(!linearMapType->getCombinedSubstitutions());
|
|
assert(!targetType->getCombinedSubstitutions());
|
|
SubstitutionMap interfaceSubs;
|
|
GenericEnvironment *genericEnv = nullptr;
|
|
auto thunkType = buildThunkType(parentThunk, linearMapType, targetType,
|
|
genericEnv, interfaceSubs,
|
|
/*withoutActuallyEscaping*/ true,
|
|
DifferentiationThunkKind::Reabstraction);
|
|
|
|
Mangle::DifferentiationMangler mangler;
|
|
auto fromInterfaceType =
|
|
linearMapType->mapTypeOutOfContext()->getCanonicalType();
|
|
|
|
auto thunkName = mangler.mangleLinearMapSubsetParametersThunk(
|
|
fromInterfaceType, kind.getLinearMapKind(),
|
|
actualConfig.parameterIndices, actualConfig.resultIndices,
|
|
desiredConfig.parameterIndices);
|
|
|
|
auto loc = parentThunk->getLocation();
|
|
auto *thunk = fb.getOrCreateSharedFunction(
|
|
loc, thunkName, thunkType, IsBare, IsTransparent, IsSerialized,
|
|
ProfileCounter(), IsThunk, IsNotDynamic);
|
|
|
|
if (!thunk->empty())
|
|
return {thunk, interfaceSubs};
|
|
|
|
thunk->setGenericEnvironment(genericEnv);
|
|
auto *entry = thunk->createBasicBlock();
|
|
TangentBuilder builder(entry, adContext);
|
|
createEntryArguments(thunk);
|
|
|
|
// Get arguments.
|
|
SmallVector<SILValue, 4> arguments;
|
|
SmallVector<AllocStackInst *, 4> localAllocations;
|
|
SmallVector<SILValue, 4> valuesToCleanup;
|
|
auto cleanupValues = [&]() {
|
|
for (auto value : llvm::reverse(valuesToCleanup))
|
|
builder.emitDestroyOperation(loc, value);
|
|
|
|
for (auto *alloc : llvm::reverse(localAllocations))
|
|
builder.createDeallocStack(loc, alloc);
|
|
};
|
|
|
|
// Build a `.zero` argument for the given `Differentiable`-conforming type.
|
|
auto buildZeroArgument = [&](SILType zeroSILType) {
|
|
auto zeroSILObjType = zeroSILType.getObjectType();
|
|
auto zeroType = zeroSILType.getASTType();
|
|
auto *swiftMod = parentThunk->getModule().getSwiftModule();
|
|
auto tangentSpace =
|
|
zeroType->getAutoDiffTangentSpace(LookUpConformanceInModule(swiftMod));
|
|
assert(tangentSpace && "No tangent space for this type");
|
|
switch (tangentSpace->getKind()) {
|
|
case TangentSpace::Kind::TangentVector: {
|
|
auto *buf = builder.createAllocStack(loc, zeroSILObjType);
|
|
localAllocations.push_back(buf);
|
|
builder.emitZeroIntoBuffer(loc, buf, IsInitialization);
|
|
if (zeroSILType.isAddress()) {
|
|
valuesToCleanup.push_back(buf);
|
|
arguments.push_back(buf);
|
|
} else {
|
|
auto arg = builder.emitLoadValueOperation(loc, buf,
|
|
LoadOwnershipQualifier::Take);
|
|
valuesToCleanup.push_back(arg);
|
|
arguments.push_back(arg);
|
|
}
|
|
break;
|
|
}
|
|
case TangentSpace::Kind::Tuple: {
|
|
llvm_unreachable("Unimplemented: Handle zero initialization for tuples");
|
|
}
|
|
}
|
|
};
|
|
|
|
// The indices in `actualConfig` and `desiredConfig` are with respect to the
|
|
// original function. However, the differential parameters and pullback
|
|
// results may already be w.r.t. a subset. We create a map between the
|
|
// original function's actual parameter indices and the linear map's actual
|
|
// indices.
|
|
// Example:
|
|
// Original: (T0, T1, T2) -> R
|
|
// Actual indices: 0, 2
|
|
// Original differential: (T0, T2) -> R
|
|
// Original pullback: R -> (T0, T2)
|
|
// Desired indices w.r.t. original: 2
|
|
// Desired indices w.r.t. linear map: 1
|
|
SmallVector<unsigned, 4> actualParamIndicesMap(
|
|
actualConfig.parameterIndices->getCapacity(), UINT_MAX);
|
|
{
|
|
unsigned indexInBitVec = 0;
|
|
for (auto index : actualConfig.parameterIndices->getIndices()) {
|
|
actualParamIndicesMap[index] = indexInBitVec;
|
|
++indexInBitVec;
|
|
}
|
|
}
|
|
auto mapOriginalParameterIndex = [&](unsigned index) -> unsigned {
|
|
auto mappedIndex = actualParamIndicesMap[index];
|
|
assert(mappedIndex < actualConfig.parameterIndices->getCapacity());
|
|
return mappedIndex;
|
|
};
|
|
|
|
switch (kind) {
|
|
// Differential arguments are:
|
|
// - All indirect results, followed by:
|
|
// - An interleaving of:
|
|
// - Thunk arguments (when parameter index is in both desired and actual
|
|
// indices).
|
|
// - Zeros (when parameter is not in desired indices).
|
|
case AutoDiffDerivativeFunctionKind::JVP: {
|
|
// Forward all indirect results.
|
|
arguments.append(thunk->getIndirectResults().begin(),
|
|
thunk->getIndirectResults().end());
|
|
auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin();
|
|
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
|
|
// Iterate over actual indices.
|
|
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
|
|
// If index is desired, use next argument.
|
|
if (desiredConfig.isWrtParameter(i)) {
|
|
useNextArgument();
|
|
}
|
|
// Otherwise, construct and use a zero argument.
|
|
else {
|
|
auto zeroSILType =
|
|
linearMapType->getParameters()[mapOriginalParameterIndex(i)]
|
|
.getSILStorageInterfaceType();
|
|
buildZeroArgument(zeroSILType);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
// Pullback arguments are:
|
|
// - An interleaving of:
|
|
// - Thunk indirect results (when parameter index is in both desired and
|
|
// actual indices).
|
|
// - Zeros (when parameter is not in desired indices).
|
|
// - All actual arguments.
|
|
case AutoDiffDerivativeFunctionKind::VJP: {
|
|
auto toIndirectResultsIter = thunk->getIndirectResults().begin();
|
|
auto useNextIndirectResult = [&]() {
|
|
arguments.push_back(*toIndirectResultsIter++);
|
|
};
|
|
// Collect pullback arguments.
|
|
unsigned pullbackResultIndex = 0;
|
|
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
|
|
auto origParamInfo = origFnType->getParameters()[i];
|
|
// Skip original `inout` parameters. All non-indirect-result pullback
|
|
// arguments (including `inout` arguments) are appended to `arguments`
|
|
// later.
|
|
if (origParamInfo.isIndirectMutating())
|
|
continue;
|
|
auto resultInfo = linearMapType->getResults()[pullbackResultIndex];
|
|
assert(pullbackResultIndex < linearMapType->getNumResults());
|
|
++pullbackResultIndex;
|
|
// Skip pullback direct results. Only indirect results are relevant as
|
|
// arguments.
|
|
if (resultInfo.isFormalDirect())
|
|
continue;
|
|
// If index is desired, use next pullback indirect result.
|
|
if (desiredConfig.isWrtParameter(i)) {
|
|
useNextIndirectResult();
|
|
continue;
|
|
}
|
|
// Otherwise, allocate and use an uninitialized pullback indirect result.
|
|
auto *indirectResult = builder.createAllocStack(
|
|
loc, resultInfo.getSILStorageInterfaceType());
|
|
localAllocations.push_back(indirectResult);
|
|
arguments.push_back(indirectResult);
|
|
}
|
|
// Forward all actual non-indirect-result arguments.
|
|
arguments.append(thunk->getArgumentsWithoutIndirectResults().begin(),
|
|
thunk->getArgumentsWithoutIndirectResults().end() - 1);
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Get the linear map thunk argument and apply it.
|
|
auto *linearMap = thunk->getArguments().back();
|
|
auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments);
|
|
|
|
// If differential thunk, deallocate local allocations and directly return
|
|
// `apply` result.
|
|
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
|
|
cleanupValues();
|
|
builder.createReturn(loc, ai);
|
|
return {thunk, interfaceSubs};
|
|
}
|
|
|
|
// If pullback thunk, return only the desired results and clean up the
|
|
// undesired results.
|
|
SmallVector<SILValue, 8> pullbackDirectResults;
|
|
extractAllElements(ai, builder, pullbackDirectResults);
|
|
SmallVector<SILValue, 8> allResults;
|
|
collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults);
|
|
// Collect pullback `inout` arguments in type order.
|
|
unsigned inoutArgIdx = 0;
|
|
SILFunctionConventions origConv(origFnType, thunk->getModule());
|
|
for (auto paramIdx : actualConfig.parameterIndices->getIndices()) {
|
|
auto paramInfo = origConv.getParameters()[paramIdx];
|
|
if (!paramInfo.isIndirectMutating())
|
|
continue;
|
|
auto inoutArg = *std::next(ai->getInoutArguments().begin(), inoutArgIdx++);
|
|
unsigned mappedParamIdx = mapOriginalParameterIndex(paramIdx);
|
|
allResults.insert(allResults.begin() + mappedParamIdx, inoutArg);
|
|
}
|
|
assert(allResults.size() == actualConfig.parameterIndices->getNumIndices() &&
|
|
"Number of pullback results should match number of differentiability "
|
|
"parameters");
|
|
|
|
SmallVector<SILValue, 8> results;
|
|
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
|
|
unsigned mappedIndex = mapOriginalParameterIndex(i);
|
|
// If result is desired:
|
|
// - Do nothing if result is indirect.
|
|
// (It was already forwarded to the `apply` instruction).
|
|
// - Push it to `results` if result is direct.
|
|
auto result = allResults[mappedIndex];
|
|
if (desiredConfig.isWrtParameter(i)) {
|
|
if (result->getType().isObject())
|
|
results.push_back(result);
|
|
}
|
|
// Otherwise, cleanup the unused results.
|
|
else {
|
|
if (result->getType().isAddress())
|
|
builder.emitDestroyAddrAndFold(loc, result);
|
|
else
|
|
builder.emitDestroyValueOperation(loc, result);
|
|
}
|
|
}
|
|
// Deallocate local allocations and return final direct result.
|
|
cleanupValues();
|
|
auto result = joinElements(results, builder, loc);
|
|
builder.createReturn(loc, result);
|
|
|
|
return {thunk, interfaceSubs};
|
|
}
|
|
|
|
std::pair<SILFunction *, SubstitutionMap>
|
|
getOrCreateSubsetParametersThunkForDerivativeFunction(
|
|
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
|
|
AutoDiffDerivativeFunctionKind kind, const AutoDiffConfig &desiredConfig,
|
|
const AutoDiffConfig &actualConfig, ADContext &adContext) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Getting a subset parameters thunk for derivative function "
|
|
<< derivativeFn << " of the original function " << origFnOperand
|
|
<< " from " << actualConfig << " to " << desiredConfig << '\n');
|
|
|
|
auto origFnType = origFnOperand->getType().castTo<SILFunctionType>();
|
|
auto &module = fb.getModule();
|
|
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
|
|
|
|
// Compute target type for thunking.
|
|
auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();
|
|
auto targetType = origFnType->getAutoDiffDerivativeFunctionType(
|
|
desiredConfig.parameterIndices, desiredConfig.resultIndices, kind,
|
|
module.Types, lookupConformance);
|
|
auto *caller = derivativeFn->getFunction();
|
|
if (targetType->hasArchetype()) {
|
|
auto substTargetType =
|
|
caller->mapTypeIntoContext(targetType->mapTypeOutOfContext())
|
|
->getCanonicalType();
|
|
targetType = SILType::getPrimitiveObjectType(substTargetType)
|
|
.castTo<SILFunctionType>();
|
|
}
|
|
assert(derivativeFnType->getNumParameters() ==
|
|
targetType->getNumParameters());
|
|
assert(derivativeFnType->getNumResults() == targetType->getNumResults());
|
|
|
|
// Build thunk type.
|
|
SubstitutionMap interfaceSubs;
|
|
GenericEnvironment *genericEnv = nullptr;
|
|
auto thunkType = buildThunkType(derivativeFn->getFunction(), derivativeFnType,
|
|
targetType, genericEnv, interfaceSubs,
|
|
/*withoutActuallyEscaping*/ false,
|
|
DifferentiationThunkKind::IndexSubset);
|
|
|
|
// FIXME: The logic for resolving `assocRef` does not reapply function
|
|
// conversions, which is problematic if `derivativeFn` is a `partial_apply`
|
|
// instruction.
|
|
StringRef origName;
|
|
if (auto *origFnRef =
|
|
peerThroughFunctionConversions<FunctionRefInst>(origFnOperand)) {
|
|
origName = origFnRef->getReferencedFunction()->getName();
|
|
} else if (auto *origMethodInst =
|
|
peerThroughFunctionConversions<MethodInst>(origFnOperand)) {
|
|
origName = origMethodInst->getMember()
|
|
.getAnyFunctionRef()
|
|
->getAbstractFunctionDecl()
|
|
->getNameStr();
|
|
}
|
|
assert(!origName.empty() && "Original function name could not be resolved");
|
|
Mangle::DifferentiationMangler mangler;
|
|
auto thunkName = mangler.mangleDerivativeFunctionSubsetParametersThunk(
|
|
origName, targetType->mapTypeOutOfContext()->getCanonicalType(),
|
|
kind, actualConfig.parameterIndices, actualConfig.resultIndices,
|
|
desiredConfig.parameterIndices);
|
|
|
|
auto loc = origFnOperand.getLoc();
|
|
auto *thunk = fb.getOrCreateSharedFunction(
|
|
loc, thunkName, thunkType, IsBare, IsTransparent, caller->isSerialized(),
|
|
ProfileCounter(), IsThunk, IsNotDynamic);
|
|
|
|
if (!thunk->empty())
|
|
return {thunk, interfaceSubs};
|
|
|
|
thunk->setGenericEnvironment(genericEnv);
|
|
auto *entry = thunk->createBasicBlock();
|
|
SILBuilder builder(entry);
|
|
createEntryArguments(thunk);
|
|
|
|
SubstitutionMap assocSubstMap;
|
|
if (auto *partialApply = dyn_cast<PartialApplyInst>(derivativeFn))
|
|
assocSubstMap = partialApply->getSubstitutionMap();
|
|
|
|
// FIXME: The logic for resolving `assocRef` does not reapply function
|
|
// conversions, which is problematic if `derivativeFn` is a `partial_apply`
|
|
// instruction.
|
|
SILValue assocRef;
|
|
if (auto *derivativeFnRef =
|
|
peerThroughFunctionConversions<FunctionRefInst>(derivativeFn)) {
|
|
auto *assoc = derivativeFnRef->getReferencedFunction();
|
|
assocRef = builder.createFunctionRef(loc, assoc);
|
|
} else if (auto *assocMethodInst =
|
|
peerThroughFunctionConversions<WitnessMethodInst>(
|
|
derivativeFn)) {
|
|
assocRef = builder.createWitnessMethod(
|
|
loc, assocMethodInst->getLookupType(),
|
|
assocMethodInst->getConformance(), assocMethodInst->getMember(),
|
|
thunk->mapTypeIntoContext(assocMethodInst->getType()));
|
|
} else if (auto *assocMethodInst =
|
|
peerThroughFunctionConversions<ClassMethodInst>(
|
|
derivativeFn)) {
|
|
auto classOperand = thunk->getArgumentsWithoutIndirectResults().back();
|
|
#ifndef NDEBUG
|
|
auto classOperandType = assocMethodInst->getOperand()->getType();
|
|
assert(classOperand->getType() == classOperandType);
|
|
#endif
|
|
assocRef = builder.createClassMethod(
|
|
loc, classOperand, assocMethodInst->getMember(),
|
|
thunk->mapTypeIntoContext(assocMethodInst->getType()));
|
|
} else if (auto *diffWitFn = peerThroughFunctionConversions<
|
|
DifferentiabilityWitnessFunctionInst>(derivativeFn)) {
|
|
assocRef = builder.createDifferentiabilityWitnessFunction(
|
|
loc, diffWitFn->getWitnessKind(), diffWitFn->getWitness());
|
|
}
|
|
assert(assocRef && "Expected derivative function to be resolved");
|
|
|
|
assocSubstMap = assocSubstMap.subst(thunk->getForwardingSubstitutionMap());
|
|
derivativeFnType = assocRef->getType().castTo<SILFunctionType>();
|
|
|
|
SmallVector<SILValue, 4> arguments;
|
|
arguments.append(thunk->getArguments().begin(), thunk->getArguments().end());
|
|
assert(arguments.size() ==
|
|
derivativeFnType->getNumParameters() +
|
|
derivativeFnType->getNumIndirectFormalResults());
|
|
auto *apply = builder.createApply(loc, assocRef, assocSubstMap, arguments);
|
|
|
|
// Extract all direct results.
|
|
SmallVector<SILValue, 8> directResults;
|
|
extractAllElements(apply, builder, directResults);
|
|
auto originalDirectResults = ArrayRef<SILValue>(directResults).drop_back(1);
|
|
auto originalDirectResult =
|
|
joinElements(originalDirectResults, builder, apply->getLoc());
|
|
auto linearMap = directResults.back();
|
|
|
|
auto linearMapType = linearMap->getType().castTo<SILFunctionType>();
|
|
auto linearMapTargetType = targetType->getResults()
|
|
.back()
|
|
.getSILStorageInterfaceType()
|
|
.castTo<SILFunctionType>();
|
|
auto unsubstLinearMapType = linearMapType->getUnsubstitutedType(module);
|
|
auto unsubstLinearMapTargetType =
|
|
linearMapTargetType->getUnsubstitutedType(module);
|
|
|
|
SILFunction *linearMapThunk;
|
|
SubstitutionMap linearMapSubs;
|
|
std::tie(linearMapThunk, linearMapSubs) =
|
|
getOrCreateSubsetParametersThunkForLinearMap(
|
|
fb, thunk, origFnType, unsubstLinearMapType,
|
|
unsubstLinearMapTargetType, kind, desiredConfig, actualConfig,
|
|
adContext);
|
|
|
|
auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk);
|
|
SILValue thunkedLinearMap = linearMap;
|
|
if (linearMapType != unsubstLinearMapType) {
|
|
thunkedLinearMap = builder.createConvertFunction(
|
|
loc, thunkedLinearMap,
|
|
SILType::getPrimitiveObjectType(unsubstLinearMapType),
|
|
/*withoutActuallyEscaping*/ false);
|
|
}
|
|
thunkedLinearMap = builder.createPartialApply(
|
|
loc, linearMapThunkFRI, linearMapSubs, {thunkedLinearMap},
|
|
ParameterConvention::Direct_Guaranteed);
|
|
if (linearMapTargetType != unsubstLinearMapTargetType) {
|
|
thunkedLinearMap = builder.createConvertFunction(
|
|
loc, thunkedLinearMap,
|
|
SILType::getPrimitiveObjectType(linearMapTargetType),
|
|
/*withoutActuallyEscaping*/ false);
|
|
}
|
|
assert(origFnType->getNumResults() +
|
|
origFnType->getNumIndirectMutatingParameters() ==
|
|
1);
|
|
if (origFnType->getNumResults() > 0 &&
|
|
origFnType->getResults().front().isFormalDirect()) {
|
|
auto result =
|
|
joinElements({originalDirectResult, thunkedLinearMap}, builder, loc);
|
|
builder.createReturn(loc, result);
|
|
} else {
|
|
builder.createReturn(loc, thunkedLinearMap);
|
|
}
|
|
|
|
return {thunk, interfaceSubs};
|
|
}
|
|
|
|
} // end namespace autodiff
|
|
} // end namespace swift
|