mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Some fixes for coroutines with normal results and `partial_apply` of coroutines were required. Fixes #55084
1483 lines
62 KiB
C++
1483 lines
62 KiB
C++
//===--- VJPCloner.cpp - VJP function generation --------------*- C++ -*---===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines a helper class for generating VJP functions for automatic
|
|
// differentiation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define DEBUG_TYPE "differentiation"
|
|
|
|
#include "swift/AST/Types.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
|
|
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
|
|
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
|
|
#include "swift/SILOptimizer/Differentiation/ADContext.h"
|
|
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
|
|
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
|
|
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
|
|
#include "swift/SILOptimizer/Differentiation/Thunk.h"
|
|
|
|
#include "swift/SIL/TerminatorUtils.h"
|
|
#include "swift/SIL/TypeSubstCloner.h"
|
|
#include "swift/SILOptimizer/Analysis/LoopAnalysis.h"
|
|
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
|
|
#include "swift/SILOptimizer/Utils/CFGOptUtils.h"
|
|
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
|
|
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
|
|
namespace swift {
|
|
namespace autodiff {
|
|
|
|
class VJPCloner::Implementation final
|
|
: public TypeSubstCloner<VJPCloner::Implementation> {
|
|
friend class VJPCloner;
|
|
friend class PullbackCloner;
|
|
|
|
/// The parent VJP cloner.
|
|
VJPCloner &cloner;
|
|
|
|
/// The global context.
|
|
ADContext &context;
|
|
|
|
/// The original function.
|
|
SILFunction *const original;
|
|
|
|
/// The differentiability witness.
|
|
SILDifferentiabilityWitness *const witness;
|
|
|
|
/// The VJP function.
|
|
SILFunction *const vjp;
|
|
|
|
/// The pullback function.
|
|
SILFunction *pullback;
|
|
|
|
/// The differentiation invoker.
|
|
DifferentiationInvoker invoker;
|
|
|
|
/// Info from activity analysis on the original function.
|
|
const DifferentiableActivityInfo &activityInfo;
|
|
|
|
/// The loop info.
|
|
SILLoopInfo *loopInfo;
|
|
|
|
/// The linear map info.
|
|
LinearMapInfo pullbackInfo;
|
|
|
|
/// Caches basic blocks whose phi arguments have been remapped (adding a
|
|
/// predecessor enum argument).
|
|
SmallPtrSet<SILBasicBlock *, 4> remappedBasicBlocks;
|
|
|
|
/// The `AutoDiffLinearMapContext` object. If null, no explicit context is
|
|
/// needed (no loops).
|
|
SILValue pullbackContextValue;
|
|
/// The unique, borrowed context object. This is valid until the exit block.
|
|
SILValue borrowedPullbackContextValue;
|
|
|
|
/// The generic signature of the `Builtin.autoDiffAllocateSubcontext(_:_:)`
|
|
/// declaration. It is used for creating a builtin call.
|
|
GenericSignature builtinAutoDiffAllocateSubcontextGenericSignature;
|
|
|
|
bool errorOccurred = false;
|
|
|
|
/// Mapping from original blocks to pullback values. Used to build pullback
|
|
/// struct instances.
|
|
llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> pullbackValues;
|
|
|
|
ASTContext &getASTContext() const { return vjp->getASTContext(); }
|
|
SILModule &getModule() const { return vjp->getModule(); }
|
|
const AutoDiffConfig &getConfig() const {
|
|
return witness->getConfig();
|
|
}
|
|
|
|
Implementation(VJPCloner &parent, ADContext &context,
|
|
SILDifferentiabilityWitness *witness, SILFunction *vjp,
|
|
DifferentiationInvoker invoker);
|
|
|
|
/// Creates an empty pullback function, to be filled in by `PullbackCloner`.
|
|
SILFunction *createEmptyPullback();
|
|
|
|
/// Run VJP generation. Returns true on error.
|
|
bool run();
|
|
|
|
/// Initializes a context object if needed.
|
|
void emitLinearMapContextInitializationIfNeeded() {
|
|
if (!pullbackInfo.hasHeapAllocatedContext())
|
|
return;
|
|
|
|
// Get linear map struct size.
|
|
auto *returnBB = &*original->findReturnBB();
|
|
auto pullbackTupleType =
|
|
remapASTType(pullbackInfo.getLinearMapTupleType(returnBB)->getCanonicalType());
|
|
Builder.setInsertionPoint(vjp->getEntryBlock());
|
|
|
|
auto pbTupleMetatypeType =
|
|
CanMetatypeType::get(pullbackTupleType, MetatypeRepresentation::Thick);
|
|
auto pbTupleMetatypeSILType =
|
|
SILType::getPrimitiveObjectType(pbTupleMetatypeType);
|
|
auto pbTupleMetatype =
|
|
Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType);
|
|
|
|
// Create an context.
|
|
pullbackContextValue = Builder.createBuiltin(
|
|
original->getLocation(),
|
|
getASTContext().getIdentifier(getBuiltinName(
|
|
BuiltinValueKind::AutoDiffCreateLinearMapContextWithType)),
|
|
SILType::getNativeObjectType(getASTContext()), SubstitutionMap(),
|
|
{pbTupleMetatype});
|
|
borrowedPullbackContextValue = Builder.createBeginBorrow(
|
|
original->getLocation(), pullbackContextValue);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Context object initialized because there are loops\n"
|
|
<< *vjp->getEntryBlock() << '\n'
|
|
<< "pullback tuple type: " << pullbackTupleType << '\n');
|
|
}
|
|
|
|
/// Get the lowered SIL type of the given AST type.
|
|
SILType getLoweredType(Type type) {
|
|
auto vjpGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature();
|
|
Lowering::AbstractionPattern pattern(vjpGenSig,
|
|
type->getReducedType(vjpGenSig));
|
|
return vjp->getLoweredType(pattern, type);
|
|
}
|
|
|
|
SILType getPullbackType() {
|
|
auto vjpFuncTy = vjp->getLoweredFunctionType();
|
|
const auto &conv = vjp->getConventions();
|
|
|
|
return conv.getSILType(vjpFuncTy->getResults().back(),
|
|
vjp->getTypeExpansionContext());
|
|
}
|
|
|
|
GenericSignature getBuiltinAutoDiffAllocateSubcontextDecl() {
|
|
if (builtinAutoDiffAllocateSubcontextGenericSignature)
|
|
return builtinAutoDiffAllocateSubcontextGenericSignature;
|
|
auto &ctx = getASTContext();
|
|
auto *decl = cast<FuncDecl>(getBuiltinValueDecl(
|
|
ctx, ctx.getIdentifier(getBuiltinName(
|
|
BuiltinValueKind::AutoDiffAllocateSubcontextWithType))));
|
|
builtinAutoDiffAllocateSubcontextGenericSignature =
|
|
decl->getGenericSignature();
|
|
assert(builtinAutoDiffAllocateSubcontextGenericSignature);
|
|
return builtinAutoDiffAllocateSubcontextGenericSignature;
|
|
}
|
|
|
|
// Creates a trampoline block for given original terminator instruction, the
|
|
// pullback struct value for its parent block, and a successor basic block.
|
|
//
|
|
// The trampoline block has the same arguments as and branches to the remapped
|
|
// successor block, but drops the last predecessor enum argument.
|
|
//
|
|
// Used for cloning branching terminator instructions with specific
|
|
// requirements on successor block arguments, where an additional predecessor
|
|
// enum argument is not acceptable.
|
|
SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst,
|
|
TupleInst *pbTupleVal,
|
|
SILBasicBlock *succBB);
|
|
|
|
/// Build a pullback tuple value for the given original terminator
|
|
/// instruction.
|
|
TupleInst *buildPullbackValueTupleValue(TermInst *termInst);
|
|
llvm::SmallVector<SILValue, 8> getPullbackValues(SILBasicBlock *origBB);
|
|
|
|
/// Build a predecessor enum instance using the given builder for the given
|
|
/// original predecessor/successor blocks and pullback struct value.
|
|
EnumInst *buildPredecessorEnumValue(SILBuilder &builder,
|
|
SILBasicBlock *predBB,
|
|
SILBasicBlock *succBB,
|
|
SILValue pbTupleVal);
|
|
|
|
public:
|
|
/// Remap original basic blocks, adding predecessor enum arguments.
|
|
SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) {
|
|
auto *vjpBB = BBMap[bb];
|
|
// If error has occurred, or if block has already been remapped, return
|
|
// remapped, return remapped block.
|
|
if (errorOccurred || remappedBasicBlocks.count(bb))
|
|
return vjpBB;
|
|
// Add predecessor enum argument to the remapped block.
|
|
auto *predEnum = pullbackInfo.getBranchingTraceDecl(bb);
|
|
auto enumTy =
|
|
getOpASTType(predEnum->getDeclaredInterfaceType()->getCanonicalType());
|
|
auto enumLoweredTy = context.getTypeConverter().getLoweredType(
|
|
enumTy, TypeExpansionContext::minimal());
|
|
vjpBB->createPhiArgument(enumLoweredTy, OwnershipKind::Owned);
|
|
remappedBasicBlocks.insert(bb);
|
|
return vjpBB;
|
|
}
|
|
|
|
/// General visitor for all instructions. If any error is emitted by previous
|
|
/// visits, bail out.
|
|
void visit(SILInstruction *inst) {
|
|
if (errorOccurred)
|
|
return;
|
|
TypeSubstCloner::visit(inst);
|
|
}
|
|
|
|
void visitSILInstruction(SILInstruction *inst) {
|
|
context.emitNondifferentiabilityError(
|
|
inst, invoker, diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
}
|
|
|
|
void postProcess(SILInstruction *orig, SILInstruction *cloned) {
|
|
if (errorOccurred)
|
|
return;
|
|
SILClonerWithScopes::postProcess(orig, cloned);
|
|
}
|
|
|
|
void visitReturnInst(ReturnInst *ri) {
|
|
Builder.setCurrentDebugScope(getOpScope(ri->getDebugScope()));
|
|
auto loc = ri->getOperand().getLoc();
|
|
// Build pullback tuple value for original block.
|
|
auto *origExit = ri->getParent();
|
|
|
|
// Get the value in the VJP corresponding to the original result.
|
|
auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
|
|
auto origResult = getOpValue(origRetInst->getOperand());
|
|
SmallVector<SILValue, 8> origResults;
|
|
extractAllElements(origResult, Builder, origResults);
|
|
|
|
// Get and partially apply the pullback.
|
|
auto vjpSubstMap = vjp->getForwardingSubstitutionMap();
|
|
auto *pullbackRef = Builder.createFunctionRef(loc, pullback);
|
|
|
|
// Prepare partial application arguments.
|
|
SILValue partialApplyArg;
|
|
PartialApplyInst *pullbackPartialApply;
|
|
if (borrowedPullbackContextValue) {
|
|
auto *pbTupleVal = buildPullbackValueTupleValue(ri);
|
|
// Initialize the top-level subcontext buffer with the top-level pullback
|
|
// tuple.
|
|
auto addr = emitProjectTopLevelSubcontext(
|
|
Builder, loc, borrowedPullbackContextValue, pbTupleVal->getType());
|
|
Builder.createStore(
|
|
loc, pbTupleVal, addr,
|
|
pbTupleVal->getType().isTrivial(*pullback) ?
|
|
StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init);
|
|
|
|
Builder.createEndBorrow(loc, borrowedPullbackContextValue);
|
|
pullbackPartialApply = Builder.createPartialApply(
|
|
loc, pullbackRef, vjpSubstMap, {pullbackContextValue},
|
|
ParameterConvention::Direct_Guaranteed);
|
|
} else {
|
|
pullbackPartialApply = Builder.createPartialApply(
|
|
loc, pullbackRef, vjpSubstMap, getPullbackValues(origExit),
|
|
ParameterConvention::Direct_Guaranteed);
|
|
}
|
|
|
|
auto pullbackType = vjp->mapTypeIntoContext(getPullbackType());
|
|
auto pullbackFnType = pullbackType.castTo<SILFunctionType>();
|
|
auto pullbackSubstType =
|
|
pullbackPartialApply->getType().castTo<SILFunctionType>();
|
|
|
|
// If necessary, convert the pullback value to the returned pullback
|
|
// function type.
|
|
SILValue pullbackValue;
|
|
if (pullbackSubstType == pullbackFnType) {
|
|
pullbackValue = pullbackPartialApply;
|
|
} else if (pullbackSubstType->isABICompatibleWith(pullbackFnType, *vjp)
|
|
.isCompatible()) {
|
|
pullbackValue =
|
|
Builder.createConvertFunction(loc, pullbackPartialApply, pullbackType,
|
|
/*withoutActuallyEscaping*/ false);
|
|
} else {
|
|
llvm::report_fatal_error("Pullback value type is not ABI-compatible "
|
|
"with the returned pullback type");
|
|
}
|
|
|
|
// Return a tuple of the original result and pullback.
|
|
SmallVector<SILValue, 8> directResults;
|
|
directResults.append(origResults.begin(), origResults.end());
|
|
directResults.push_back(pullbackValue);
|
|
|
|
Builder.createReturn(ri->getLoc(),
|
|
joinElements(directResults, Builder, loc));
|
|
}
|
|
|
|
void visitUnwindInst(UnwindInst *ui) {
|
|
Builder.setCurrentDebugScope(getOpScope(ui->getDebugScope()));
|
|
auto loc = ui->getLoc();
|
|
auto *origExit = ui->getParent();
|
|
|
|
// Consume unused pullback values
|
|
if (borrowedPullbackContextValue) {
|
|
auto *pbTupleVal = buildPullbackValueTupleValue(ui);
|
|
// Initialize the top-level subcontext buffer with the top-level pullback
|
|
// tuple.
|
|
auto addr = emitProjectTopLevelSubcontext(
|
|
Builder, loc, borrowedPullbackContextValue, pbTupleVal->getType());
|
|
Builder.createStore(
|
|
loc, pbTupleVal, addr,
|
|
pbTupleVal->getType().isTrivial(*pullback) ?
|
|
StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init);
|
|
|
|
Builder.createEndBorrow(loc, borrowedPullbackContextValue);
|
|
Builder.emitDestroyValueOperation(loc, pullbackContextValue);
|
|
} else {
|
|
for (SILValue val : getPullbackValues(origExit))
|
|
Builder.emitDestroyValueOperation(loc, val);
|
|
}
|
|
|
|
Builder.createUnwind(loc);
|
|
}
|
|
|
|
void visitBranchInst(BranchInst *bi) {
|
|
Builder.setCurrentDebugScope(getOpScope(bi->getDebugScope()));
|
|
// Build pullback struct value for original block.
|
|
// Build predecessor enum value for destination block.
|
|
auto *origBB = bi->getParent();
|
|
auto *pbTupleVal = buildPullbackValueTupleValue(bi);
|
|
auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB,
|
|
bi->getDestBB(), pbTupleVal);
|
|
|
|
// Remap arguments, appending the new enum values.
|
|
SmallVector<SILValue, 8> args;
|
|
for (auto origArg : bi->getArgs())
|
|
args.push_back(getOpValue(origArg));
|
|
args.push_back(enumVal);
|
|
|
|
// Create a new `br` instruction.
|
|
getBuilder().createBranch(bi->getLoc(), getOpBasicBlock(bi->getDestBB()),
|
|
args);
|
|
}
|
|
|
|
void visitYieldInst(YieldInst *yi) {
|
|
Builder.setCurrentDebugScope(getOpScope(yi->getDebugScope()));
|
|
// Build pullback struct value for original block.
|
|
auto *pbTupleVal = buildPullbackValueTupleValue(yi);
|
|
// Create a new `yield` instruction. Note that resume / unwind blocks cannot
|
|
// have arguments, so we're building trampolines with branch tracing enum
|
|
// values.
|
|
getBuilder().createYield(
|
|
yi->getLoc(), getOpValueArray<1>(yi->getOperandValues()),
|
|
createTrampolineBasicBlock(yi, pbTupleVal, yi->getResumeBB()),
|
|
createTrampolineBasicBlock(yi, pbTupleVal, yi->getUnwindBB()));
|
|
}
|
|
|
|
void visitCondBranchInst(CondBranchInst *cbi) {
|
|
Builder.setCurrentDebugScope(getOpScope(cbi->getDebugScope()));
|
|
// Build pullback struct value for original block.
|
|
auto *pbTupleVal = buildPullbackValueTupleValue(cbi);
|
|
// Create a new `cond_br` instruction.
|
|
getBuilder().createCondBranch(
|
|
cbi->getLoc(), getOpValue(cbi->getCondition()),
|
|
createTrampolineBasicBlock(cbi, pbTupleVal, cbi->getTrueBB()),
|
|
createTrampolineBasicBlock(cbi, pbTupleVal, cbi->getFalseBB()));
|
|
}
|
|
|
|
void visitSwitchEnumTermInst(SwitchEnumTermInst inst) {
|
|
Builder.setCurrentDebugScope(getOpScope(inst->getDebugScope()));
|
|
// Build pullback tuple value for original block.
|
|
auto *pbTupleVal = buildPullbackValueTupleValue(*inst);
|
|
|
|
// Create trampoline successor basic blocks.
|
|
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
|
|
for (unsigned i : range(inst.getNumCases())) {
|
|
auto caseBB = inst.getCase(i);
|
|
auto *trampolineBB =
|
|
createTrampolineBasicBlock(inst, pbTupleVal, caseBB.second);
|
|
caseBBs.push_back({caseBB.first, trampolineBB});
|
|
}
|
|
// Create trampoline default basic block.
|
|
SILBasicBlock *newDefaultBB = nullptr;
|
|
if (auto *defaultBB = inst.getDefaultBBOrNull().getPtrOrNull())
|
|
newDefaultBB = createTrampolineBasicBlock(inst, pbTupleVal, defaultBB);
|
|
|
|
// Create a new `switch_enum` instruction.
|
|
switch (inst->getKind()) {
|
|
case SILInstructionKind::SwitchEnumInst:
|
|
getBuilder().createSwitchEnum(
|
|
inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs);
|
|
break;
|
|
case SILInstructionKind::SwitchEnumAddrInst:
|
|
getBuilder().createSwitchEnumAddr(
|
|
inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs);
|
|
break;
|
|
default:
|
|
llvm_unreachable("Expected `switch_enum` or `switch_enum_addr`");
|
|
}
|
|
}
|
|
|
|
void visitSwitchEnumInst(SwitchEnumInst *sei) {
|
|
visitSwitchEnumTermInst(sei);
|
|
}
|
|
|
|
void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) {
|
|
visitSwitchEnumTermInst(seai);
|
|
}
|
|
|
|
void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) {
|
|
Builder.setCurrentDebugScope(getOpScope(ccbi->getDebugScope()));
|
|
// Build pullback struct value for original block.
|
|
auto *pbTupleVal = buildPullbackValueTupleValue(ccbi);
|
|
// Create a new `checked_cast_branch` instruction.
|
|
getBuilder().createCheckedCastBranch(
|
|
ccbi->getLoc(), ccbi->isExact(), ccbi->getCheckedCastOptions(),
|
|
getOpValue(ccbi->getOperand()),
|
|
getOpASTType(ccbi->getSourceFormalType()),
|
|
getOpType(ccbi->getTargetLoweredType()),
|
|
getOpASTType(ccbi->getTargetFormalType()),
|
|
createTrampolineBasicBlock(ccbi, pbTupleVal, ccbi->getSuccessBB()),
|
|
createTrampolineBasicBlock(ccbi, pbTupleVal, ccbi->getFailureBB()),
|
|
ccbi->getTrueBBCount(), ccbi->getFalseBBCount());
|
|
}
|
|
|
|
void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi) {
|
|
Builder.setCurrentDebugScope(getOpScope(ccabi->getDebugScope()));
|
|
// Build pullback struct value for original block.
|
|
auto *pbTupleVal = buildPullbackValueTupleValue(ccabi);
|
|
// Create a new `checked_cast_addr_branch` instruction.
|
|
getBuilder().createCheckedCastAddrBranch(
|
|
ccabi->getLoc(),
|
|
ccabi->getCheckedCastOptions(),
|
|
ccabi->getConsumptionKind(),
|
|
getOpValue(ccabi->getSrc()), getOpASTType(ccabi->getSourceFormalType()),
|
|
getOpValue(ccabi->getDest()),
|
|
getOpASTType(ccabi->getTargetFormalType()),
|
|
createTrampolineBasicBlock(ccabi, pbTupleVal, ccabi->getSuccessBB()),
|
|
createTrampolineBasicBlock(ccabi, pbTupleVal, ccabi->getFailureBB()),
|
|
ccabi->getTrueBBCount(), ccabi->getFalseBBCount());
|
|
}
|
|
|
|
void visitEndApplyInst(EndApplyInst *eai) {
|
|
BeginApplyInst *bai = eai->getBeginApply();
|
|
|
|
// If callee should not be differentiated, do standard cloning.
|
|
if (!pullbackInfo.shouldDifferentiateApplySite(bai)) {
|
|
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *bai << '\n');
|
|
TypeSubstCloner::visitEndApplyInst(eai);
|
|
return;
|
|
}
|
|
// If the original function is a semantic member accessor, do standard
|
|
// cloning. Semantic member accessors have special pullback generation
|
|
// logic, so all `end_apply` instructions can be directly cloned to the VJP.
|
|
if (isSemanticMemberAccessor(original)) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Cloning `end_apply` in semantic member accessor:\n"
|
|
<< *eai << '\n');
|
|
TypeSubstCloner::visitEndApplyInst(eai);
|
|
return;
|
|
}
|
|
|
|
Builder.setCurrentDebugScope(getOpScope(eai->getDebugScope()));
|
|
auto loc = eai->getLoc();
|
|
auto &builder = getBuilder();
|
|
auto token = getMappedValue(bai->getTokenResult());
|
|
|
|
LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *eai << '\n');
|
|
|
|
FullApplySite fai(token->getDefiningInstruction());
|
|
auto vjpResult = builder.createEndApply(loc, token, fai.getType());
|
|
LLVM_DEBUG(getADDebugStream() << "Created end_apply\n" << *vjpResult);
|
|
|
|
builder.emitDestroyValueOperation(loc, fai.getCallee());
|
|
|
|
// Checkpoint the pullback.
|
|
SmallVector<SILValue, 8> vjpDirectResults;
|
|
extractAllElements(vjpResult, getBuilder(), vjpDirectResults);
|
|
ArrayRef<SILValue> originalDirectResults =
|
|
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
|
|
SILValue originalDirectResult =
|
|
joinElements(originalDirectResults, getBuilder(), loc);
|
|
SILValue pullback = vjpDirectResults.back();
|
|
{
|
|
auto pullbackFnType = pullback->getType().castTo<SILFunctionType>();
|
|
auto pullbackUnsubstFnType =
|
|
pullbackFnType->getUnsubstitutedType(getModule());
|
|
if (pullbackFnType != pullbackUnsubstFnType) {
|
|
pullback = builder.createConvertFunction(
|
|
loc, pullback,
|
|
SILType::getPrimitiveObjectType(pullbackUnsubstFnType),
|
|
/*withoutActuallyEscaping*/ false);
|
|
}
|
|
}
|
|
|
|
// Store the original result to the value map.
|
|
mapValue(eai, originalDirectResult);
|
|
|
|
auto pullbackType = pullbackInfo.lookUpLinearMapType(bai);
|
|
|
|
// If actual pullback type does not match lowered pullback type, reabstract
|
|
// the pullback using a thunk.
|
|
auto actualPullbackType =
|
|
getOpType(pullback->getType()).getAs<SILFunctionType>();
|
|
auto loweredPullbackType =
|
|
getOpType(getLoweredType(pullbackType)).castTo<SILFunctionType>();
|
|
|
|
auto applyInfoIt = context.getNestedApplyInfo().find(bai);
|
|
assert(applyInfoIt != context.getNestedApplyInfo().end());
|
|
if (!loweredPullbackType->isEqual(actualPullbackType)) {
|
|
// Set non-reabstracted original pullback type in nested apply info.
|
|
applyInfoIt->second.originalPullbackType = actualPullbackType;
|
|
SILOptFunctionBuilder fb(context.getTransform());
|
|
pullback = reabstractCoroutine(
|
|
getBuilder(), fb, loc, pullback, loweredPullbackType,
|
|
[this](SubstitutionMap subs) -> SubstitutionMap {
|
|
return this->getOpSubstitutionMap(subs);
|
|
});
|
|
}
|
|
unsigned pullbackIdx = applyInfoIt->second.pullbackIdx;
|
|
pullbackValues[bai->getParent()][pullbackIdx] = pullback;
|
|
|
|
// Some instructions that produce the callee may have been cloned.
|
|
// If the original callee did not have any users beyond this `apply`,
|
|
// recursively kill the cloned callee.
|
|
if (auto *origCallee = cast_or_null<SingleValueInstruction>(
|
|
bai->getCallee()->getDefiningInstruction()))
|
|
if (origCallee->hasOneUse())
|
|
recursivelyDeleteTriviallyDeadInstructions(
|
|
getOpValue(origCallee)->getDefiningInstruction());
|
|
}
|
|
|
|
// Check and diagnose non-differentiable original function type.
|
|
bool diagnoseNondifferentiableOriginalFunctionType(CanSILFunctionType originalFnTy,
|
|
FullApplySite fai, SILValue origCallee,
|
|
const AutoDiffConfig &config) const {
|
|
// Check and diagnose non-differentiable arguments.
|
|
for (auto paramIndex : config.parameterIndices->getIndices()) {
|
|
if (!originalFnTy->getParameters()[paramIndex]
|
|
.getSILStorageInterfaceType()
|
|
.isDifferentiable(getModule())) {
|
|
auto arg = fai.getArgumentsWithoutIndirectResults()[paramIndex];
|
|
// FIXME: This shouldn't be necessary and might indicate a bug in
|
|
// the transformation.
|
|
RegularLocation nonAutoGenLoc(arg.getLoc());
|
|
nonAutoGenLoc.markNonAutoGenerated();
|
|
auto startLoc = nonAutoGenLoc.getStartSourceLoc();
|
|
auto endLoc = nonAutoGenLoc.getEndSourceLoc();
|
|
context.emitNondifferentiabilityError(
|
|
arg, invoker, diag::autodiff_nondifferentiable_argument)
|
|
.fixItInsert(startLoc, "withoutDerivative(at: ")
|
|
.fixItInsertAfter(endLoc, ")");
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// Check and diagnose non-differentiable results.
|
|
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
|
|
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
|
|
originalFnTy->getNumAutoDiffSemanticResultsParameters();
|
|
|
|
for (auto resultIndex : config.resultIndices->getIndices()) {
|
|
SILType remappedResultType;
|
|
if (resultIndex >= firstYieldResultIndex) {
|
|
auto yieldResultIdx = resultIndex - firstYieldResultIndex;
|
|
const auto& yield = originalFnTy->getYields()[yieldResultIdx];
|
|
// We do not have a good way to differentiate direct yields
|
|
if (yield.isAutoDiffSemanticResult())
|
|
remappedResultType = yield.getSILStorageInterfaceType();
|
|
else {
|
|
context.emitNondifferentiabilityError(
|
|
origCallee, invoker,
|
|
diag::autodiff_cannot_differentiate_through_direct_yield);
|
|
return true;
|
|
}
|
|
} else if (resultIndex >= firstSemanticParamResultIdx) {
|
|
auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx;
|
|
auto semanticResultArg =
|
|
*std::next(fai.getAutoDiffSemanticResultArguments().begin(),
|
|
semanticResultArgIdx);
|
|
remappedResultType = semanticResultArg->getType();
|
|
} else {
|
|
remappedResultType = originalFnTy->getResults()[resultIndex]
|
|
.getSILStorageInterfaceType();
|
|
}
|
|
|
|
if (!remappedResultType || !remappedResultType.isDifferentiable(getModule())) {
|
|
auto startLoc = fai.getLoc().getStartSourceLoc();
|
|
auto endLoc = fai.getLoc().getEndSourceLoc();
|
|
context.emitNondifferentiabilityError(
|
|
origCallee, invoker,
|
|
diag::autodiff_nondifferentiable_result)
|
|
.fixItInsert(startLoc, "withoutDerivative(at: ")
|
|
.fixItInsertAfter(endLoc, ")");
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
void visitBeginApplyInst(BeginApplyInst *bai) {
|
|
// If callee should not be differentiated, do standard cloning.
|
|
if (!pullbackInfo.shouldDifferentiateApplySite(bai)) {
|
|
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *bai << '\n');
|
|
TypeSubstCloner::visitBeginApplyInst(bai);
|
|
return;
|
|
}
|
|
// If the original function is a semantic member accessor, do standard
|
|
// cloning. Semantic member accessors have special pullback generation
|
|
// logic, so all `begin_apply` instructions can be directly cloned to the VJP.
|
|
if (isSemanticMemberAccessor(original)) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Cloning `begin_apply` in semantic member accessor:\n"
|
|
<< *bai << '\n');
|
|
TypeSubstCloner::visitBeginApplyInst(bai);
|
|
return;
|
|
}
|
|
|
|
Builder.setCurrentDebugScope(getOpScope(bai->getDebugScope()));
|
|
auto loc = bai->getLoc();
|
|
auto &builder = getBuilder();
|
|
auto origCallee = getOpValue(bai->getCallee());
|
|
auto originalFnTy = origCallee->getType().castTo<SILFunctionType>();
|
|
|
|
LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *bai << '\n');
|
|
|
|
|
|
SmallVector<SILValue, 4> allResults;
|
|
SmallVector<unsigned, 8> activeParamIndices;
|
|
SmallVector<unsigned, 8> activeResultIndices;
|
|
collectMinimalIndicesForFunctionCall(bai, getConfig(), activityInfo,
|
|
allResults, activeParamIndices,
|
|
activeResultIndices);
|
|
assert(!activeParamIndices.empty() && "Parameter indices cannot be empty");
|
|
assert(!activeResultIndices.empty() && "Result indices cannot be empty");
|
|
LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=(";
|
|
llvm::interleave(
|
|
activeParamIndices.begin(), activeParamIndices.end(),
|
|
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
|
|
s << "), results=("; llvm::interleave(
|
|
activeResultIndices.begin(), activeResultIndices.end(),
|
|
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
|
|
s << ")\n";);
|
|
|
|
// Form expected indices.
|
|
AutoDiffConfig config(
|
|
IndexSubset::get(getASTContext(),
|
|
bai->getArgumentsWithoutIndirectResults().size(),
|
|
activeParamIndices),
|
|
IndexSubset::get(getASTContext(),
|
|
bai->getSubstCalleeType()->getNumAutoDiffSemanticResults(),
|
|
activeResultIndices));
|
|
|
|
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy,
|
|
bai, origCallee, config)) {
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
|
|
// Emit the VJP.
|
|
SILValue vjpValue;
|
|
|
|
// If the original `apply` instruction has a substitution map, then the
|
|
// applied function is specialized.
|
|
// In the VJP, specialization is also necessary for parity. The original
|
|
// function operand is specialized with a remapped version of same
|
|
// substitution map using an argument-less `partial_apply`.
|
|
if (bai->getSubstitutionMap().empty()) {
|
|
origCallee = builder.emitCopyValueOperation(loc, origCallee);
|
|
} else {
|
|
auto substMap = getOpSubstitutionMap(bai->getSubstitutionMap());
|
|
auto vjpPartialApply = getBuilder().createPartialApply(
|
|
bai->getLoc(), origCallee, substMap, {},
|
|
ParameterConvention::Direct_Guaranteed);
|
|
origCallee = vjpPartialApply;
|
|
originalFnTy = origCallee->getType().castTo<SILFunctionType>();
|
|
|
|
// Diagnose if new original function type is non-differentiable.
|
|
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy,
|
|
bai, origCallee, config)) {
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
}
|
|
|
|
auto *diffFuncInst =
|
|
context.createDifferentiableFunction(getBuilder(), loc,
|
|
config.parameterIndices, config.resultIndices,
|
|
origCallee);
|
|
|
|
// Record the `differentiable_function` instruction.
|
|
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
|
|
|
|
builder.emitScopedBorrowOperation(
|
|
loc, diffFuncInst,
|
|
[&](SILValue borrowedADFunc) {
|
|
auto extractedVJP =
|
|
getBuilder().createDifferentiableFunctionExtract(
|
|
loc, NormalDifferentiableFunctionTypeComponent::VJP,
|
|
borrowedADFunc);
|
|
vjpValue = builder.emitCopyValueOperation(loc, extractedVJP);
|
|
});
|
|
builder.emitDestroyValueOperation(loc, diffFuncInst);
|
|
|
|
// Record desired/actual VJP indices.
|
|
// Temporarily set original pullback type to `None`.
|
|
NestedApplyInfo info{config, /*originalPullbackType*/ std::nullopt};
|
|
auto insertion = context.getNestedApplyInfo().try_emplace(bai, info);
|
|
auto &nestedApplyInfo = insertion.first->getSecond();
|
|
nestedApplyInfo = info;
|
|
|
|
// Call the VJP using the original parameters.
|
|
SmallVector<SILValue, 8> vjpArgs;
|
|
auto vjpFnTy = getOpType(vjpValue->getType()).castTo<SILFunctionType>();
|
|
auto numVJPArgs =
|
|
vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults();
|
|
vjpArgs.reserve(numVJPArgs);
|
|
// Collect substituted arguments.
|
|
for (auto origArg : bai->getArguments())
|
|
vjpArgs.push_back(getOpValue(origArg));
|
|
|
|
// Apply the VJP.
|
|
// The VJP should be specialized, so no substitution map is necessary.
|
|
auto *vjpCall = getBuilder().createBeginApply(loc, vjpValue, SubstitutionMap(),
|
|
vjpArgs, bai->getApplyOptions());
|
|
LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall);
|
|
// Note that vjpValue is destroyed after end_apply
|
|
|
|
// Store all the results (yields and token) to the value map.
|
|
assert(bai->getNumResults() == vjpCall->getNumResults());
|
|
for (unsigned i = 0; i < vjpCall->getNumResults(); ++i)
|
|
mapValue(bai->getResult(i), vjpCall->getResult(i));
|
|
|
|
// Checkpoint the pullback.
|
|
nestedApplyInfo.pullbackIdx = pullbackValues[bai->getParent()].size();
|
|
pullbackValues[bai->getParent()].push_back(SILValue());
|
|
|
|
// The rest of the cloning magic happens during `end_apply` cloning.
|
|
}
|
|
|
|
// If an `apply` has active results or active inout arguments, replace it
|
|
// with an `apply` of its VJP.
|
|
void visitApplyInst(ApplyInst *ai) {
|
|
// If callee should not be differentiated, do standard cloning.
|
|
if (!pullbackInfo.shouldDifferentiateApplySite(ai)) {
|
|
LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
|
|
TypeSubstCloner::visitApplyInst(ai);
|
|
return;
|
|
}
|
|
// If callee is `array.uninitialized_intrinsic`, do standard cloning.
|
|
// `array.uninitialized_intrinsic` differentiation is handled separately.
|
|
if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Cloning `array.uninitialized_intrinsic` `apply`:\n"
|
|
<< *ai << '\n');
|
|
TypeSubstCloner::visitApplyInst(ai);
|
|
return;
|
|
}
|
|
// If callee is `array.finalize_intrinsic`, do standard cloning.
|
|
// `array.finalize_intrinsic` has special-case pullback generation.
|
|
if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Cloning `array.finalize_intrinsic` `apply`:\n"
|
|
<< *ai << '\n');
|
|
TypeSubstCloner::visitApplyInst(ai);
|
|
return;
|
|
}
|
|
// If the original function is a semantic member accessor, do standard
|
|
// cloning. Semantic member accessors have special pullback generation
|
|
// logic, so all `apply` instructions can be directly cloned to the VJP.
|
|
if (isSemanticMemberAccessor(original)) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Cloning `apply` in semantic member accessor:\n"
|
|
<< *ai << '\n');
|
|
TypeSubstCloner::visitApplyInst(ai);
|
|
return;
|
|
}
|
|
|
|
Builder.setCurrentDebugScope(getOpScope(ai->getDebugScope()));
|
|
auto loc = ai->getLoc();
|
|
auto &builder = getBuilder();
|
|
auto origCallee = getOpValue(ai->getCallee());
|
|
auto originalFnTy = origCallee->getType().castTo<SILFunctionType>();
|
|
|
|
LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n');
|
|
|
|
// Get the minimal parameter and result indices required for differentiating
|
|
// this `apply`.
|
|
SmallVector<SILValue, 4> allResults;
|
|
SmallVector<unsigned, 8> activeParamIndices;
|
|
SmallVector<unsigned, 8> activeResultIndices;
|
|
collectMinimalIndicesForFunctionCall(ai, getConfig(), activityInfo,
|
|
allResults, activeParamIndices,
|
|
activeResultIndices);
|
|
assert(!activeParamIndices.empty() && "Parameter indices cannot be empty");
|
|
assert(!activeResultIndices.empty() && "Result indices cannot be empty");
|
|
LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=(";
|
|
llvm::interleave(
|
|
activeParamIndices.begin(), activeParamIndices.end(),
|
|
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
|
|
s << "), results=("; llvm::interleave(
|
|
activeResultIndices.begin(), activeResultIndices.end(),
|
|
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
|
|
s << ")\n";);
|
|
|
|
// Form expected indices.
|
|
AutoDiffConfig config(
|
|
IndexSubset::get(getASTContext(),
|
|
ai->getArgumentsWithoutIndirectResults().size(),
|
|
activeParamIndices),
|
|
IndexSubset::get(getASTContext(),
|
|
ai->getSubstCalleeType()->getNumAutoDiffSemanticResults(),
|
|
activeResultIndices));
|
|
|
|
// Emit the VJP.
|
|
SILValue vjpValue;
|
|
// If functionSource is a `@differentiable` function, just extract it.
|
|
if (originalFnTy->isDifferentiable()) {
|
|
auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices();
|
|
for (auto i : config.parameterIndices->getIndices()) {
|
|
if (!paramIndices->contains(i)) {
|
|
context.emitNondifferentiabilityError(
|
|
origCallee, invoker,
|
|
diag::
|
|
autodiff_function_noderivative_parameter_not_differentiable);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
}
|
|
builder.emitScopedBorrowOperation(
|
|
loc, origCallee, [&](SILValue borrowedDiffFunc) {
|
|
auto origFnType = origCallee->getType().castTo<SILFunctionType>();
|
|
auto origFnUnsubstType =
|
|
origFnType->getUnsubstitutedType(getModule());
|
|
if (origFnType != origFnUnsubstType) {
|
|
borrowedDiffFunc = builder.createConvertFunction(
|
|
loc, borrowedDiffFunc,
|
|
SILType::getPrimitiveObjectType(origFnUnsubstType),
|
|
/*withoutActuallyEscaping*/ false);
|
|
}
|
|
vjpValue = builder.createDifferentiableFunctionExtract(
|
|
loc, NormalDifferentiableFunctionTypeComponent::VJP,
|
|
borrowedDiffFunc);
|
|
vjpValue = builder.emitCopyValueOperation(loc, vjpValue);
|
|
});
|
|
auto vjpFnType = vjpValue->getType().castTo<SILFunctionType>();
|
|
auto vjpFnUnsubstType = vjpFnType->getUnsubstitutedType(getModule());
|
|
if (vjpFnType != vjpFnUnsubstType) {
|
|
vjpValue = builder.createConvertFunction(
|
|
loc, vjpValue, SILType::getPrimitiveObjectType(vjpFnUnsubstType),
|
|
/*withoutActuallyEscaping*/ false);
|
|
}
|
|
}
|
|
|
|
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy,
|
|
ai, origCallee, config)) {
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
|
|
// If VJP has not yet been found, emit an `differentiable_function`
|
|
// instruction on the remapped original function operand and
|
|
// an `differentiable_function_extract` instruction to get the VJP.
|
|
// The `differentiable_function` instruction will be canonicalized during
|
|
// the transform main loop.
|
|
if (!vjpValue) {
|
|
// FIXME: Handle indirect differentiation invokers. This may require some
|
|
// redesign: currently, each original function + witness pair is mapped
|
|
// only to one invoker.
|
|
/*
|
|
DifferentiationInvoker indirect(ai, attr);
|
|
auto insertion =
|
|
context.getInvokers().try_emplace({original, attr}, indirect);
|
|
auto &invoker = insertion.first->getSecond();
|
|
invoker = indirect;
|
|
*/
|
|
|
|
// If the original `apply` instruction has a substitution map, then the
|
|
// applied function is specialized.
|
|
// In the VJP, specialization is also necessary for parity. The original
|
|
// function operand is specialized with a remapped version of same
|
|
// substitution map using an argument-less `partial_apply`.
|
|
if (ai->getSubstitutionMap().empty()) {
|
|
origCallee = builder.emitCopyValueOperation(loc, origCallee);
|
|
} else {
|
|
auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap());
|
|
auto vjpPartialApply = getBuilder().createPartialApply(
|
|
ai->getLoc(), origCallee, substMap, {},
|
|
ParameterConvention::Direct_Guaranteed);
|
|
origCallee = vjpPartialApply;
|
|
originalFnTy = origCallee->getType().castTo<SILFunctionType>();
|
|
|
|
// Diagnose if new original function type is non-differentiable.
|
|
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy,
|
|
ai, origCallee, config)) {
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
}
|
|
|
|
auto *diffFuncInst = context.createDifferentiableFunction(
|
|
getBuilder(), loc, config.parameterIndices, config.resultIndices,
|
|
origCallee);
|
|
|
|
// Record the `differentiable_function` instruction.
|
|
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
|
|
|
|
builder.emitScopedBorrowOperation(
|
|
loc, diffFuncInst, [&](SILValue borrowedADFunc) {
|
|
auto extractedVJP =
|
|
getBuilder().createDifferentiableFunctionExtract(
|
|
loc, NormalDifferentiableFunctionTypeComponent::VJP,
|
|
borrowedADFunc);
|
|
vjpValue = builder.emitCopyValueOperation(loc, extractedVJP);
|
|
});
|
|
builder.emitDestroyValueOperation(loc, diffFuncInst);
|
|
}
|
|
|
|
// Record desired/actual VJP indices.
|
|
// Temporarily set original pullback type to `None`.
|
|
NestedApplyInfo info{config, /*originalPullbackType*/ std::nullopt};
|
|
auto insertion = context.getNestedApplyInfo().try_emplace(ai, info);
|
|
auto &nestedApplyInfo = insertion.first->getSecond();
|
|
nestedApplyInfo = info;
|
|
|
|
// Call the VJP using the original parameters.
|
|
SmallVector<SILValue, 8> vjpArgs;
|
|
auto vjpFnTy = getOpType(vjpValue->getType()).castTo<SILFunctionType>();
|
|
auto numVJPArgs =
|
|
vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults();
|
|
vjpArgs.reserve(numVJPArgs);
|
|
// Collect substituted arguments.
|
|
for (auto origArg : ai->getArguments())
|
|
vjpArgs.push_back(getOpValue(origArg));
|
|
assert(vjpArgs.size() == numVJPArgs);
|
|
// Apply the VJP.
|
|
// The VJP should be specialized, so no substitution map is necessary.
|
|
auto *vjpCall = getBuilder().createApply(loc, vjpValue, SubstitutionMap(),
|
|
vjpArgs, ai->getApplyOptions());
|
|
LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall);
|
|
builder.emitDestroyValueOperation(loc, vjpValue);
|
|
|
|
// Get the VJP results (original results and pullback).
|
|
SmallVector<SILValue, 8> vjpDirectResults;
|
|
extractAllElements(vjpCall, getBuilder(), vjpDirectResults);
|
|
ArrayRef<SILValue> originalDirectResults =
|
|
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
|
|
SILValue originalDirectResult =
|
|
joinElements(originalDirectResults, getBuilder(), vjpCall->getLoc());
|
|
SILValue pullback = vjpDirectResults.back();
|
|
{
|
|
auto pullbackFnType = pullback->getType().castTo<SILFunctionType>();
|
|
auto pullbackUnsubstFnType =
|
|
pullbackFnType->getUnsubstitutedType(getModule());
|
|
if (pullbackFnType != pullbackUnsubstFnType) {
|
|
pullback = builder.createConvertFunction(
|
|
loc, pullback,
|
|
SILType::getPrimitiveObjectType(pullbackUnsubstFnType),
|
|
/*withoutActuallyEscaping*/ false);
|
|
}
|
|
}
|
|
|
|
// Store the original result to the value map.
|
|
mapValue(ai, originalDirectResult);
|
|
|
|
// Checkpoint the pullback.
|
|
auto pullbackType = pullbackInfo.lookUpLinearMapType(ai);
|
|
|
|
// If actual pullback type does not match lowered pullback type, reabstract
|
|
// the pullback using a thunk.
|
|
auto actualPullbackType =
|
|
getOpType(pullback->getType()).getAs<SILFunctionType>();
|
|
auto loweredPullbackType =
|
|
getOpType(getLoweredType(pullbackType)).castTo<SILFunctionType>();
|
|
if (!loweredPullbackType->isEqual(actualPullbackType)) {
|
|
// Set non-reabstracted original pullback type in nested apply info.
|
|
nestedApplyInfo.originalPullbackType = actualPullbackType;
|
|
SILOptFunctionBuilder fb(context.getTransform());
|
|
pullback = reabstractFunction(
|
|
getBuilder(), fb, ai->getLoc(), pullback, loweredPullbackType,
|
|
[this](SubstitutionMap subs) -> SubstitutionMap {
|
|
return this->getOpSubstitutionMap(subs);
|
|
});
|
|
}
|
|
nestedApplyInfo.pullbackIdx = pullbackValues[ai->getParent()].size();
|
|
pullbackValues[ai->getParent()].push_back(pullback);
|
|
|
|
// Some instructions that produce the callee may have been cloned.
|
|
// If the original callee did not have any users beyond this `apply`,
|
|
// recursively kill the cloned callee.
|
|
if (auto *origCallee = cast_or_null<SingleValueInstruction>(
|
|
ai->getCallee()->getDefiningInstruction()))
|
|
if (origCallee->hasOneUse())
|
|
recursivelyDeleteTriviallyDeadInstructions(
|
|
getOpValue(origCallee)->getDefiningInstruction());
|
|
}
|
|
|
|
void visitTryApplyInst(TryApplyInst *tai) {
|
|
Builder.setCurrentDebugScope(getOpScope(tai->getDebugScope()));
|
|
// Build pullback struct value for original block.
|
|
auto *pbTupleVal = buildPullbackValueTupleValue(tai);
|
|
// Create a new `try_apply` instruction.
|
|
auto args = getOpValueArray<8>(tai->getArguments());
|
|
getBuilder().createTryApply(
|
|
tai->getLoc(), getOpValue(tai->getCallee()),
|
|
getOpSubstitutionMap(tai->getSubstitutionMap()), args,
|
|
createTrampolineBasicBlock(tai, pbTupleVal, tai->getNormalBB()),
|
|
createTrampolineBasicBlock(tai, pbTupleVal, tai->getErrorBB()),
|
|
tai->getApplyOptions());
|
|
}
|
|
|
|
void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) {
|
|
// Clone `differentiable_function` from original to VJP, then add the cloned
|
|
// instruction to the `differentiable_function` worklist.
|
|
TypeSubstCloner::visitDifferentiableFunctionInst(dfi);
|
|
auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi));
|
|
context.getDifferentiableFunctionInstWorklist().push_back(newDFI);
|
|
}
|
|
|
|
void visitLinearFunctionInst(LinearFunctionInst *lfi) {
|
|
// Clone `linear_function` from original to VJP, then add the cloned
|
|
// instruction to the `linear_function` worklist.
|
|
TypeSubstCloner::visitLinearFunctionInst(lfi);
|
|
auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi));
|
|
context.getLinearFunctionInstWorklist().push_back(newLFI);
|
|
}
|
|
};
|
|
|
|
/// Initialization helper function.
|
|
///
|
|
/// Returns the substitution map used for type remapping.
|
|
static SubstitutionMap getSubstitutionMap(SILFunction *original,
|
|
SILFunction *vjp) {
|
|
auto substMap = original->getForwardingSubstitutionMap();
|
|
if (auto *vjpGenEnv = vjp->getGenericEnvironment()) {
|
|
auto vjpSubstMap = vjpGenEnv->getForwardingSubstitutionMap();
|
|
substMap = SubstitutionMap::get(
|
|
vjpGenEnv->getGenericSignature(), QuerySubstitutionMap{vjpSubstMap},
|
|
LookUpConformanceInSubstitutionMap(vjpSubstMap));
|
|
}
|
|
return substMap;
|
|
}
|
|
|
|
/// Initialization helper function.
|
|
///
|
|
/// Returns the activity info for the given original function, autodiff indices,
|
|
/// and VJP generic signature.
|
|
static const DifferentiableActivityInfo &
|
|
getActivityInfoHelper(ADContext &context, SILFunction *original,
|
|
const AutoDiffConfig &config, SILFunction *vjp) {
|
|
// Get activity info of the original function.
|
|
auto &passManager = context.getPassManager();
|
|
auto *activityAnalysis =
|
|
passManager.getAnalysis<DifferentiableActivityAnalysis>();
|
|
auto &activityCollection = *activityAnalysis->get(original);
|
|
auto &activityInfo = activityCollection.getActivityInfo(
|
|
vjp->getLoweredFunctionType()->getSubstGenericSignature(),
|
|
AutoDiffDerivativeFunctionKind::VJP);
|
|
LLVM_DEBUG(activityInfo.dump(config, getADDebugStream()));
|
|
return activityInfo;
|
|
}
|
|
|
|
VJPCloner::Implementation::Implementation(VJPCloner &cloner, ADContext &context,
|
|
SILDifferentiabilityWitness *witness,
|
|
SILFunction *vjp,
|
|
DifferentiationInvoker invoker)
|
|
: TypeSubstCloner(*vjp, *witness->getOriginalFunction(),
|
|
getSubstitutionMap(witness->getOriginalFunction(), vjp)),
|
|
cloner(cloner), context(context),
|
|
original(witness->getOriginalFunction()), witness(witness),
|
|
vjp(vjp), invoker(invoker),
|
|
activityInfo(getActivityInfoHelper(
|
|
context, original, witness->getConfig(), vjp)),
|
|
loopInfo(context.getPassManager().getAnalysis<SILLoopAnalysis>()
|
|
->get(original)),
|
|
pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp,
|
|
witness->getConfig(), activityInfo, loopInfo) {
|
|
// Create empty pullback function.
|
|
pullback = createEmptyPullback();
|
|
context.recordGeneratedFunction(pullback);
|
|
}
|
|
|
|
VJPCloner::VJPCloner(ADContext &context,
|
|
SILDifferentiabilityWitness *witness, SILFunction *vjp,
|
|
DifferentiationInvoker invoker)
|
|
: impl(*new Implementation(*this, context, witness, vjp, invoker)) {}
|
|
|
|
VJPCloner::~VJPCloner() { delete &impl; }
|
|
|
|
ADContext &VJPCloner::getContext() const { return impl.context; }
|
|
SILModule &VJPCloner::getModule() const { return impl.getModule(); }
|
|
SILFunction &VJPCloner::getOriginal() const { return *impl.original; }
|
|
SILFunction &VJPCloner::getVJP() const { return *impl.vjp; }
|
|
SILFunction &VJPCloner::getPullback() const { return *impl.pullback; }
|
|
SILDifferentiabilityWitness *VJPCloner::getWitness() const {
|
|
return impl.witness;
|
|
}
|
|
const AutoDiffConfig &VJPCloner::getConfig() const {
|
|
return impl.getConfig();
|
|
}
|
|
DifferentiationInvoker VJPCloner::getInvoker() const { return impl.invoker; }
|
|
LinearMapInfo &VJPCloner::getPullbackInfo() const { return impl.pullbackInfo; }
|
|
SILLoopInfo *VJPCloner::getLoopInfo() const { return impl.loopInfo; }
|
|
const DifferentiableActivityInfo &VJPCloner::getActivityInfo() const {
|
|
return impl.activityInfo;
|
|
}
|
|
|
|
SILFunction *VJPCloner::Implementation::createEmptyPullback() {
|
|
auto origTy = original->getLoweredFunctionType();
|
|
// Get witness generic signature for remapping types.
|
|
// Witness generic signature may have more requirements than VJP generic
|
|
// signature: when witness generic signature has same-type requirements
|
|
// binding all generic parameters to concrete types, VJP function type uses
|
|
// all the concrete types and VJP generic signature is null.
|
|
auto witnessCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature();
|
|
auto lookupConformance = LookUpConformanceInModule();
|
|
|
|
// Given a type, returns its formal SIL parameter info.
|
|
auto getTangentParameterInfoForOriginalResult =
|
|
[&](CanType tanType, ResultConvention origResConv) -> SILParameterInfo {
|
|
tanType = tanType->getReducedType(witnessCanGenSig);
|
|
Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType);
|
|
auto &tl = context.getTypeConverter().getTypeLowering(
|
|
pattern, tanType, TypeExpansionContext::minimal());
|
|
ParameterConvention conv;
|
|
switch (origResConv) {
|
|
case ResultConvention::Unowned:
|
|
case ResultConvention::UnownedInnerPointer:
|
|
case ResultConvention::Owned:
|
|
case ResultConvention::Autoreleased:
|
|
if (tl.isAddressOnly()) {
|
|
conv = ParameterConvention::Indirect_In_Guaranteed;
|
|
} else {
|
|
conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
|
|
: ParameterConvention::Direct_Guaranteed;
|
|
}
|
|
break;
|
|
case ResultConvention::Indirect:
|
|
conv = ParameterConvention::Indirect_In_Guaranteed;
|
|
break;
|
|
case ResultConvention::Pack:
|
|
conv = ParameterConvention::Pack_Guaranteed;
|
|
break;
|
|
}
|
|
return {tanType, conv};
|
|
};
|
|
|
|
// Given a type, returns its formal SIL result info.
|
|
auto getTangentResultInfoForOriginalParameter =
|
|
[&](CanType tanType, ParameterConvention origParamConv) -> SILResultInfo {
|
|
tanType = tanType->getReducedType(witnessCanGenSig);
|
|
Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType);
|
|
auto &tl = context.getTypeConverter().getTypeLowering(
|
|
pattern, tanType, TypeExpansionContext::minimal());
|
|
ResultConvention conv;
|
|
switch (origParamConv) {
|
|
case ParameterConvention::Direct_Owned:
|
|
case ParameterConvention::Direct_Guaranteed:
|
|
case ParameterConvention::Direct_Unowned:
|
|
if (tl.isAddressOnly()) {
|
|
conv = ResultConvention::Indirect;
|
|
} else {
|
|
conv = tl.isTrivial() ? ResultConvention::Unowned
|
|
: ResultConvention::Owned;
|
|
}
|
|
break;
|
|
case ParameterConvention::Indirect_In:
|
|
case ParameterConvention::Indirect_Inout:
|
|
case ParameterConvention::Indirect_In_Guaranteed:
|
|
case ParameterConvention::Indirect_InoutAliasable:
|
|
case ParameterConvention::Indirect_In_CXX:
|
|
conv = ResultConvention::Indirect;
|
|
break;
|
|
case ParameterConvention::Pack_Guaranteed:
|
|
case ParameterConvention::Pack_Owned:
|
|
case ParameterConvention::Pack_Inout:
|
|
conv = ResultConvention::Pack;
|
|
break;
|
|
}
|
|
return {tanType, conv};
|
|
};
|
|
|
|
// Parameters of the pullback are:
|
|
// - the tangent vectors of the original results, and
|
|
// - a pullback struct.
|
|
// Results of the pullback are in the tangent space of the original
|
|
// parameters.
|
|
SmallVector<SILParameterInfo, 8> pbParams;
|
|
SmallVector<SILYieldInfo, 8> pbYields;
|
|
SmallVector<SILResultInfo, 8> adjResults;
|
|
auto origParams = origTy->getParameters();
|
|
auto config = witness->getConfig();
|
|
|
|
// Add pullback parameters based on original result indices.
|
|
SmallVector<unsigned, 4> semanticResultParamIndices;
|
|
for (auto i : range(origTy->getNumParameters())) {
|
|
auto origParam = origParams[i];
|
|
if (!origParam.isAutoDiffSemanticResult())
|
|
continue;
|
|
semanticResultParamIndices.push_back(i);
|
|
}
|
|
|
|
unsigned firstSemanticParamResultIdx = origTy->getNumResults();
|
|
unsigned firstYieldResultIndex = firstSemanticParamResultIdx +
|
|
origTy->getNumAutoDiffSemanticResultsParameters();
|
|
for (auto resultIndex : config.resultIndices->getIndices()) {
|
|
// Handle formal result.
|
|
if (resultIndex < firstSemanticParamResultIdx) {
|
|
auto origResult = origTy->getResults()[resultIndex];
|
|
origResult = origResult.getWithInterfaceType(
|
|
origResult.getInterfaceType()->getReducedType(witnessCanGenSig));
|
|
auto paramInfo = getTangentParameterInfoForOriginalResult(
|
|
origResult.getInterfaceType()
|
|
->getAutoDiffTangentSpace(lookupConformance)
|
|
->getType()
|
|
->getReducedType(witnessCanGenSig),
|
|
origResult.getConvention());
|
|
pbParams.push_back(paramInfo);
|
|
} else if (resultIndex < firstYieldResultIndex) {
|
|
// Handle semantic result parameter.
|
|
unsigned paramIndex = 0;
|
|
unsigned resultParamIndex = 0;
|
|
for (auto i : range(origTy->getNumParameters())) {
|
|
auto origParam = origTy->getParameters()[i];
|
|
if (!origParam.isAutoDiffSemanticResult()) {
|
|
++paramIndex;
|
|
continue;
|
|
}
|
|
if (resultParamIndex == resultIndex - firstSemanticParamResultIdx)
|
|
break;
|
|
++paramIndex;
|
|
++resultParamIndex;
|
|
}
|
|
auto resultParam = origParams[paramIndex];
|
|
auto origResult = resultParam.getWithInterfaceType(
|
|
resultParam.getInterfaceType()->getReducedType(witnessCanGenSig));
|
|
|
|
auto resultParamTanConvention = resultParam.getConvention();
|
|
if (!config.isWrtParameter(paramIndex))
|
|
resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed;
|
|
|
|
pbParams.emplace_back(origResult.getInterfaceType()
|
|
->getAutoDiffTangentSpace(lookupConformance)
|
|
->getType()
|
|
->getReducedType(witnessCanGenSig),
|
|
resultParamTanConvention);
|
|
} else {
|
|
assert(origTy->isCoroutine());
|
|
assert(origTy->getCoroutineKind() == SILCoroutineKind::YieldOnce);
|
|
|
|
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
|
|
auto yieldResult = origTy->getYields()[yieldResultIndex];
|
|
auto origYield =
|
|
yieldResult.getWithInterfaceType(
|
|
yieldResult.getInterfaceType()->getReducedType(witnessCanGenSig));
|
|
assert(yieldResult.getConvention() == ParameterConvention::Indirect_Inout);
|
|
|
|
pbYields.emplace_back(origYield.getInterfaceType()
|
|
->getAutoDiffTangentSpace(lookupConformance)
|
|
->getType()
|
|
->getReducedType(witnessCanGenSig),
|
|
yieldResult.getConvention());
|
|
}
|
|
}
|
|
|
|
if (pullbackInfo.hasHeapAllocatedContext()) {
|
|
// Accept a `AutoDiffLinarMapContext` heap object if there are loops.
|
|
pbParams.push_back({
|
|
getASTContext().TheNativeObjectType,
|
|
ParameterConvention::Direct_Guaranteed
|
|
});
|
|
} else {
|
|
// Accept a pullback tuple in the pullback parameter list. This is the
|
|
// returned pullback's closure context.
|
|
auto *origExit = &*original->findReturnBB();
|
|
auto pbTupleType =
|
|
pullbackInfo.getLinearMapTupleLoweredType(origExit).getAs<TupleType>();
|
|
for (Type eltTy : pbTupleType->getElementTypes())
|
|
pbParams.emplace_back(CanType(eltTy), ParameterConvention::Direct_Owned);
|
|
}
|
|
|
|
// Add pullback results for the requested wrt parameters.
|
|
for (auto i : config.parameterIndices->getIndices()) {
|
|
auto origParam = origParams[i];
|
|
if (origParam.isAutoDiffSemanticResult())
|
|
continue;
|
|
origParam = origParam.getWithInterfaceType(
|
|
origParam.getInterfaceType()->getReducedType(witnessCanGenSig));
|
|
adjResults.push_back(getTangentResultInfoForOriginalParameter(
|
|
origParam.getInterfaceType()
|
|
->getAutoDiffTangentSpace(lookupConformance)
|
|
->getType()
|
|
->getReducedType(witnessCanGenSig),
|
|
origParam.getConvention()));
|
|
}
|
|
|
|
Mangle::DifferentiationMangler mangler(getASTContext());
|
|
auto pbName = mangler.mangleLinearMap(
|
|
original->getName(), AutoDiffLinearMapKind::Pullback, config);
|
|
// Set pullback generic signature equal to VJP generic signature.
|
|
// Do not use witness generic signature, which may have same-type requirements
|
|
// binding all generic parameters to concrete types.
|
|
auto pbGenericSig = vjp->getLoweredFunctionType()->getSubstGenericSignature();
|
|
auto *pbGenericEnv = pbGenericSig.getGenericEnvironment();
|
|
auto pbType = SILFunctionType::get(
|
|
pbGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(),
|
|
origTy->getCalleeConvention(), pbParams, pbYields, adjResults, std::nullopt,
|
|
origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(),
|
|
original->getASTContext());
|
|
|
|
SILOptFunctionBuilder fb(context.getTransform());
|
|
auto linkage = vjp->isSerialized() ? SILLinkage::Public : SILLinkage::Private;
|
|
auto *pullback = fb.createFunction(
|
|
linkage, context.getASTContext().getIdentifier(pbName).str(), pbType,
|
|
pbGenericEnv, original->getLocation(), original->isBare(),
|
|
IsNotTransparent, vjp->getSerializedKind(),
|
|
original->isDynamicallyReplaceable(), original->isDistributed(),
|
|
original->isRuntimeAccessible());
|
|
auto &module = context.getModule();
|
|
pullback->setDebugScope(new (module)
|
|
SILDebugScope(original->getLocation(), pullback));
|
|
|
|
return pullback;
|
|
}
|
|
|
|
SILBasicBlock *VJPCloner::Implementation::createTrampolineBasicBlock(
|
|
TermInst *termInst, TupleInst *pbTupleVal, SILBasicBlock *succBB) {
|
|
assert(llvm::find(termInst->getSuccessorBlocks(), succBB) !=
|
|
termInst->getSuccessorBlocks().end() &&
|
|
"Basic block is not a successor of terminator instruction");
|
|
// Create the trampoline block.
|
|
auto *vjpSuccBB = getOpBasicBlock(succBB);
|
|
auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
|
|
for (auto *arg : vjpSuccBB->getArguments().drop_back())
|
|
trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind());
|
|
// In the trampoline block, build predecessor enum value for VJP successor
|
|
// block and branch to it.
|
|
SILBuilder trampolineBuilder(trampolineBB);
|
|
trampolineBuilder.setCurrentDebugScope(getOpScope(termInst->getDebugScope()));
|
|
auto *origBB = termInst->getParent();
|
|
auto *succEnumVal =
|
|
buildPredecessorEnumValue(trampolineBuilder, origBB, succBB, pbTupleVal);
|
|
SmallVector<SILValue, 4> forwardedArguments(
|
|
trampolineBB->getArguments().begin(), trampolineBB->getArguments().end());
|
|
forwardedArguments.push_back(succEnumVal);
|
|
trampolineBuilder.createBranch(termInst->getLoc(), vjpSuccBB,
|
|
forwardedArguments);
|
|
return trampolineBB;
|
|
}
|
|
|
|
llvm::SmallVector<SILValue, 8>
|
|
VJPCloner::Implementation::getPullbackValues(SILBasicBlock *origBB) {
|
|
auto *vjpBB = BBMap[origBB];
|
|
auto bbPullbackValues = pullbackValues[origBB];
|
|
if (!origBB->isEntry()) {
|
|
auto *predEnumArg = vjpBB->getArguments().back();
|
|
bbPullbackValues.insert(bbPullbackValues.begin(), predEnumArg);
|
|
}
|
|
|
|
return bbPullbackValues;
|
|
}
|
|
|
|
TupleInst *
|
|
VJPCloner::Implementation::buildPullbackValueTupleValue(TermInst *termInst) {
|
|
assert(termInst->getFunction() == original);
|
|
auto loc = RegularLocation::getAutoGeneratedLocation();
|
|
auto origBB = termInst->getParent();
|
|
auto tupleLoweredTy =
|
|
remapType(pullbackInfo.getLinearMapTupleLoweredType(origBB));
|
|
auto bbPullbackValues = getPullbackValues(origBB);
|
|
return getBuilder().createTuple(loc, tupleLoweredTy, bbPullbackValues);
|
|
}
|
|
|
|
EnumInst *VJPCloner::Implementation::buildPredecessorEnumValue(
|
|
SILBuilder &builder, SILBasicBlock *predBB, SILBasicBlock *succBB,
|
|
SILValue pbTupleVal) {
|
|
auto loc = RegularLocation::getAutoGeneratedLocation();
|
|
auto enumLoweredTy =
|
|
remapType(pullbackInfo.getBranchingTraceEnumLoweredType(succBB));
|
|
auto *enumEltDecl =
|
|
pullbackInfo.lookUpBranchingTraceEnumElement(predBB, succBB);
|
|
auto enumEltType = getOpType(enumLoweredTy.getEnumElementType(
|
|
enumEltDecl, getModule(), TypeExpansionContext::minimal()));
|
|
// If the predecessor block is in a loop, its predecessor enum payload is a
|
|
// `Builtin.RawPointer`.
|
|
if (loopInfo->getLoopFor(predBB)) {
|
|
auto rawPtrType = SILType::getRawPointerType(getASTContext());
|
|
assert(enumEltType == rawPtrType);
|
|
auto pbTupleType =
|
|
remapASTType(pullbackInfo.getLinearMapTupleType(predBB)->getCanonicalType());
|
|
|
|
auto pbTupleMetatypeType =
|
|
CanMetatypeType::get(pbTupleType, MetatypeRepresentation::Thick);
|
|
auto pbTupleMetatypeSILType =
|
|
SILType::getPrimitiveObjectType(pbTupleMetatypeType);
|
|
auto pbTupleMetatype =
|
|
Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType);
|
|
|
|
auto rawBufferValue = builder.createBuiltin(
|
|
loc,
|
|
getASTContext().getIdentifier(getBuiltinName(
|
|
BuiltinValueKind::AutoDiffAllocateSubcontextWithType)),
|
|
rawPtrType, SubstitutionMap(),
|
|
{borrowedPullbackContextValue, pbTupleMetatype});
|
|
|
|
auto typedBufferValue =
|
|
builder.createPointerToAddress(
|
|
loc, rawBufferValue, pbTupleVal->getType().getAddressType(),
|
|
/*isStrict*/ true);
|
|
builder.createStore(
|
|
loc, pbTupleVal, typedBufferValue,
|
|
pbTupleVal->getType().isTrivial(*pullback) ?
|
|
StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init);
|
|
return builder.createEnum(loc, rawBufferValue, enumEltDecl, enumLoweredTy);
|
|
}
|
|
return builder.createEnum(loc, pbTupleVal, enumEltDecl, enumLoweredTy);
|
|
}
|
|
|
|
bool VJPCloner::Implementation::run() {
|
|
PrettyStackTraceSILFunction trace("generating VJP for", original);
|
|
LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName()
|
|
<< " to vjp @" << vjp->getName() << '\n');
|
|
|
|
// Create entry BB and arguments.
|
|
auto *entry = vjp->createBasicBlock();
|
|
createEntryArguments(vjp);
|
|
|
|
emitLinearMapContextInitializationIfNeeded();
|
|
|
|
// Clone.
|
|
SmallVector<SILValue, 4> entryArgs;
|
|
entryArgs.assign(entry->getArguments().begin(),
|
|
entry->getArguments().end());
|
|
|
|
cloneFunctionBody(original, entry, entryArgs);
|
|
// If errors occurred, back out.
|
|
if (errorOccurred)
|
|
return true;
|
|
|
|
// Merge VJP basic blocks. This is significant for control flow
|
|
// differentiation: trampoline destination bbs are merged into trampoline bbs.
|
|
// NOTE(TF-990): Merging basic blocks ensures that `@guaranteed` trampoline
|
|
// bb arguments have a lifetime-ending `end_borrow` use, and is robust when
|
|
// `-enable-strip-ownership-after-serialization` is true.
|
|
mergeBasicBlocks(vjp);
|
|
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Generated VJP for " << original->getName() << ":\n"
|
|
<< *vjp);
|
|
|
|
// Generate pullback code.
|
|
PullbackCloner PullbackCloner(cloner);
|
|
if (PullbackCloner.run()) {
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
return errorOccurred;
|
|
}
|
|
|
|
bool VJPCloner::run() {
|
|
bool foundError = impl.run();
|
|
#ifndef NDEBUG
|
|
if (!foundError)
|
|
getVJP().verify();
|
|
#endif
|
|
return foundError;
|
|
}
|
|
|
|
} // end namespace autodiff
|
|
} // end namespace swift
|