mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
We are going to need to add more flags to the various checked cast instructions. Generalize the CastingIsolatedConformances bit in all of these SIL instructions to an "options" struct that's easier to extend. Precursor to rdar://152335805.
3795 lines
160 KiB
C++
3795 lines
160 KiB
C++
//===--- PullbackCloner.cpp - Pullback function generation ---*- C++ -*----===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines a helper class for generating pullback functions for
|
|
// automatic differentiation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define DEBUG_TYPE "differentiation"
|
|
|
|
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
|
|
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
|
|
#include "swift/SILOptimizer/Differentiation/ADContext.h"
|
|
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
|
|
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
|
|
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
|
|
#include "swift/SILOptimizer/Differentiation/Thunk.h"
|
|
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
|
|
|
|
#include "swift/AST/ConformanceLookup.h"
|
|
#include "swift/AST/Expr.h"
|
|
#include "swift/AST/PropertyWrappers.h"
|
|
#include "swift/AST/TypeCheckRequests.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
#include "swift/Basic/STLExtras.h"
|
|
#include "swift/SIL/ApplySite.h"
|
|
#include "swift/SIL/InstructionUtils.h"
|
|
#include "swift/SIL/Projection.h"
|
|
#include "swift/SIL/TypeSubstCloner.h"
|
|
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
|
|
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
|
|
namespace swift {
|
|
|
|
class SILDifferentiabilityWitness;
|
|
class SILBasicBlock;
|
|
class SILFunction;
|
|
class SILInstruction;
|
|
|
|
namespace autodiff {
|
|
|
|
class ADContext;
|
|
class VJPCloner;
|
|
|
|
/// The implementation class for `PullbackCloner`.
|
|
///
|
|
/// The implementation class is a `SILInstructionVisitor`. Effectively, it acts
|
|
/// as a `SILCloner` that visits basic blocks in post-order and that visits
|
|
/// instructions per basic block in reverse order. This visitation order is
|
|
/// necessary for generating pullback functions, whose control flow graph is
|
|
/// ~a transposed version of the original function's control flow graph.
|
|
class PullbackCloner::Implementation final
|
|
: public SILInstructionVisitor<PullbackCloner::Implementation> {
|
|
|
|
public:
|
|
explicit Implementation(VJPCloner &vjpCloner);
|
|
|
|
private:
|
|
/// The parent VJP cloner.
|
|
VJPCloner &vjpCloner;
|
|
|
|
/// Dominance info for the original function.
|
|
DominanceInfo *domInfo = nullptr;
|
|
|
|
/// Post-dominance info for the original function.
|
|
PostDominanceInfo *postDomInfo = nullptr;
|
|
|
|
/// Post-order info for the original function.
|
|
PostOrderFunctionInfo *postOrderInfo = nullptr;
|
|
|
|
/// Mapping from original basic blocks to corresponding pullback basic blocks.
|
|
/// Pullback basic blocks always have the predecessor as the single argument.
|
|
llvm::DenseMap<SILBasicBlock *, SILBasicBlock *> pullbackBBMap;
|
|
|
|
/// Mapping from original basic blocks and original values to corresponding
|
|
/// adjoint values.
|
|
llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, AdjointValue> valueMap;
|
|
|
|
/// Mapping from original basic blocks and original values to corresponding
|
|
/// adjoint buffers.
|
|
llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap;
|
|
|
|
/// Mapping from pullback struct field declarations to pullback struct
|
|
/// elements destructured from the linear map basic block argument. In the
|
|
/// beginning of each pullback basic block, the block's pullback struct is
|
|
/// destructured into individual elements stored here.
|
|
llvm::DenseMap<SILBasicBlock*, SmallVector<SILValue, 4>> pullbackTupleElements;
|
|
|
|
/// Mapping from original basic blocks and successor basic blocks to
|
|
/// corresponding pullback trampoline basic blocks. Trampoline basic blocks
|
|
/// take additional arguments in addition to the predecessor enum argument.
|
|
llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, SILBasicBlock *>
|
|
pullbackTrampolineBBMap;
|
|
|
|
/// Mapping from original basic blocks to dominated active values.
|
|
llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues;
|
|
|
|
/// Mapping from original basic blocks and original active values to
|
|
/// corresponding pullback block arguments.
|
|
llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *>
|
|
activeValuePullbackBBArgumentMap;
|
|
|
|
/// Mapping from original basic blocks to local temporary values to be cleaned
|
|
/// up. This is populated when pullback emission is run on one basic block and
|
|
/// cleaned before processing another basic block.
|
|
llvm::DenseMap<SILBasicBlock *, llvm::SmallSetVector<SILValue, 32>>
|
|
blockTemporaries;
|
|
|
|
/// The scope cloner.
|
|
ScopeCloner scopeCloner;
|
|
|
|
/// The main builder.
|
|
TangentBuilder builder;
|
|
|
|
/// An auxiliary local allocation builder.
|
|
TangentBuilder localAllocBuilder;
|
|
|
|
/// The original function's exit block.
|
|
SILBasicBlock *originalExitBlock = nullptr;
|
|
|
|
/// Stack buffers allocated for storing local adjoint values.
|
|
SmallVector<AllocStackInst *, 64> functionLocalAllocations;
|
|
|
|
/// Copies created to deal with destructive enum operations
|
|
/// (unchecked_take_enum_addr)
|
|
llvm::SmallDenseMap<InitEnumDataAddrInst*, SILValue> enumDataAdjCopies;
|
|
|
|
/// A set used to remember local allocations that were destroyed.
|
|
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
|
|
|
|
/// The seed arguments of the pullback function.
|
|
SmallVector<SILArgument *, 4> seeds;
|
|
|
|
/// The `AutoDiffLinearMapContext` object, if any.
|
|
SILValue contextValue = nullptr;
|
|
|
|
llvm::BumpPtrAllocator allocator;
|
|
|
|
bool errorOccurred = false;
|
|
|
|
ADContext &getContext() const { return vjpCloner.getContext(); }
|
|
SILModule &getModule() const { return getContext().getModule(); }
|
|
ASTContext &getASTContext() const { return getPullback().getASTContext(); }
|
|
SILFunction &getOriginal() const { return vjpCloner.getOriginal(); }
|
|
SILDifferentiabilityWitness *getWitness() const {
|
|
return vjpCloner.getWitness();
|
|
}
|
|
DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); }
|
|
LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); }
|
|
const AutoDiffConfig &getConfig() const { return vjpCloner.getConfig(); }
|
|
const DifferentiableActivityInfo &getActivityInfo() const {
|
|
return vjpCloner.getActivityInfo();
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Pullback struct mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
void initializePullbackTupleElements(SILBasicBlock *origBB,
|
|
SILInstructionResultArray values) {
|
|
auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB);
|
|
assert(pbTupleTyple->getNumElements() == values.size() &&
|
|
"The number of pullback tuple fields must equal the number of "
|
|
"pullback tuple element values");
|
|
auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }});
|
|
(void)res;
|
|
assert(res.second && "A pullback tuple element already exists!");
|
|
}
|
|
|
|
void initializePullbackTupleElements(SILBasicBlock *origBB,
|
|
const llvm::ArrayRef<SILArgument *> &values) {
|
|
auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB);
|
|
assert(pbTupleTyple->getNumElements() == values.size() &&
|
|
"The number of pullback tuple fields must equal the number of "
|
|
"pullback tuple element values");
|
|
auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }});
|
|
(void)res;
|
|
assert(res.second && "A pullback struct element already exists!");
|
|
}
|
|
|
|
/// Returns the pullback tuple element value corresponding to the given
|
|
/// original block and apply inst.
|
|
SILValue getPullbackTupleElement(FullApplySite fai) {
|
|
unsigned idx = getPullbackInfo().lookUpLinearMapIndex(fai);
|
|
assert((idx > 0 || (idx == 0 && fai.getParent()->isEntry())) &&
|
|
"impossible linear map index");
|
|
auto values = pullbackTupleElements.lookup(fai.getParent());
|
|
assert(idx < values.size() &&
|
|
"pullback tuple element for this apply does not exist!");
|
|
return values[idx];
|
|
}
|
|
|
|
/// Returns the pullback tuple element value corresponding to the predecessor
|
|
/// for the given original block.
|
|
SILValue getPullbackPredTupleElement(SILBasicBlock *origBB) {
|
|
assert(!origBB->isEntry() && "no predecessors for entry block");
|
|
auto values = pullbackTupleElements.lookup(origBB);
|
|
assert(values.size() && "pullback tuple cannot be empty");
|
|
return values[0];
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Type transformer
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Get the type lowering for the given AST type.
|
|
///
|
|
/// Explicitly use minimal type expansion context: in general, differentiation
|
|
/// happens on function types, so it cannot know if the original function is
|
|
/// resilient or not.
|
|
const Lowering::TypeLowering &getTypeLowering(Type type) {
|
|
auto pbGenSig =
|
|
getPullback().getLoweredFunctionType()->getSubstGenericSignature();
|
|
Lowering::AbstractionPattern pattern(pbGenSig,
|
|
type->getReducedType(pbGenSig));
|
|
return getContext().getTypeConverter().getTypeLowering(
|
|
pattern, type, TypeExpansionContext::minimal());
|
|
}
|
|
|
|
/// Remap any archetypes into the current function's context.
|
|
SILType remapType(SILType ty) {
|
|
if (ty.hasArchetype())
|
|
ty = ty.mapTypeOutOfContext();
|
|
auto remappedType = ty.getASTType()->getReducedType(
|
|
getPullback().getLoweredFunctionType()->getSubstGenericSignature());
|
|
auto remappedSILType =
|
|
SILType::getPrimitiveType(remappedType, ty.getCategory());
|
|
// FIXME: Sometimes getPullback() doesn't have a generic environment, in which
|
|
// case callers are apparently happy to receive an interface type.
|
|
if (getPullback().getGenericEnvironment())
|
|
return getPullback().mapTypeIntoContext(remappedSILType);
|
|
return remappedSILType;
|
|
}
|
|
|
|
std::optional<TangentSpace> getTangentSpace(CanType type) {
|
|
// Use witness generic signature to remap types.
|
|
type =
|
|
getWitness()->getDerivativeGenericSignature().getReducedType(
|
|
type);
|
|
return type->getAutoDiffTangentSpace(
|
|
LookUpConformanceInModule());
|
|
}
|
|
|
|
/// Returns the tangent value category of the given value.
|
|
SILValueCategory getTangentValueCategory(SILValue v) {
|
|
// Tangent value category table:
|
|
//
|
|
// Let $L be a loadable type and $*A be an address-only type.
|
|
//
|
|
// Original type | Tangent type loadable? | Tangent value category and type
|
|
// --------------|------------------------|--------------------------------
|
|
// $L | loadable | object, $L' (no mismatch)
|
|
// $*A | loadable | address, $*L' (create a buffer)
|
|
// $L | address-only | address, $*A' (no alternative)
|
|
// $*A | address-only | address, $*A' (no alternative)
|
|
|
|
// TODO(https://github.com/apple/swift/issues/55523): Make "tangent value category" depend solely on whether the tangent type is loadable or address-only.
|
|
//
|
|
// For loadable tangent types, using symbolic adjoint values instead of
|
|
// concrete adjoint buffers is more efficient.
|
|
|
|
// Quick check: if the value has an address type, the tangent value category
|
|
// is currently always "address".
|
|
if (v->getType().isAddress())
|
|
return SILValueCategory::Address;
|
|
// If the value has an object type and the tangent type is not address-only,
|
|
// then the tangent value category is "object".
|
|
auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType());
|
|
auto tanASTType = tanSpace->getCanonicalType();
|
|
if (v->getType().isObject() && getTypeLowering(tanASTType).isLoadable())
|
|
return SILValueCategory::Object;
|
|
// Otherwise, the tangent value category is "address".
|
|
return SILValueCategory::Address;
|
|
}
|
|
|
|
/// Assuming the given type conforms to `Differentiable` after remapping,
|
|
/// returns the associated tangent space type.
|
|
SILType getRemappedTangentType(SILType type) {
|
|
return SILType::getPrimitiveType(
|
|
getTangentSpace(remapType(type).getASTType())->getCanonicalType(),
|
|
type.getCategory());
|
|
}
|
|
|
|
/// Substitutes all replacement types of the given substitution map using the
|
|
/// pullback function's substitution map.
|
|
SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) {
|
|
return substMap.subst(getPullback().getForwardingSubstitutionMap());
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Temporary value management
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Record a temporary value for cleanup before its block's terminator.
|
|
SILValue recordTemporary(SILValue value) {
|
|
assert(value->getType().isObject());
|
|
assert(value->getFunction() == &getPullback());
|
|
auto inserted = blockTemporaries[value->getParentBlock()].insert(value);
|
|
(void)inserted;
|
|
LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value);
|
|
assert(inserted && "Temporary already recorded?");
|
|
return value;
|
|
}
|
|
|
|
/// Clean up all temporary values for the given pullback block.
|
|
void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) {
|
|
assert(bb->getParent() == &getPullback());
|
|
LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb"
|
|
<< bb->getDebugID() << '\n');
|
|
for (auto temp : blockTemporaries[bb])
|
|
builder.emitDestroyValueOperation(loc, temp);
|
|
blockTemporaries[bb].clear();
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint value factory methods
|
|
//--------------------------------------------------------------------------//
|
|
|
|
AdjointValue makeZeroAdjointValue(SILType type) {
|
|
return AdjointValue::createZero(allocator, remapType(type));
|
|
}
|
|
|
|
AdjointValue makeConcreteAdjointValue(SILValue value) {
|
|
return AdjointValue::createConcrete(allocator, value);
|
|
}
|
|
|
|
AdjointValue makeAggregateAdjointValue(SILType type,
|
|
ArrayRef<AdjointValue> elements) {
|
|
return AdjointValue::createAggregate(allocator, remapType(type), elements);
|
|
}
|
|
|
|
AdjointValue makeAddElementAdjointValue(AdjointValue baseAdjoint,
|
|
AdjointValue eltToAdd,
|
|
FieldLocator fieldLocator) {
|
|
auto *addElementValue =
|
|
new AddElementValue(baseAdjoint, eltToAdd, fieldLocator);
|
|
return AdjointValue::createAddElement(allocator, baseAdjoint.getType(),
|
|
addElementValue);
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint value materialization
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Materializes an adjoint value. The type of the given adjoint value must be
|
|
/// loadable.
|
|
SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) {
|
|
assert(val.getType().isObject());
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Materializing adjoint for " << val << '\n');
|
|
SILValue result;
|
|
switch (val.getKind()) {
|
|
case AdjointValueKind::Zero:
|
|
result = recordTemporary(builder.emitZero(loc, val.getSwiftType()));
|
|
break;
|
|
case AdjointValueKind::Aggregate: {
|
|
SmallVector<SILValue, 8> elements;
|
|
for (auto i : range(val.getNumAggregateElements())) {
|
|
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
|
|
elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
|
|
}
|
|
if (val.getType().is<TupleType>())
|
|
result = recordTemporary(
|
|
builder.createTuple(loc, val.getType(), elements));
|
|
else
|
|
result = recordTemporary(
|
|
builder.createStruct(loc, val.getType(), elements));
|
|
break;
|
|
}
|
|
case AdjointValueKind::Concrete:
|
|
result = val.getConcreteValue();
|
|
break;
|
|
case AdjointValueKind::AddElement: {
|
|
auto adjointSILType = val.getAddElementValue()->baseAdjoint.getType();
|
|
auto *baseAdjAlloc = builder.createAllocStack(loc, adjointSILType);
|
|
materializeAdjointIndirect(val, baseAdjAlloc, loc);
|
|
|
|
auto baseAdjConcrete = recordTemporary(builder.emitLoadValueOperation(
|
|
loc, baseAdjAlloc, LoadOwnershipQualifier::Take));
|
|
|
|
builder.createDeallocStack(loc, baseAdjAlloc);
|
|
|
|
result = baseAdjConcrete;
|
|
break;
|
|
}
|
|
}
|
|
if (auto debugInfo = val.getDebugInfo())
|
|
builder.createDebugValue(
|
|
debugInfo->first.getLocation(), result, debugInfo->second);
|
|
return result;
|
|
}
|
|
|
|
/// Materializes an adjoint value indirectly to a SIL buffer.
|
|
void materializeAdjointIndirect(AdjointValue val, SILValue destAddress,
|
|
SILLocation loc) {
|
|
assert(destAddress->getType().isAddress());
|
|
switch (val.getKind()) {
|
|
/// If adjoint value is a symbolic zero, emit a call to
|
|
/// `AdditiveArithmetic.zero`.
|
|
case AdjointValueKind::Zero:
|
|
builder.emitZeroIntoBuffer(loc, destAddress, IsInitialization);
|
|
break;
|
|
/// If adjoint value is a symbolic aggregate (tuple or struct), recursively
|
|
/// materialize the symbolic tuple or struct, filling the
|
|
/// buffer.
|
|
case AdjointValueKind::Aggregate: {
|
|
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
|
|
for (auto idx : range(val.getNumAggregateElements())) {
|
|
auto eltTy = SILType::getPrimitiveAddressType(
|
|
tupTy->getElementType(idx)->getCanonicalType());
|
|
auto *eltBuf =
|
|
builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
|
|
materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc);
|
|
}
|
|
} else if (auto *structDecl =
|
|
val.getSwiftType()->getStructOrBoundGenericStruct()) {
|
|
auto fieldIt = structDecl->getStoredProperties().begin();
|
|
for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
|
|
++fieldIt, ++i) {
|
|
auto eltBuf =
|
|
builder.createStructElementAddr(loc, destAddress, *fieldIt);
|
|
materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc);
|
|
}
|
|
} else {
|
|
llvm_unreachable("Not an aggregate type");
|
|
}
|
|
break;
|
|
}
|
|
/// If adjoint value is concrete, it is already materialized. Store it in
|
|
/// the destination address.
|
|
case AdjointValueKind::Concrete: {
|
|
auto concreteVal = val.getConcreteValue();
|
|
auto copyOfConcreteVal = builder.emitCopyValueOperation(loc, concreteVal);
|
|
builder.emitStoreValueOperation(loc, copyOfConcreteVal, destAddress,
|
|
StoreOwnershipQualifier::Init);
|
|
break;
|
|
}
|
|
case AdjointValueKind::AddElement: {
|
|
auto baseAdjoint = val;
|
|
auto baseAdjointType = baseAdjoint.getType();
|
|
|
|
// Current adjoint may be made up of layers of `AddElement` adjoints.
|
|
// We can iteratively gather the list of elements to add instead of making
|
|
// recursive calls to `materializeAdjointIndirect`.
|
|
SmallVector<AddElementValue *, 4> addEltAdjValues;
|
|
|
|
do {
|
|
auto addElementValue = baseAdjoint.getAddElementValue();
|
|
addEltAdjValues.push_back(addElementValue);
|
|
baseAdjoint = addElementValue->baseAdjoint;
|
|
assert(baseAdjointType == baseAdjoint.getType());
|
|
} while (baseAdjoint.getKind() == AdjointValueKind::AddElement);
|
|
|
|
materializeAdjointIndirect(baseAdjoint, destAddress, loc);
|
|
|
|
for (auto *addElementValue : addEltAdjValues) {
|
|
auto eltToAdd = addElementValue->eltToAdd;
|
|
|
|
SILValue baseAdjEltAddr;
|
|
if (baseAdjoint.getType().is<TupleType>()) {
|
|
baseAdjEltAddr = builder.createTupleElementAddr(
|
|
loc, destAddress, addElementValue->getFieldIndex());
|
|
} else {
|
|
baseAdjEltAddr = builder.createStructElementAddr(
|
|
loc, destAddress, addElementValue->getFieldDecl());
|
|
}
|
|
|
|
auto eltToAddMaterialized = materializeAdjointDirect(eltToAdd, loc);
|
|
// Copy `eltToAddMaterialized` so we have a value with owned ownership
|
|
// semantics, required for using `eltToAddMaterialized` in a `store`
|
|
// instruction.
|
|
auto eltToAddMaterializedCopy =
|
|
builder.emitCopyValueOperation(loc, eltToAddMaterialized);
|
|
auto *eltToAddAlloc = builder.createAllocStack(loc, eltToAdd.getType());
|
|
builder.emitStoreValueOperation(loc, eltToAddMaterializedCopy,
|
|
eltToAddAlloc,
|
|
StoreOwnershipQualifier::Init);
|
|
|
|
builder.emitInPlaceAdd(loc, baseAdjEltAddr, eltToAddAlloc);
|
|
builder.createDestroyAddr(loc, eltToAddAlloc);
|
|
builder.createDeallocStack(loc, eltToAddAlloc);
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint value mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Returns true if the given value in the original function has a
|
|
/// corresponding adjoint value.
|
|
bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const {
|
|
assert(origBB->getParent() == &getOriginal());
|
|
assert(originalValue->getType().isObject());
|
|
return valueMap.count({origBB, originalValue});
|
|
}
|
|
|
|
/// Initializes the adjoint value for the original value. Asserts that the
|
|
/// original value does not already have an adjoint value.
|
|
void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
|
|
AdjointValue adjointValue) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Setting adjoint value for " << originalValue);
|
|
assert(origBB->getParent() == &getOriginal());
|
|
assert(originalValue->getType().isObject());
|
|
assert(getTangentValueCategory(originalValue) == SILValueCategory::Object);
|
|
assert(adjointValue.getType().isObject());
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
// The adjoint value must be in the tangent space.
|
|
assert(adjointValue.getType() ==
|
|
getRemappedTangentType(originalValue->getType()));
|
|
// Try to assign a debug variable.
|
|
if (auto debugInfo = findDebugLocationAndVariable(originalValue)) {
|
|
LLVM_DEBUG({
|
|
auto &s = getADDebugStream();
|
|
s << "Found debug variable: \"" << debugInfo->second.Name
|
|
<< "\"\nLocation: ";
|
|
debugInfo->first.getLocation().print(s, getASTContext().SourceMgr);
|
|
s << '\n';
|
|
});
|
|
adjointValue.setDebugInfo(*debugInfo);
|
|
} else {
|
|
LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n");
|
|
}
|
|
// Insert into dictionary.
|
|
auto insertion =
|
|
valueMap.try_emplace({origBB, originalValue}, adjointValue);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "The new adjoint value, replacing the existing one, is: "
|
|
<< insertion.first->getSecond() << '\n');
|
|
if (!insertion.second)
|
|
insertion.first->getSecond() = adjointValue;
|
|
}
|
|
|
|
/// Returns the adjoint value for a value in the original function.
|
|
///
|
|
/// This method first tries to find an existing entry in the adjoint value
|
|
/// mapping. If no entry exists, creates a zero adjoint value.
|
|
AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) {
|
|
assert(origBB->getParent() == &getOriginal());
|
|
assert(originalValue->getType().isObject());
|
|
assert(getTangentValueCategory(originalValue) == SILValueCategory::Object);
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
auto insertion = valueMap.try_emplace(
|
|
{origBB, originalValue},
|
|
makeZeroAdjointValue(getRemappedTangentType(originalValue->getType())));
|
|
auto it = insertion.first;
|
|
return it->getSecond();
|
|
}
|
|
|
|
/// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets
|
|
/// the sum as the new adjoint value.
|
|
void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
|
|
AdjointValue newAdjointValue, SILLocation loc) {
|
|
assert(origBB->getParent() == &getOriginal());
|
|
assert(originalValue->getType().isObject());
|
|
assert(newAdjointValue.getType().isObject());
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Adding adjoint value for " << originalValue);
|
|
// The adjoint value must be in the tangent space.
|
|
assert(newAdjointValue.getType() ==
|
|
getRemappedTangentType(originalValue->getType()));
|
|
// Try to assign a debug variable.
|
|
if (auto debugInfo = findDebugLocationAndVariable(originalValue)) {
|
|
LLVM_DEBUG({
|
|
auto &s = getADDebugStream();
|
|
s << "Found debug variable: \"" << debugInfo->second.Name
|
|
<< "\"\nLocation: ";
|
|
debugInfo->first.getLocation().print(s, getASTContext().SourceMgr);
|
|
s << '\n';
|
|
});
|
|
newAdjointValue.setDebugInfo(*debugInfo);
|
|
} else {
|
|
LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n");
|
|
}
|
|
auto insertion =
|
|
valueMap.try_emplace({origBB, originalValue}, newAdjointValue);
|
|
auto inserted = insertion.second;
|
|
if (inserted)
|
|
return;
|
|
// If adjoint already exists, accumulate the adjoint onto the existing
|
|
// adjoint.
|
|
auto it = insertion.first;
|
|
auto existingValue = it->getSecond();
|
|
valueMap.erase(it);
|
|
auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc);
|
|
// If the original value is the `Array` result of an
|
|
// `array.uninitialized_intrinsic` application, accumulate adjoint buffers
|
|
// for the array element addresses.
|
|
accumulateArrayLiteralElementAddressAdjoints(origBB, originalValue, adjVal,
|
|
loc);
|
|
setAdjointValue(origBB, originalValue, adjVal);
|
|
}
|
|
|
|
/// Get the pullback block argument corresponding to the given original block
|
|
/// and active value.
|
|
SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB,
|
|
SILValue activeValue) {
|
|
assert(getTangentValueCategory(activeValue) == SILValueCategory::Object);
|
|
assert(origBB->getParent() == &getOriginal());
|
|
auto pullbackBBArg =
|
|
activeValuePullbackBBArgumentMap[{origBB, activeValue}];
|
|
assert(pullbackBBArg);
|
|
assert(pullbackBBArg->getParent() == getPullbackBlock(origBB));
|
|
return pullbackBBArg;
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint value accumulation
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Given two adjoint values, accumulates them and returns their sum.
|
|
AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs,
|
|
SILLocation loc);
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint buffer mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// If the given original value is an address projection, returns a
|
|
/// corresponding adjoint projection to be used as its adjoint buffer.
|
|
///
|
|
/// Helper function for `getAdjointBuffer`.
|
|
SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue);
|
|
|
|
/// Returns the adjoint buffer for the original value.
|
|
///
|
|
/// This method first tries to find an existing entry in the adjoint buffer
|
|
/// mapping. If no entry exists, creates a zero adjoint buffer.
|
|
SILValue getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) {
|
|
assert(getTangentValueCategory(originalValue) == SILValueCategory::Address);
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue());
|
|
if (!insertion.second) // not inserted
|
|
return insertion.first->getSecond();
|
|
|
|
// If the original buffer is a projection, return a corresponding projection
|
|
// into the adjoint buffer.
|
|
if (auto adjProj = getAdjointProjection(origBB, originalValue))
|
|
return (bufferMap[{origBB, originalValue}] = adjProj);
|
|
|
|
LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for "
|
|
<< originalValue
|
|
<< "in bb" << origBB->getDebugID() << '\n');
|
|
|
|
auto bufType = getRemappedTangentType(originalValue->getType());
|
|
// Set insertion point for local allocation builder: before the last local
|
|
// allocation, or at the start of the pullback function's entry if no local
|
|
// allocations exist yet.
|
|
auto debugInfo = findDebugLocationAndVariable(originalValue);
|
|
SILLocation loc = debugInfo ? debugInfo->first.getLocation()
|
|
: RegularLocation::getAutoGeneratedLocation();
|
|
llvm::SmallString<32> adjName;
|
|
auto *newBuf = createFunctionLocalAllocation(
|
|
bufType, loc, /*zeroInitialize*/ true,
|
|
swift::transform(debugInfo,
|
|
[&](AdjointValue::DebugInfo di) {
|
|
llvm::raw_svector_ostream adjNameStream(adjName);
|
|
SILDebugVariable &dv = di.second;
|
|
dv.ArgNo = 0;
|
|
adjNameStream << "derivative of '" << dv.Name << "'";
|
|
if (SILDebugLocation origBBLoc = origBB->front().getDebugLocation()) {
|
|
adjNameStream << " in scope at ";
|
|
origBBLoc.getLocation().print(adjNameStream, getASTContext().SourceMgr);
|
|
}
|
|
adjNameStream << " (scope #" << origBB->getDebugID() << ")";
|
|
dv.Name = adjName;
|
|
// We have no meaningful debug location, and the type is different.
|
|
dv.Scope = nullptr;
|
|
dv.Loc = {};
|
|
dv.Type = {};
|
|
dv.DIExpr = {};
|
|
return dv;
|
|
}));
|
|
return (insertion.first->getSecond() = newBuf);
|
|
}
|
|
|
|
/// Initializes the adjoint buffer for the original value. Asserts that the
|
|
/// original value does not already have an adjoint buffer.
|
|
void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue,
|
|
SILValue adjointBuffer) {
|
|
assert(getTangentValueCategory(originalValue) == SILValueCategory::Address);
|
|
auto insertion =
|
|
bufferMap.try_emplace({origBB, originalValue}, adjointBuffer);
|
|
assert(insertion.second && "Adjoint buffer already exists");
|
|
(void)insertion;
|
|
}
|
|
|
|
/// Accumulates `rhsAddress` into the adjoint buffer corresponding to the
|
|
/// original value.
|
|
void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue,
|
|
SILValue rhsAddress, SILLocation loc) {
|
|
assert(getTangentValueCategory(originalValue) ==
|
|
SILValueCategory::Address &&
|
|
rhsAddress->getType().isAddress());
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
assert(rhsAddress->getFunction() == &getPullback());
|
|
auto adjointBuffer = getAdjointBuffer(origBB, originalValue);
|
|
|
|
LLVM_DEBUG(getADDebugStream() << "Adding"
|
|
<< rhsAddress << "to adjoint ("
|
|
<< adjointBuffer << ") of "
|
|
<< originalValue
|
|
<< "in bb" << origBB->getDebugID() << '\n');
|
|
|
|
builder.emitInPlaceAdd(loc, adjointBuffer, rhsAddress);
|
|
}
|
|
|
|
/// Returns a next insertion point for creating a local allocation: either
|
|
/// before the previous local allocation, or at the start of the pullback
|
|
/// entry if no local allocations exist.
|
|
///
|
|
/// Helper for `createFunctionLocalAllocation`.
|
|
SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() {
|
|
// If there are no local allocations, insert at the pullback entry start.
|
|
if (functionLocalAllocations.empty())
|
|
return getPullback().getEntryBlock()->begin();
|
|
// Otherwise, insert before the last local allocation. Inserting before
|
|
// rather than after ensures that allocation and zero initialization
|
|
// instructions are grouped together.
|
|
auto lastLocalAlloc = functionLocalAllocations.back();
|
|
return lastLocalAlloc->getDefiningInstruction()->getIterator();
|
|
}
|
|
|
|
/// Creates and returns a local allocation with the given type.
|
|
///
|
|
/// Local allocations are created uninitialized in the pullback entry and
|
|
/// deallocated in the pullback exit. All local allocations not in
|
|
/// `destroyedLocalAllocations` are also destroyed in the pullback exit.
|
|
///
|
|
/// Helper for `getAdjointBuffer`.
|
|
AllocStackInst *createFunctionLocalAllocation(
|
|
SILType type, SILLocation loc, bool zeroInitialize = false,
|
|
std::optional<SILDebugVariable> varInfo = std::nullopt) {
|
|
// Set insertion point for local allocation builder: before the last local
|
|
// allocation, or at the start of the pullback function's entry if no local
|
|
// allocations exist yet.
|
|
localAllocBuilder.setInsertionPoint(
|
|
getPullback().getEntryBlock(),
|
|
getNextFunctionLocalAllocationInsertionPoint());
|
|
// Create and return local allocation.
|
|
auto *alloc = localAllocBuilder.createAllocStack(loc, type, varInfo);
|
|
functionLocalAllocations.push_back(alloc);
|
|
// Zero-initialize if requested.
|
|
if (zeroInitialize)
|
|
localAllocBuilder.emitZeroIntoBuffer(loc, alloc, IsInitialization);
|
|
return alloc;
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Optional differentiation
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Given a `wrappedAdjoint` value of type `T.TangentVector` and `Optional<T>`
|
|
/// type, creates an `Optional<T>.TangentVector` buffer from it.
|
|
///
|
|
/// `wrappedAdjoint` may be an object or address value, both cases are
|
|
/// handled.
|
|
AllocStackInst *createOptionalAdjoint(SILBasicBlock *bb,
|
|
SILValue wrappedAdjoint,
|
|
SILType optionalTy);
|
|
|
|
/// Accumulate adjoint of `wrappedAdjoint` into optionalBuffer.
|
|
void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb,
|
|
SILValue optionalBuffer,
|
|
SILValue wrappedAdjoint);
|
|
|
|
/// Accumulate adjoint of `wrappedAdjoint` into optionalValue.
|
|
void accumulateAdjointValueForOptional(SILBasicBlock *bb,
|
|
SILValue optionalValue,
|
|
SILValue wrappedAdjoint);
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Array literal initialization differentiation
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Given the adjoint value of an array initialized from an
|
|
/// `array.uninitialized_intrinsic` application and an array element index,
|
|
/// returns an `alloc_stack` containing the adjoint value of the array element
|
|
/// at the given index by applying `Array.TangentVector.subscript`.
|
|
AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint,
|
|
int eltIndex, SILLocation loc);
|
|
|
|
/// Given the adjoint value of an array initialized from an
|
|
/// `array.uninitialized_intrinsic` application, accumulates the adjoint
|
|
/// value's elements into the adjoint buffers of its element addresses.
|
|
void accumulateArrayLiteralElementAddressAdjoints(
|
|
SILBasicBlock *origBB, SILValue originalValue,
|
|
AdjointValue arrayAdjointValue, SILLocation loc);
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// CFG mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) {
|
|
return pullbackBBMap.lookup(originalBlock);
|
|
}
|
|
|
|
SILBasicBlock *getPullbackTrampolineBlock(SILBasicBlock *originalBlock,
|
|
SILBasicBlock *successorBlock) {
|
|
return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock});
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Debug info
|
|
//--------------------------------------------------------------------------//
|
|
|
|
const SILDebugScope *remapScope(const SILDebugScope *DS) {
|
|
return scopeCloner.getOrCreateClonedScope(DS);
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Debugging utilities
|
|
//--------------------------------------------------------------------------//
|
|
|
|
void printAdjointValueMapping() {
|
|
// Group original/adjoint values by basic block.
|
|
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, AdjointValue>> tmp;
|
|
for (auto pair : valueMap) {
|
|
auto origPair = pair.first;
|
|
auto *origBB = origPair.first;
|
|
auto origValue = origPair.second;
|
|
auto adjValue = pair.second;
|
|
tmp[origBB].insert({origValue, adjValue});
|
|
}
|
|
// Print original/adjoint values per basic block.
|
|
auto &s = getADDebugStream() << "Adjoint value mapping:\n";
|
|
for (auto &origBB : getOriginal()) {
|
|
if (!pullbackBBMap.count(&origBB))
|
|
continue;
|
|
auto bbValueMap = tmp[&origBB];
|
|
s << "bb" << origBB.getDebugID();
|
|
s << " (size " << bbValueMap.size() << "):\n";
|
|
for (auto valuePair : bbValueMap) {
|
|
auto origValue = valuePair.first;
|
|
auto adjValue = valuePair.second;
|
|
s << "ORIG: " << origValue;
|
|
s << "ADJ: " << adjValue << '\n';
|
|
}
|
|
s << '\n';
|
|
}
|
|
}
|
|
|
|
void printAdjointBufferMapping() {
|
|
// Group original/adjoint buffers by basic block.
|
|
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, SILValue>> tmp;
|
|
for (auto pair : bufferMap) {
|
|
auto origPair = pair.first;
|
|
auto *origBB = origPair.first;
|
|
auto origBuf = origPair.second;
|
|
auto adjBuf = pair.second;
|
|
tmp[origBB][origBuf] = adjBuf;
|
|
}
|
|
// Print original/adjoint buffers per basic block.
|
|
auto &s = getADDebugStream() << "Adjoint buffer mapping:\n";
|
|
for (auto &origBB : getOriginal()) {
|
|
if (!pullbackBBMap.count(&origBB))
|
|
continue;
|
|
auto bbBufferMap = tmp[&origBB];
|
|
s << "bb" << origBB.getDebugID();
|
|
s << " (size " << bbBufferMap.size() << "):\n";
|
|
for (auto valuePair : bbBufferMap) {
|
|
auto origBuf = valuePair.first;
|
|
auto adjBuf = valuePair.second;
|
|
s << "ORIG: " << origBuf;
|
|
s << "ADJ: " << adjBuf << '\n';
|
|
}
|
|
s << '\n';
|
|
}
|
|
}
|
|
|
|
public:
|
|
//--------------------------------------------------------------------------//
|
|
// Entry point
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Performs pullback generation on the empty pullback function. Returns true
|
|
/// if any error occurs.
|
|
bool run();
|
|
|
|
/// Performs pullback generation on the empty pullback function, given that
|
|
/// the original function is a "semantic member accessor".
|
|
///
|
|
/// "Semantic member accessors" are attached to member properties that have a
|
|
/// corresponding tangent stored property in the parent `TangentVector` type.
|
|
/// These accessors have special-case pullback generation based on their
|
|
/// semantic behavior.
|
|
///
|
|
/// Returns true if any error occurs.
|
|
bool runForSemanticMemberAccessor();
|
|
bool runForSemanticMemberGetter();
|
|
bool runForSemanticMemberSetter();
|
|
|
|
/// If original result is non-varied, it will always have a zero derivative.
|
|
/// Skip full pullback generation and simply emit zero derivatives for wrt
|
|
/// parameters.
|
|
void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult);
|
|
|
|
/// Public helper so that our users can get the underlying newly created
|
|
/// function.
|
|
SILFunction &getPullback() const { return vjpCloner.getPullback(); }
|
|
|
|
using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>;
|
|
|
|
/// Determines the pullback successor block for a given original block and one
|
|
/// of its predecessors. When a trampoline block is necessary, emits code into
|
|
/// the trampoline block to trampoline the original block's active value's
|
|
/// adjoint values.
|
|
///
|
|
/// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint
|
|
/// values to the pullback successor blocks in which they are used. This
|
|
/// allows us to release those values in pullback successor blocks that do not
|
|
/// use them.
|
|
SILBasicBlock *
|
|
buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB,
|
|
llvm::SmallDenseMap<SILValue, TrampolineBlockSet>
|
|
&pullbackTrampolineBlockMap);
|
|
|
|
/// Emits pullback code in the corresponding pullback block.
|
|
void visitSILBasicBlock(SILBasicBlock *bb);
|
|
|
|
void visit(SILInstruction *inst) {
|
|
if (errorOccurred)
|
|
return;
|
|
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "PullbackCloner visited:\n[ORIG]" << *inst);
|
|
#ifndef NDEBUG
|
|
auto beforeInsertion = std::prev(builder.getInsertionPoint());
|
|
#endif
|
|
SILInstructionVisitor::visit(inst);
|
|
LLVM_DEBUG({
|
|
auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback (pb bb" <<
|
|
builder.getInsertionBB()->getDebugID() << "):\n";
|
|
auto afterInsertion = builder.getInsertionPoint();
|
|
for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
|
|
s << *it;
|
|
});
|
|
}
|
|
|
|
/// Fallback instruction visitor for unhandled instructions.
|
|
/// Emit a general non-differentiability diagnostic.
|
|
void visitSILInstruction(SILInstruction *inst) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Unhandled instruction in PullbackCloner: " << *inst);
|
|
getContext().emitNondifferentiabilityError(
|
|
inst, getInvoker(), diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
}
|
|
|
|
/// Handle `apply` instruction.
|
|
/// Original: (y0, y1, ...) = apply @fn (x0, x1, ...)
|
|
/// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...)
|
|
void visitApplyInst(ApplyInst *ai) {
|
|
assert(getPullbackInfo().shouldDifferentiateApplySite(ai));
|
|
|
|
// Skip `array.uninitialized_intrinsic` applications, which have special
|
|
// `store` and `copy_addr` support.
|
|
if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC))
|
|
return;
|
|
|
|
auto loc = ai->getLoc();
|
|
auto *bb = ai->getParent();
|
|
// Handle `array.finalize_intrinsic` applications.
|
|
// `array.finalize_intrinsic` semantically behaves like an identity
|
|
// function.
|
|
if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) {
|
|
assert(ai->getNumArguments() == 1 &&
|
|
"Expected intrinsic to have one operand");
|
|
// Accumulate result's adjoint into argument's adjoint.
|
|
auto adjResult = getAdjointValue(bb, ai);
|
|
auto origArg = ai->getArgumentsWithoutIndirectResults().front();
|
|
addAdjointValue(bb, origArg, adjResult, loc);
|
|
return;
|
|
}
|
|
|
|
buildPullbackCall(ai);
|
|
}
|
|
|
|
void buildPullbackCall(FullApplySite fai) {
|
|
auto loc = fai->getLoc();
|
|
auto *bb = fai->getParent();
|
|
|
|
// Replace a call to a function with a call to its pullback.
|
|
auto &nestedApplyInfo = getContext().getNestedApplyInfo();
|
|
auto applyInfoLookup = nestedApplyInfo.find(fai);
|
|
// If no `NestedApplyInfo` was found, then this task doesn't need to be
|
|
// differentiated.
|
|
if (applyInfoLookup == nestedApplyInfo.end()) {
|
|
// Must not be active.
|
|
// TODO: Do we need to check token result for begin_apply?
|
|
SILValue result = fai.getResult();
|
|
assert(!result || !getActivityInfo().isActive(result, getConfig()));
|
|
return;
|
|
}
|
|
auto &applyInfo = applyInfoLookup->getSecond();
|
|
|
|
// Get the original result of the `apply` instruction.
|
|
const auto &conv = fai.getSubstCalleeConv();
|
|
SmallVector<SILValue, 8> origDirectResults;
|
|
forEachApplyDirectResult(fai, [&](SILValue directResult) {
|
|
origDirectResults.push_back(directResult);
|
|
});
|
|
SmallVector<SILValue, 8> origAllResults;
|
|
collectAllActualResultsInTypeOrder(fai, origDirectResults, origAllResults);
|
|
// Append semantic result arguments after original results.
|
|
for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) {
|
|
unsigned argIdx = fai.getNumIndirectSILResults() + paramIdx;
|
|
auto paramInfo = conv.getParamInfoForSILArg(argIdx);
|
|
if (!paramInfo.isAutoDiffSemanticResult())
|
|
continue;
|
|
origAllResults.push_back(
|
|
fai.getArgumentsWithoutIndirectResults()[paramIdx]);
|
|
}
|
|
|
|
// Get callee pullback arguments.
|
|
SmallVector<SILValue, 8> args;
|
|
|
|
// Handle callee pullback indirect results.
|
|
// Create local allocations for these and destroy them after the call.
|
|
auto pullback = getPullbackTupleElement(fai);
|
|
auto pullbackType =
|
|
remapType(pullback->getType()).castTo<SILFunctionType>();
|
|
|
|
auto actualPullbackType = applyInfo.originalPullbackType
|
|
? *applyInfo.originalPullbackType
|
|
: pullbackType;
|
|
actualPullbackType = actualPullbackType->getUnsubstitutedType(getModule());
|
|
SmallVector<AllocStackInst *, 4> pullbackIndirectResults;
|
|
for (auto indRes : actualPullbackType->getIndirectFormalResults()) {
|
|
auto *alloc = builder.createAllocStack(
|
|
loc, remapType(indRes.getSILStorageInterfaceType()));
|
|
pullbackIndirectResults.push_back(alloc);
|
|
args.push_back(alloc);
|
|
}
|
|
|
|
// Collect callee pullback formal arguments.
|
|
unsigned firstSemanticParamResultIdx = conv.getResults().size();
|
|
unsigned firstYieldResultIndex = firstSemanticParamResultIdx +
|
|
conv.getNumAutoDiffSemanticResultParameters();
|
|
for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) {
|
|
if (resultIndex >= firstYieldResultIndex)
|
|
continue;
|
|
assert(resultIndex < origAllResults.size());
|
|
auto origResult = origAllResults[resultIndex];
|
|
// Get the seed (i.e. adjoint value of the original result).
|
|
SILValue seed;
|
|
switch (getTangentValueCategory(origResult)) {
|
|
case SILValueCategory::Object:
|
|
seed = materializeAdjointDirect(getAdjointValue(bb, origResult), loc);
|
|
break;
|
|
case SILValueCategory::Address:
|
|
seed = getAdjointBuffer(bb, origResult);
|
|
break;
|
|
}
|
|
args.push_back(seed);
|
|
}
|
|
|
|
// If callee pullback was reabstracted in VJP, reabstract callee pullback.
|
|
if (applyInfo.originalPullbackType) {
|
|
auto toType = *applyInfo.originalPullbackType;
|
|
SILOptFunctionBuilder fb(getContext().getTransform());
|
|
if (toType->isCoroutine())
|
|
pullback = reabstractCoroutine(
|
|
builder, fb, loc, pullback, toType,
|
|
[this](SubstitutionMap subs) -> SubstitutionMap {
|
|
return this->remapSubstitutionMap(subs);
|
|
});
|
|
else
|
|
pullback = reabstractFunction(
|
|
builder, fb, loc, pullback, toType,
|
|
[this](SubstitutionMap subs) -> SubstitutionMap {
|
|
return this->remapSubstitutionMap(subs);
|
|
});
|
|
}
|
|
|
|
// Call the callee pullback.
|
|
FullApplySite pullbackCall;
|
|
SmallVector<SILValue, 8> dirResults;
|
|
if (actualPullbackType->isCoroutine()) {
|
|
pullbackCall = builder.createBeginApply(loc, pullback, SubstitutionMap(),
|
|
args);
|
|
// Record pullback and begin_apply token: the pullback will be consumed
|
|
// after end_apply.
|
|
applyInfo.pullback = pullback;
|
|
applyInfo.beginApplyToken = cast<BeginApplyInst>(pullbackCall)->getTokenResult();
|
|
} else {
|
|
pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
|
|
args);
|
|
builder.emitDestroyValueOperation(loc, pullback);
|
|
// Extract all results from `pullbackCall`.
|
|
extractAllElements(cast<ApplyInst>(pullbackCall), builder, dirResults);
|
|
}
|
|
|
|
// Get all results in type-defined order.
|
|
SmallVector<SILValue, 8> allResults;
|
|
collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults);
|
|
|
|
LLVM_DEBUG({
|
|
auto &s = getADDebugStream();
|
|
s << "All results of the nested pullback call:\n";
|
|
llvm::for_each(allResults, [&](SILValue v) { s << v; });
|
|
});
|
|
|
|
// Accumulate adjoints for original differentiation parameters.
|
|
auto allResultsIt = allResults.begin();
|
|
for (unsigned i : applyInfo.config.parameterIndices->getIndices()) {
|
|
unsigned argIdx = fai.getNumIndirectSILResults() + i;
|
|
auto origArg = fai.getArgument(argIdx);
|
|
// Skip adjoint accumulation for semantic results arguments.
|
|
auto paramInfo = fai.getSubstCalleeConv().getParamInfoForSILArg(argIdx);
|
|
if (paramInfo.isAutoDiffSemanticResult())
|
|
continue;
|
|
auto tan = *allResultsIt++;
|
|
if (tan->getType().isAddress()) {
|
|
addToAdjointBuffer(bb, origArg, tan, loc);
|
|
} else {
|
|
if (origArg->getType().isAddress()) {
|
|
auto *tmpBuf = builder.createAllocStack(loc, tan->getType());
|
|
builder.emitStoreValueOperation(loc, tan, tmpBuf,
|
|
StoreOwnershipQualifier::Init);
|
|
addToAdjointBuffer(bb, origArg, tmpBuf, loc);
|
|
builder.emitDestroyAddrAndFold(loc, tmpBuf);
|
|
builder.createDeallocStack(loc, tmpBuf);
|
|
} else {
|
|
recordTemporary(tan);
|
|
addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Propagate adjoints for yields
|
|
if (actualPullbackType->isCoroutine()) {
|
|
auto originalYields = cast<BeginApplyInst>(fai)->getYieldedValues();
|
|
auto pullbackYields = cast<BeginApplyInst>(pullbackCall)->getYieldedValues();
|
|
assert(originalYields.size() == pullbackYields.size());
|
|
|
|
for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) {
|
|
if (resultIndex < firstYieldResultIndex)
|
|
continue;
|
|
|
|
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
|
|
setAdjointBuffer(bb, originalYields[yieldResultIndex], pullbackYields[yieldResultIndex]);
|
|
}
|
|
}
|
|
|
|
// Destroy unused pullback direct results. Needed for pullback results from
|
|
// VJPs extracted from `@differentiable` function callees, where the
|
|
// `@differentiable` function's differentiation parameter indices are a
|
|
// superset of the active `apply` parameter indices.
|
|
while (allResultsIt != allResults.end()) {
|
|
auto unusedPullbackDirectResult = *allResultsIt++;
|
|
if (unusedPullbackDirectResult->getType().isAddress())
|
|
continue;
|
|
builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult);
|
|
}
|
|
|
|
// Destroy and deallocate pullback indirect results.
|
|
for (auto *alloc : llvm::reverse(pullbackIndirectResults)) {
|
|
builder.emitDestroyAddrAndFold(loc, alloc);
|
|
builder.createDeallocStack(loc, alloc);
|
|
}
|
|
}
|
|
|
|
void visitAbortApplyInst(AbortApplyInst *aai) {
|
|
BeginApplyInst *bai = aai->getBeginApply();
|
|
assert(getPullbackInfo().shouldDifferentiateApplySite(bai));
|
|
|
|
// abort_apply differentiation is not yet supported.
|
|
getContext().emitNondifferentiabilityError(
|
|
bai, getInvoker(), diag::autodiff_coroutines_not_supported);
|
|
errorOccurred = true;
|
|
}
|
|
|
|
void visitEndApplyInst(EndApplyInst *eai) {
|
|
BeginApplyInst *bai = eai->getBeginApply();
|
|
assert(getPullbackInfo().shouldDifferentiateApplySite(bai));
|
|
|
|
// Replace a call to a function with a call to its pullback.
|
|
auto &nestedApplyInfo = getContext().getNestedApplyInfo();
|
|
auto applyInfoLookup = nestedApplyInfo.find(bai);
|
|
// If no `NestedApplyInfo` was found, then this task doesn't need to be
|
|
// differentiated.
|
|
if (applyInfoLookup == nestedApplyInfo.end()) {
|
|
// Must not be active.
|
|
assert(!getActivityInfo().isActive(bai->getTokenResult(), getConfig()));
|
|
assert(!getActivityInfo().isActive(eai, getConfig()));
|
|
return;
|
|
}
|
|
|
|
buildPullbackCall(bai);
|
|
}
|
|
|
|
void visitBeginApplyInst(BeginApplyInst *bai) {
|
|
assert(getPullbackInfo().shouldDifferentiateApplySite(bai));
|
|
|
|
auto &nestedApplyInfo = getContext().getNestedApplyInfo();
|
|
auto applyInfoLookup = nestedApplyInfo.find(bai);
|
|
// If no `NestedApplyInfo` was found, then this task doesn't need to be
|
|
// differentiated.
|
|
if (applyInfoLookup == nestedApplyInfo.end()) {
|
|
// Must not be active.
|
|
assert(!getActivityInfo().isActive(bai->getTokenResult(), getConfig()));
|
|
return;
|
|
}
|
|
auto applyInfo = applyInfoLookup->getSecond();
|
|
|
|
auto loc = bai->getLoc();
|
|
builder.createEndApply(loc, applyInfo.beginApplyToken,
|
|
SILType::getEmptyTupleType(getASTContext()));
|
|
builder.emitDestroyValueOperation(loc, applyInfo.pullback);
|
|
}
|
|
|
|
/// Handle `struct` instruction.
|
|
/// Original: y = struct (x0, x1, x2, ...)
|
|
/// Adjoint: adj[x0] += struct_extract adj[y], #x0
|
|
/// adj[x1] += struct_extract adj[y], #x1
|
|
/// adj[x2] += struct_extract adj[y], #x2
|
|
/// ...
|
|
void visitStructInst(StructInst *si) {
|
|
auto *bb = si->getParent();
|
|
auto loc = si->getLoc();
|
|
auto *structDecl = si->getStructDecl();
|
|
switch (getTangentValueCategory(si)) {
|
|
case SILValueCategory::Object: {
|
|
auto av = getAdjointValue(bb, si);
|
|
switch (av.getKind()) {
|
|
case AdjointValueKind::Zero: {
|
|
for (auto *field : structDecl->getStoredProperties()) {
|
|
auto fv = si->getFieldValue(field);
|
|
addAdjointValue(
|
|
bb, fv,
|
|
makeZeroAdjointValue(getRemappedTangentType(fv->getType())), loc);
|
|
}
|
|
break;
|
|
}
|
|
case AdjointValueKind::Concrete: {
|
|
auto adjStruct = materializeAdjointDirect(std::move(av), loc);
|
|
auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct);
|
|
|
|
// Find the struct `TangentVector` type.
|
|
auto structTy = remapType(si->getType()).getASTType();
|
|
#ifndef NDEBUG
|
|
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
|
|
assert(!getTypeLowering(tangentVectorTy).isAddressOnly());
|
|
assert(tangentVectorTy->getStructOrBoundGenericStruct());
|
|
#endif
|
|
|
|
// Accumulate adjoints for the fields of the `struct` operand.
|
|
unsigned fieldIndex = 0;
|
|
for (auto it = structDecl->getStoredProperties().begin();
|
|
it != structDecl->getStoredProperties().end();
|
|
++it, ++fieldIndex) {
|
|
VarDecl *field = *it;
|
|
if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
|
|
continue;
|
|
// Find the corresponding field in the tangent space.
|
|
auto *tanField = getTangentStoredProperty(
|
|
getContext(), field, structTy, loc, getInvoker());
|
|
if (!tanField) {
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
auto tanElt = dti->getResult(fieldIndex);
|
|
addAdjointValue(bb, si->getFieldValue(field),
|
|
makeConcreteAdjointValue(tanElt), si->getLoc());
|
|
}
|
|
break;
|
|
}
|
|
case AdjointValueKind::Aggregate: {
|
|
// Note: All user-called initializations go through the calls to the
|
|
// initializer, and synthesized initializers only have one level of
|
|
// struct formation which will not result into any aggregate adjoint
|
|
// values.
|
|
llvm_unreachable(
|
|
"Aggregate adjoint values should not occur for `struct` "
|
|
"instructions");
|
|
}
|
|
case AdjointValueKind::AddElement: {
|
|
llvm_unreachable(
|
|
"Adjoint of `StructInst` cannot be of kind `AddElement`");
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(bb, si);
|
|
// Find the struct `TangentVector` type.
|
|
auto structTy = remapType(si->getType()).getASTType();
|
|
// Accumulate adjoints for the fields of the `struct` operand.
|
|
unsigned fieldIndex = 0;
|
|
for (auto it = structDecl->getStoredProperties().begin();
|
|
it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) {
|
|
VarDecl *field = *it;
|
|
if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
|
|
continue;
|
|
// Find the corresponding field in the tangent space.
|
|
auto *tanField = getTangentStoredProperty(getContext(), field, structTy,
|
|
loc, getInvoker());
|
|
if (!tanField) {
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
auto *adjFieldBuf =
|
|
builder.createStructElementAddr(loc, adjBuf, tanField);
|
|
auto fieldValue = si->getFieldValue(field);
|
|
switch (getTangentValueCategory(fieldValue)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjField = builder.emitLoadValueOperation(
|
|
loc, adjFieldBuf, LoadOwnershipQualifier::Copy);
|
|
recordTemporary(adjField);
|
|
addAdjointValue(bb, fieldValue, makeConcreteAdjointValue(adjField),
|
|
loc);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
addToAdjointBuffer(bb, fieldValue, adjFieldBuf, loc);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
} break;
|
|
}
|
|
}
|
|
|
|
/// Handle `struct_extract` instruction.
|
|
/// Original: y = struct_extract x, #field
|
|
/// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
|
|
/// ^~~~~~~
|
|
/// field in tangent space corresponding to #field
|
|
void visitStructExtractInst(StructExtractInst *sei) {
|
|
auto *bb = sei->getParent();
|
|
auto loc = getValidLocation(sei);
|
|
// Find the corresponding field in the tangent space.
|
|
auto structTy = remapType(sei->getOperand()->getType()).getASTType();
|
|
auto *tanField =
|
|
getTangentStoredProperty(getContext(), sei, structTy, getInvoker());
|
|
assert(tanField && "Invalid projections should have been diagnosed");
|
|
// Check the `struct_extract` operand's value tangent category.
|
|
switch (getTangentValueCategory(sei->getOperand())) {
|
|
case SILValueCategory::Object: {
|
|
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
|
|
auto tangentVectorSILTy =
|
|
SILType::getPrimitiveObjectType(tangentVectorTy);
|
|
auto eltAdj = getAdjointValue(bb, sei);
|
|
|
|
switch (eltAdj.getKind()) {
|
|
case AdjointValueKind::Zero: {
|
|
addAdjointValue(bb, sei->getOperand(),
|
|
makeZeroAdjointValue(tangentVectorSILTy), loc);
|
|
break;
|
|
}
|
|
case AdjointValueKind::Aggregate:
|
|
case AdjointValueKind::Concrete:
|
|
case AdjointValueKind::AddElement: {
|
|
auto baseAdj = makeZeroAdjointValue(tangentVectorSILTy);
|
|
addAdjointValue(bb, sei->getOperand(),
|
|
makeAddElementAdjointValue(baseAdj, eltAdj, tanField),
|
|
loc);
|
|
break;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBase = getAdjointBuffer(bb, sei->getOperand());
|
|
auto *adjBaseElt =
|
|
builder.createStructElementAddr(loc, adjBase, tanField);
|
|
// Check the `struct_extract`'s value tangent category.
|
|
switch (getTangentValueCategory(sei)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjElt = getAdjointValue(bb, sei);
|
|
auto concreteAdjElt = materializeAdjointDirect(adjElt, loc);
|
|
auto concreteAdjEltCopy =
|
|
builder.emitCopyValueOperation(loc, concreteAdjElt);
|
|
auto *alloc = builder.createAllocStack(loc, adjElt.getType());
|
|
builder.emitStoreValueOperation(loc, concreteAdjEltCopy, alloc,
|
|
StoreOwnershipQualifier::Init);
|
|
builder.emitInPlaceAdd(loc, adjBaseElt, alloc);
|
|
builder.createDestroyAddr(loc, alloc);
|
|
builder.createDeallocStack(loc, alloc);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjElt = getAdjointBuffer(bb, sei);
|
|
builder.emitInPlaceAdd(loc, adjBaseElt, adjElt);
|
|
break;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `ref_element_addr` instruction.
|
|
/// Original: y = ref_element_addr x, <n>
|
|
/// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
|
|
/// ^~~~~~~
|
|
/// field in tangent space corresponding to #field
|
|
void visitRefElementAddrInst(RefElementAddrInst *reai) {
|
|
auto *bb = reai->getParent();
|
|
auto loc = reai->getLoc();
|
|
auto adjBuf = getAdjointBuffer(bb, reai);
|
|
auto classOperand = reai->getOperand();
|
|
auto classType = remapType(reai->getOperand()->getType()).getASTType();
|
|
auto *tanField =
|
|
getTangentStoredProperty(getContext(), reai, classType, getInvoker());
|
|
assert(tanField && "Invalid projections should have been diagnosed");
|
|
switch (getTangentValueCategory(classOperand)) {
|
|
case SILValueCategory::Object: {
|
|
auto classTy = remapType(classOperand->getType()).getASTType();
|
|
auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType();
|
|
auto tangentVectorSILTy =
|
|
SILType::getPrimitiveObjectType(tangentVectorTy);
|
|
auto *tangentVectorDecl =
|
|
tangentVectorTy->getStructOrBoundGenericStruct();
|
|
// Accumulate adjoint for the `ref_element_addr` operand.
|
|
SmallVector<AdjointValue, 8> eltVals;
|
|
for (auto *field : tangentVectorDecl->getStoredProperties()) {
|
|
if (field == tanField) {
|
|
auto adjElt = builder.emitLoadValueOperation(
|
|
reai->getLoc(), adjBuf, LoadOwnershipQualifier::Copy);
|
|
eltVals.push_back(makeConcreteAdjointValue(adjElt));
|
|
recordTemporary(adjElt);
|
|
} else {
|
|
auto substMap = tangentVectorTy->getMemberSubstitutionMap(
|
|
field);
|
|
auto fieldTy = field->getInterfaceType().subst(substMap);
|
|
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
|
|
assert(fieldSILTy.isObject());
|
|
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
|
|
}
|
|
}
|
|
addAdjointValue(bb, classOperand,
|
|
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
|
|
loc);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBufClass = getAdjointBuffer(bb, classOperand);
|
|
auto adjBufElt =
|
|
builder.createStructElementAddr(loc, adjBufClass, tanField);
|
|
builder.emitInPlaceAdd(loc, adjBufElt, adjBuf);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `tuple` instruction.
|
|
/// Original: y = tuple (x0, x1, x2, ...)
|
|
/// Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y]
|
|
/// ^~~
|
|
/// excluding non-differentiable elements
|
|
void visitTupleInst(TupleInst *ti) {
|
|
auto *bb = ti->getParent();
|
|
auto loc = ti->getLoc();
|
|
switch (getTangentValueCategory(ti)) {
|
|
case SILValueCategory::Object: {
|
|
auto av = getAdjointValue(bb, ti);
|
|
switch (av.getKind()) {
|
|
case AdjointValueKind::Zero:
|
|
for (auto elt : ti->getElements()) {
|
|
if (!getTangentSpace(elt->getType().getASTType()))
|
|
continue;
|
|
addAdjointValue(
|
|
bb, elt,
|
|
makeZeroAdjointValue(getRemappedTangentType(elt->getType())),
|
|
loc);
|
|
}
|
|
break;
|
|
case AdjointValueKind::Concrete: {
|
|
auto adjVal = av.getConcreteValue();
|
|
auto adjValCopy = builder.emitCopyValueOperation(loc, adjVal);
|
|
SmallVector<SILValue, 4> adjElts;
|
|
if (!adjVal->getType().getAs<TupleType>()) {
|
|
recordTemporary(adjValCopy);
|
|
adjElts.push_back(adjValCopy);
|
|
} else {
|
|
auto *dti = builder.createDestructureTuple(loc, adjValCopy);
|
|
for (auto adjElt : dti->getResults())
|
|
recordTemporary(adjElt);
|
|
adjElts.append(dti->getResults().begin(), dti->getResults().end());
|
|
}
|
|
// Accumulate adjoints for `tuple` operands, skipping the
|
|
// non-`Differentiable` ones.
|
|
unsigned adjIndex = 0;
|
|
for (auto i : range(ti->getNumOperands())) {
|
|
if (!getTangentSpace(ti->getOperand(i)->getType().getASTType()))
|
|
continue;
|
|
auto adjElt = adjElts[adjIndex++];
|
|
addAdjointValue(bb, ti->getOperand(i),
|
|
makeConcreteAdjointValue(adjElt), loc);
|
|
}
|
|
break;
|
|
}
|
|
case AdjointValueKind::Aggregate: {
|
|
unsigned adjIndex = 0;
|
|
for (auto i : range(ti->getElements().size())) {
|
|
if (!getTangentSpace(ti->getElement(i)->getType().getASTType()))
|
|
continue;
|
|
addAdjointValue(bb, ti->getElement(i),
|
|
av.getAggregateElement(adjIndex++), loc);
|
|
}
|
|
break;
|
|
}
|
|
case AdjointValueKind::AddElement: {
|
|
llvm_unreachable(
|
|
"Adjoint of `TupleInst` cannot be of kind `AddElement`");
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(bb, ti);
|
|
// Accumulate adjoints for `tuple` operands, skipping the
|
|
// non-`Differentiable` ones.
|
|
unsigned adjIndex = 0;
|
|
for (auto i : range(ti->getNumOperands())) {
|
|
if (!getTangentSpace(ti->getOperand(i)->getType().getASTType()))
|
|
continue;
|
|
auto adjBufElt =
|
|
builder.createTupleElementAddr(loc, adjBuf, adjIndex++);
|
|
auto adjElt = getAdjointBuffer(bb, ti->getOperand(i));
|
|
builder.emitInPlaceAdd(loc, adjElt, adjBufElt);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `tuple_extract` instruction.
|
|
/// Original: y = tuple_extract x, <n>
|
|
/// Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0)
|
|
/// ^~~~~~
|
|
/// n'-th element, where n' is tuple tangent space
|
|
/// index corresponding to n
|
|
void visitTupleExtractInst(TupleExtractInst *tei) {
|
|
auto *bb = tei->getParent();
|
|
auto loc = tei->getLoc();
|
|
auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
|
|
auto eltAdj = getAdjointValue(bb, tei);
|
|
switch (eltAdj.getKind()) {
|
|
case AdjointValueKind::Zero: {
|
|
addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy),
|
|
loc);
|
|
break;
|
|
}
|
|
case AdjointValueKind::Aggregate:
|
|
case AdjointValueKind::Concrete:
|
|
case AdjointValueKind::AddElement: {
|
|
auto tupleTy = tei->getTupleType();
|
|
auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>();
|
|
if (!tupleTanTupleTy) {
|
|
addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
|
|
break;
|
|
}
|
|
|
|
unsigned elements = 0;
|
|
for (unsigned i : range(tupleTy->getNumElements())) {
|
|
if (!getTangentSpace(
|
|
tupleTy->getElement(i).getType()->getCanonicalType()))
|
|
continue;
|
|
elements++;
|
|
}
|
|
|
|
if (elements == 1) {
|
|
addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
|
|
} else {
|
|
auto baseAdj = makeZeroAdjointValue(tupleTanTy);
|
|
addAdjointValue(
|
|
bb, tei->getOperand(),
|
|
makeAddElementAdjointValue(baseAdj, eltAdj, tei->getFieldIndex()),
|
|
loc);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `destructure_tuple` instruction.
|
|
/// Original: (y0, ..., yn) = destructure_tuple x
|
|
/// Adjoint: adj[x].0 += adj[y0]
|
|
/// ...
|
|
/// adj[x].n += adj[yn]
|
|
void visitDestructureTupleInst(DestructureTupleInst *dti) {
|
|
auto *bb = dti->getParent();
|
|
auto loc = dti->getLoc();
|
|
auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType());
|
|
// Check the `destructure_tuple` operand's value tangent category.
|
|
switch (getTangentValueCategory(dti->getOperand())) {
|
|
case SILValueCategory::Object: {
|
|
SmallVector<AdjointValue, 8> adjValues;
|
|
for (auto origElt : dti->getResults()) {
|
|
// Skip non-`Differentiable` tuple elements.
|
|
if (!getTangentSpace(remapType(origElt->getType()).getASTType()))
|
|
continue;
|
|
adjValues.push_back(getAdjointValue(bb, origElt));
|
|
}
|
|
// Handle tuple tangent type.
|
|
// Add adjoints for every tuple element that has a tangent space.
|
|
if (tupleTanTy.is<TupleType>()) {
|
|
assert(adjValues.size() > 1);
|
|
addAdjointValue(bb, dti->getOperand(),
|
|
makeAggregateAdjointValue(tupleTanTy, adjValues), loc);
|
|
}
|
|
// Handle non-tuple tangent type.
|
|
// Add adjoint for the single tuple element that has a tangent space.
|
|
else {
|
|
assert(adjValues.size() == 1);
|
|
addAdjointValue(bb, dti->getOperand(), adjValues.front(), loc);
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(bb, dti->getOperand());
|
|
unsigned adjIndex = 0;
|
|
for (auto origElt : dti->getResults()) {
|
|
// Skip non-`Differentiable` tuple elements.
|
|
if (!getTangentSpace(remapType(origElt->getType()).getASTType()))
|
|
continue;
|
|
// Handle tuple tangent type.
|
|
// Add adjoints for every tuple element that has a tangent space.
|
|
if (tupleTanTy.is<TupleType>()) {
|
|
auto adjEltBuf = getAdjointBuffer(bb, origElt);
|
|
auto adjBufElt =
|
|
builder.createTupleElementAddr(loc, adjBuf, adjIndex);
|
|
builder.emitInPlaceAdd(loc, adjBufElt, adjEltBuf);
|
|
}
|
|
// Handle non-tuple tangent type.
|
|
// Add adjoint for the single tuple element that has a tangent space.
|
|
else {
|
|
auto adjEltBuf = getAdjointBuffer(bb, origElt);
|
|
addToAdjointBuffer(bb, dti->getOperand(), adjEltBuf, loc);
|
|
}
|
|
++adjIndex;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `load` or `load_borrow` instruction
|
|
/// Original: y = load/load_borrow x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitLoadOperation(SingleValueInstruction *inst) {
|
|
assert(isa<LoadInst>(inst) || isa<LoadBorrowInst>(inst));
|
|
auto *bb = inst->getParent();
|
|
auto loc = inst->getLoc();
|
|
switch (getTangentValueCategory(inst)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjVal = materializeAdjointDirect(getAdjointValue(bb, inst), loc);
|
|
// Allocate a local buffer and store the adjoint value. This buffer will
|
|
// be used for accumulation into the adjoint buffer.
|
|
auto adjBuf = builder.createAllocStack(loc, adjVal->getType(), {},
|
|
DoesNotHaveDynamicLifetime,
|
|
IsNotLexical, IsNotFromVarDecl,
|
|
DoesNotUseMoveableValueDebugInfo,
|
|
/* skipVarDeclAssert = */ true);
|
|
auto copy = builder.emitCopyValueOperation(loc, adjVal);
|
|
builder.emitStoreValueOperation(loc, copy, adjBuf,
|
|
StoreOwnershipQualifier::Init);
|
|
// Accumulate the adjoint value in the local buffer into the adjoint
|
|
// buffer.
|
|
addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc);
|
|
builder.emitDestroyAddr(loc, adjBuf);
|
|
builder.createDeallocStack(loc, adjBuf);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(bb, inst);
|
|
addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
void visitLoadInst(LoadInst *li) { visitLoadOperation(li); }
|
|
void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); }
|
|
|
|
/// Handle `store` or `store_borrow` instruction.
|
|
/// Original: store/store_borrow x to y
|
|
/// Adjoint: adj[x] += load adj[y]; adj[y] = 0
|
|
void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc,
|
|
SILValue origDest) {
|
|
auto adjBuf = getAdjointBuffer(bb, origDest);
|
|
switch (getTangentValueCategory(origSrc)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjVal = builder.emitLoadValueOperation(
|
|
loc, adjBuf, LoadOwnershipQualifier::Take);
|
|
recordTemporary(adjVal);
|
|
addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc);
|
|
builder.emitZeroIntoBuffer(loc, adjBuf, IsInitialization);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
addToAdjointBuffer(bb, origSrc, adjBuf, loc);
|
|
builder.emitZeroIntoBuffer(loc, adjBuf, IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
void visitStoreInst(StoreInst *si) {
|
|
visitStoreOperation(si->getParent(), si->getLoc(), si->getSrc(),
|
|
si->getDest());
|
|
}
|
|
void visitStoreBorrowInst(StoreBorrowInst *sbi) {
|
|
visitStoreOperation(sbi->getParent(), sbi->getLoc(), sbi->getSrc(),
|
|
sbi);
|
|
}
|
|
|
|
/// Handle `copy_addr` instruction.
|
|
/// Original: copy_addr x to y
|
|
/// Adjoint: adj[x] += adj[y]; adj[y] = 0
|
|
void visitCopyAddrInst(CopyAddrInst *cai) {
|
|
auto *bb = cai->getParent();
|
|
auto adjDest = getAdjointBuffer(bb, cai->getDest());
|
|
addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc());
|
|
builder.emitZeroIntoBuffer(cai->getLoc(), adjDest, IsNotInitialization);
|
|
}
|
|
|
|
/// Handle any ownership instruction that deals with values: copy_value,
|
|
/// move_value, begin_borrow.
|
|
/// Original: y = copy_value x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitValueOwnershipInst(SingleValueInstruction *svi,
|
|
bool needZeroResAdj = false) {
|
|
assert(svi->getNumOperands() == 1);
|
|
auto *bb = svi->getParent();
|
|
switch (getTangentValueCategory(svi)) {
|
|
case SILValueCategory::Object: {
|
|
auto adj = getAdjointValue(bb, svi);
|
|
addAdjointValue(bb, svi->getOperand(0), adj, svi->getLoc());
|
|
if (needZeroResAdj) {
|
|
assert(svi->getNumResults() == 1);
|
|
SILValue val = svi->getResult(0);
|
|
setAdjointValue(
|
|
bb, val,
|
|
makeZeroAdjointValue(getRemappedTangentType(val->getType())));
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjDest = getAdjointBuffer(bb, svi);
|
|
addToAdjointBuffer(bb, svi->getOperand(0), adjDest, svi->getLoc());
|
|
builder.emitZeroIntoBuffer(svi->getLoc(), adjDest, IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `copy_value` instruction.
|
|
/// Original: y = copy_value x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitCopyValueInst(CopyValueInst *cvi) { visitValueOwnershipInst(cvi); }
|
|
|
|
/// Handle `begin_borrow` instruction.
|
|
/// Original: y = begin_borrow x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitBeginBorrowInst(BeginBorrowInst *bbi) {
|
|
visitValueOwnershipInst(bbi);
|
|
}
|
|
|
|
/// Handle `move_value` instruction.
|
|
/// Original: y = move_value x
|
|
/// Adjoint: adj[x] += adj[y]; adj[y] = 0
|
|
void visitMoveValueInst(MoveValueInst *mvi) {
|
|
switch (getTangentValueCategory(mvi)) {
|
|
case SILValueCategory::Address:
|
|
LLVM_DEBUG(getADDebugStream() << "AutoDiff does not support move_value with "
|
|
"SILValueCategory::Address");
|
|
getContext().emitNondifferentiabilityError(
|
|
mvi, getInvoker(), diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
case SILValueCategory::Object:
|
|
visitValueOwnershipInst(mvi, /*needZeroResAdj=*/true);
|
|
}
|
|
}
|
|
|
|
void visitEndInitLetRefInst(EndInitLetRefInst *eir) { visitValueOwnershipInst(eir); }
|
|
|
|
/// Handle `begin_access` instruction.
|
|
/// Original: y = begin_access x
|
|
/// Adjoint: nothing
|
|
void visitBeginAccessInst(BeginAccessInst *bai) {
|
|
// Check for non-differentiable writes.
|
|
if (bai->getAccessKind() == SILAccessKind::Modify) {
|
|
if (isa<GlobalAddrInst>(bai->getSource())) {
|
|
getContext().emitNondifferentiabilityError(
|
|
bai, getInvoker(),
|
|
diag::autodiff_cannot_differentiate_writes_to_global_variables);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
if (isa<ProjectBoxInst>(bai->getSource())) {
|
|
getContext().emitNondifferentiabilityError(
|
|
bai, getInvoker(),
|
|
diag::autodiff_cannot_differentiate_writes_to_mutable_captures);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `unconditional_checked_cast_addr` instruction.
|
|
/// Original: y = unconditional_checked_cast_addr x
|
|
/// Adjoint: adj[x] += unconditional_checked_cast_addr adj[y]
|
|
void visitUnconditionalCheckedCastAddrInst(
|
|
UnconditionalCheckedCastAddrInst *uccai) {
|
|
auto *bb = uccai->getParent();
|
|
auto adjDest = getAdjointBuffer(bb, uccai->getDest());
|
|
auto adjSrc = getAdjointBuffer(bb, uccai->getSrc());
|
|
auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType());
|
|
builder.createUnconditionalCheckedCastAddr(
|
|
uccai->getLoc(), uccai->getCheckedCastOptions(),
|
|
adjDest, adjDest->getType().getASTType(), castBuf,
|
|
adjSrc->getType().getASTType());
|
|
addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc());
|
|
builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf);
|
|
builder.createDeallocStack(uccai->getLoc(), castBuf);
|
|
builder.emitZeroIntoBuffer(uccai->getLoc(), adjDest, IsInitialization);
|
|
}
|
|
|
|
/// Handle `enum` instruction.
|
|
/// Original: y = enum $Enum, #Enum.some!enumelt, x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitEnumInst(EnumInst *ei) {
|
|
SILBasicBlock *bb = ei->getParent();
|
|
SILLocation loc = ei->getLoc();
|
|
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
|
|
|
|
// Only `Optional`-typed operands are supported for now. Diagnose all other
|
|
// enum operand types.
|
|
if (ei->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Unsupported enum type in PullbackCloner: " << *ei);
|
|
getContext().emitNondifferentiabilityError(
|
|
ei, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
|
|
auto adjOpt = getAdjointValue(bb, ei);
|
|
auto adjStruct = materializeAdjointDirect(adjOpt, loc);
|
|
|
|
VarDecl *adjOptVar =
|
|
getASTContext().getOptionalTanValueDecl(adjStruct->getType().getASTType());
|
|
auto *adjVal = builder.createStructExtract(loc, adjStruct, adjOptVar);
|
|
|
|
EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
|
|
auto *adjData = builder.createUncheckedEnumData(loc, adjVal, someElemDecl);
|
|
|
|
addAdjointValue(bb, ei->getOperand(), makeConcreteAdjointValue(adjData), loc);
|
|
}
|
|
|
|
/// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
|
|
/// instructions.
|
|
///
|
|
/// Original: x = init_enum_data_addr y : $*Enum, #Enum.Case
|
|
/// inject_enum_addr y
|
|
///
|
|
/// Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
|
|
void visitInjectEnumAddrInst(InjectEnumAddrInst *inject) {
|
|
SILBasicBlock *bb = inject->getParent();
|
|
SILValue origEnum = inject->getOperand();
|
|
|
|
// Only `Optional`-typed operands are supported for now. Diagnose all other
|
|
// enum operand types.
|
|
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
|
|
if (origEnum->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Unsupported enum type in PullbackCloner: " << *inject);
|
|
getContext().emitNondifferentiabilityError(
|
|
inject, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
|
|
// No associated value => no adjoint to propagate
|
|
if (!inject->getElement()->hasAssociatedValues())
|
|
return;
|
|
|
|
InitEnumDataAddrInst *origData = nullptr;
|
|
for (auto use : origEnum->getUses()) {
|
|
if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser())) {
|
|
// We need a more complicated analysis when init_enum_data_addr and
|
|
// inject_enum_addr are in different blocks, or there is more than one
|
|
// such instruction. Bail out for now.
|
|
if (origData || init->getParent() != bb) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Could not find a matching init_enum_data_addr for: "
|
|
<< *inject);
|
|
getContext().emitNondifferentiabilityError(
|
|
inject, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
|
|
origData = init;
|
|
}
|
|
}
|
|
|
|
SILValue adjDest = getAdjointBuffer(bb, origEnum);
|
|
VarDecl *adjOptVar =
|
|
getASTContext().getOptionalTanValueDecl(adjDest->getType().getASTType());
|
|
|
|
SILLocation loc = origData->getLoc();
|
|
StructElementAddrInst *adjOpt =
|
|
builder.createStructElementAddr(loc, adjDest, adjOptVar);
|
|
|
|
// unchecked_take_enum_data_addr is destructive, so copy
|
|
// Optional<T.TangentVector> to a new alloca.
|
|
AllocStackInst *adjOptCopy =
|
|
createFunctionLocalAllocation(adjOpt->getType(), loc);
|
|
builder.createCopyAddr(loc, adjOpt, adjOptCopy, IsNotTake,
|
|
IsInitialization);
|
|
// The Optional copy is invalidated, do not attempt to destroy it at the end
|
|
// of the pullback. The value returned from unchecked_take_enum_data_addr is
|
|
// destroyed in visitInitEnumDataAddrInst.
|
|
auto [_, inserted] = enumDataAdjCopies.try_emplace(origData, adjOptCopy);
|
|
assert(inserted && "expected single buffer");
|
|
|
|
EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
|
|
UncheckedTakeEnumDataAddrInst *adjData =
|
|
builder.createUncheckedTakeEnumDataAddr(loc, adjOptCopy, someElemDecl);
|
|
|
|
addToAdjointBuffer(bb, origData, adjData, loc);
|
|
}
|
|
|
|
/// Handle `init_enum_data_addr` instruction.
|
|
/// Destroy the value returned from `unchecked_take_enum_data_addr`.
|
|
void visitInitEnumDataAddrInst(InitEnumDataAddrInst *init) {
|
|
SILValue adjOptCopy = enumDataAdjCopies.at(init);
|
|
|
|
builder.emitDestroyAddr(init->getLoc(), adjOptCopy);
|
|
destroyedLocalAllocations.insert(adjOptCopy);
|
|
enumDataAdjCopies.erase(init);
|
|
}
|
|
|
|
/// Handle `unchecked_ref_cast` instruction.
|
|
/// Original: y = unchecked_ref_cast x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
/// (assuming adj[x] and adj[y] have the same type)
|
|
void visitUncheckedRefCastInst(UncheckedRefCastInst *urci) {
|
|
auto *bb = urci->getParent();
|
|
assert(urci->getOperand()->getType().isObject());
|
|
assert(getRemappedTangentType(urci->getOperand()->getType()) ==
|
|
getRemappedTangentType(urci->getType()) &&
|
|
"Operand/result must have the same `TangentVector` type");
|
|
switch (getTangentValueCategory(urci)) {
|
|
case SILValueCategory::Object: {
|
|
auto adj = getAdjointValue(bb, urci);
|
|
addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc());
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjDest = getAdjointBuffer(bb, urci);
|
|
addToAdjointBuffer(bb, urci->getOperand(), adjDest, urci->getLoc());
|
|
builder.emitZeroIntoBuffer(urci->getLoc(), adjDest, IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `upcast` instruction.
|
|
/// Original: y = upcast x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
/// (assuming adj[x] and adj[y] have the same type)
|
|
void visitUpcastInst(UpcastInst *ui) {
|
|
auto *bb = ui->getParent();
|
|
assert(ui->getOperand()->getType().isObject());
|
|
assert(getRemappedTangentType(ui->getOperand()->getType()) ==
|
|
getRemappedTangentType(ui->getType()) &&
|
|
"Operand/result must have the same `TangentVector` type");
|
|
switch (getTangentValueCategory(ui)) {
|
|
case SILValueCategory::Object: {
|
|
auto adj = getAdjointValue(bb, ui);
|
|
addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc());
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjDest = getAdjointBuffer(bb, ui);
|
|
addToAdjointBuffer(bb, ui->getOperand(), adjDest, ui->getLoc());
|
|
builder.emitZeroIntoBuffer(ui->getLoc(), adjDest, IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `unchecked_take_enum_data_addr` instruction.
|
|
/// Currently, only `Optional`-typed operands are supported.
|
|
/// Original: y = unchecked_take_enum_data_addr x : $*Enum, #Enum.Case
|
|
/// Adjoint: adj[x] += $Enum.TangentVector(adj[y])
|
|
void
|
|
visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *utedai) {
|
|
auto *bb = utedai->getParent();
|
|
auto adjDest = getAdjointBuffer(bb, utedai);
|
|
auto enumTy = utedai->getOperand()->getType();
|
|
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
|
|
// Only `Optional`-typed operands are supported for now. Diagnose all other
|
|
// enum operand types.
|
|
if (enumTy.getASTType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Unhandled instruction in PullbackCloner: " << *utedai);
|
|
getContext().emitNondifferentiabilityError(
|
|
utedai, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
accumulateAdjointForOptionalBuffer(bb, utedai->getOperand(), adjDest);
|
|
builder.emitZeroIntoBuffer(utedai->getLoc(), adjDest, IsNotInitialization);
|
|
}
|
|
|
|
#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);
|
|
#undef NOT_DIFFERENTIABLE
|
|
|
|
#define NO_ADJOINT(INST) \
|
|
void visit##INST##Inst(INST##Inst *inst) {}
|
|
// Terminators.
|
|
NO_ADJOINT(Return)
|
|
NO_ADJOINT(Branch)
|
|
NO_ADJOINT(CondBranch)
|
|
NO_ADJOINT(Yield)
|
|
|
|
// Address projections.
|
|
NO_ADJOINT(StructElementAddr)
|
|
NO_ADJOINT(TupleElementAddr)
|
|
|
|
// Array literal initialization address projections.
|
|
NO_ADJOINT(PointerToAddress)
|
|
NO_ADJOINT(IndexAddr)
|
|
|
|
// Memory allocation/access.
|
|
NO_ADJOINT(AllocStack)
|
|
NO_ADJOINT(DeallocStack)
|
|
NO_ADJOINT(EndAccess)
|
|
|
|
// Debugging/reference counting instructions.
|
|
NO_ADJOINT(DebugValue)
|
|
NO_ADJOINT(RetainValue)
|
|
NO_ADJOINT(RetainValueAddr)
|
|
NO_ADJOINT(ReleaseValue)
|
|
NO_ADJOINT(ReleaseValueAddr)
|
|
NO_ADJOINT(StrongRetain)
|
|
NO_ADJOINT(StrongRelease)
|
|
NO_ADJOINT(UnownedRetain)
|
|
NO_ADJOINT(UnownedRelease)
|
|
NO_ADJOINT(StrongRetainUnowned)
|
|
NO_ADJOINT(DestroyValue)
|
|
NO_ADJOINT(DestroyAddr)
|
|
|
|
// Value ownership.
|
|
NO_ADJOINT(EndBorrow)
|
|
#undef NO_ADJOINT
|
|
};
|
|
|
|
PullbackCloner::Implementation::Implementation(VJPCloner &vjpCloner)
|
|
: vjpCloner(vjpCloner), scopeCloner(getPullback()),
|
|
builder(getPullback(), getContext()),
|
|
localAllocBuilder(getPullback(), getContext()) {
|
|
// Get dominance and post-order info for the original function.
|
|
auto &passManager = getContext().getPassManager();
|
|
auto *domAnalysis = passManager.getAnalysis<DominanceAnalysis>();
|
|
auto *postDomAnalysis = passManager.getAnalysis<PostDominanceAnalysis>();
|
|
auto *postOrderAnalysis = passManager.getAnalysis<PostOrderAnalysis>();
|
|
auto *original = &vjpCloner.getOriginal();
|
|
domInfo = domAnalysis->get(original);
|
|
postDomInfo = postDomAnalysis->get(original);
|
|
postOrderInfo = postOrderAnalysis->get(original);
|
|
// Initialize `originalExitBlock`.
|
|
auto origExitIt = original->findReturnBB();
|
|
assert(origExitIt != original->end() &&
|
|
"Functions without returns must have been diagnosed");
|
|
originalExitBlock = &*origExitIt;
|
|
localAllocBuilder.setCurrentDebugScope(
|
|
remapScope(originalExitBlock->getTerminator()->getDebugScope()));
|
|
}
|
|
|
|
PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)
|
|
: impl(*new Implementation(vjpCloner)) {}
|
|
|
|
PullbackCloner::~PullbackCloner() { delete &impl; }
|
|
|
|
static SILValue getArrayValue(ApplyInst *ai) {
|
|
SILValue arrayValue;
|
|
for (auto use : ai->getUses()) {
|
|
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
|
|
if (!dti)
|
|
continue;
|
|
DEBUG_ASSERT(!arrayValue && "Array value already found");
|
|
// The first `destructure_tuple` result is the `Array` value.
|
|
arrayValue = dti->getResult(0);
|
|
#ifndef DEBUG_ASSERT_enabled
|
|
break;
|
|
#endif
|
|
}
|
|
ASSERT(arrayValue);
|
|
return arrayValue;
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Entry point
|
|
//--------------------------------------------------------------------------//
|
|
|
|
bool PullbackCloner::run() {
|
|
bool foundError = impl.run();
|
|
#ifndef NDEBUG
|
|
if (!foundError)
|
|
impl.getPullback().verify();
|
|
#endif
|
|
return foundError;
|
|
}
|
|
|
|
bool PullbackCloner::Implementation::run() {
|
|
PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal());
|
|
auto &original = getOriginal();
|
|
auto &pullback = getPullback();
|
|
auto pbLoc = getPullback().getLocation();
|
|
LLVM_DEBUG(getADDebugStream() << "Running PullbackCloner on\n" << original);
|
|
|
|
// Collect original formal results.
|
|
SmallVector<SILValue, 8> origFormalResults;
|
|
collectAllFormalResultsInTypeOrder(original, origFormalResults);
|
|
for (auto resultIndex : getConfig().resultIndices->getIndices()) {
|
|
auto origResult = origFormalResults[resultIndex];
|
|
// If original result is non-varied, it will always have a zero derivative.
|
|
// Skip full pullback generation and simply emit zero derivatives for wrt
|
|
// parameters.
|
|
//
|
|
// NOTE(TF-876): This shortcut is currently necessary for functions
|
|
// returning non-varied result with >1 basic block where some basic blocks
|
|
// have no dominated active values; control flow differentiation does not
|
|
// handle this case. See TF-876 for context.
|
|
if (!getActivityInfo().isVaried(origResult, getConfig().parameterIndices)) {
|
|
emitZeroDerivativesForNonvariedResult(origResult);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Collect dominated active values in original basic blocks.
|
|
// Adjoint values of dominated active values are passed as pullback block
|
|
// arguments.
|
|
DominanceOrder domOrder(original.getEntryBlock(), domInfo);
|
|
// Keep track of visited values.
|
|
SmallPtrSet<SILValue, 8> visited;
|
|
while (auto *bb = domOrder.getNext()) {
|
|
auto &bbActiveValues = activeValues[bb];
|
|
// If the current block has an immediate dominator, append the immediate
|
|
// dominator block's active values to the current block's active values.
|
|
if (auto *domNode = domInfo->getNode(bb)->getIDom()) {
|
|
auto &domBBActiveValues = activeValues[domNode->getBlock()];
|
|
bbActiveValues.append(domBBActiveValues.begin(), domBBActiveValues.end());
|
|
}
|
|
// If `v` is active and has not been visited, records it as an active value
|
|
// in the original basic block.
|
|
// For active values unsupported by differentiation, emits a diagnostic and
|
|
// returns true. Otherwise, returns false.
|
|
auto recordValueIfActive = [&](SILValue v) -> bool {
|
|
// If value is not active, skip.
|
|
if (!getActivityInfo().isActive(v, getConfig()))
|
|
return false;
|
|
// If active value has already been visited, skip.
|
|
if (visited.count(v))
|
|
return false;
|
|
// Mark active value as visited.
|
|
visited.insert(v);
|
|
|
|
// Diagnose unsupported active values.
|
|
auto type = v->getType();
|
|
// Do not emit remaining activity-related diagnostics for semantic member
|
|
// accessors, which have special-case pullback generation.
|
|
if (isSemanticMemberAccessor(&original))
|
|
return false;
|
|
// Diagnose active enum values. Differentiation of enum values requires
|
|
// special adjoint value handling and is not yet supported. Diagnose
|
|
// only the first active enum value to prevent too many diagnostics.
|
|
//
|
|
// Do not diagnose `Optional`-typed values, which will have special-case
|
|
// differentiation support.
|
|
if (type.getEnumOrBoundGenericEnum()) {
|
|
if (!type.getASTType()->isOptional()) {
|
|
getContext().emitNondifferentiabilityError(
|
|
v, getInvoker(), diag::autodiff_enums_unsupported);
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
}
|
|
// Diagnose unsupported stored property projections.
|
|
if (isa<StructExtractInst>(v) || isa<RefElementAddrInst>(v) ||
|
|
isa<StructElementAddrInst>(v)) {
|
|
auto *inst = cast<SingleValueInstruction>(v);
|
|
assert(inst->getNumOperands() == 1);
|
|
auto baseType = remapType(inst->getOperand(0)->getType()).getASTType();
|
|
if (!getTangentStoredProperty(getContext(), inst, baseType,
|
|
getInvoker())) {
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
}
|
|
// Skip address projections.
|
|
// Address projections do not need their own adjoint buffers; they
|
|
// become projections into their adjoint base buffer.
|
|
if (Projection::isAddressProjection(v))
|
|
return false;
|
|
|
|
// Co-routines borrow adjoint buffers for yields
|
|
if (isa_and_nonnull<BeginApplyInst>(v.getDefiningInstruction()))
|
|
return false;
|
|
|
|
// Check that active values are differentiable. Otherwise we may crash
|
|
// later when tangent space is required, but not available.
|
|
if (!getTangentSpace(remapType(type).getASTType())) {
|
|
getContext().emitNondifferentiabilityError(
|
|
v, getInvoker(), diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
|
|
// Record active value.
|
|
bbActiveValues.push_back(v);
|
|
return false;
|
|
};
|
|
// Record all active values in the basic block.
|
|
for (auto *arg : bb->getArguments())
|
|
if (recordValueIfActive(arg))
|
|
return true;
|
|
for (auto &inst : *bb) {
|
|
for (auto op : inst.getOperandValues())
|
|
if (recordValueIfActive(op))
|
|
return true;
|
|
for (auto result : inst.getResults())
|
|
if (recordValueIfActive(result))
|
|
return true;
|
|
}
|
|
domOrder.pushChildren(bb);
|
|
}
|
|
|
|
// Create pullback blocks and arguments, visiting original blocks using BFS
|
|
// starting from the original exit block. Unvisited original basic blocks
|
|
// (e.g unreachable blocks) are not relevant for pullback generation and thus
|
|
// ignored.
|
|
// The original blocks in traversal order for pullback generation.
|
|
SmallVector<SILBasicBlock *, 8> originalBlocks;
|
|
// The workqueue used for bookkeeping during the breadth-first traversal.
|
|
BasicBlockWorkqueue workqueue = {originalExitBlock};
|
|
|
|
// Perform BFS from the original exit block.
|
|
{
|
|
while (auto *BB = workqueue.pop()) {
|
|
originalBlocks.push_back(BB);
|
|
|
|
for (auto *nextBB : BB->getPredecessorBlocks()) {
|
|
// If there is no linear map tuple for predecessor BB, then BB is
|
|
// unreachable from function entry. Do not run pullback cloner on it.
|
|
if (getPullbackInfo().getLinearMapTupleType(nextBB))
|
|
workqueue.pushIfNotVisited(nextBB);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto *origBB : originalBlocks) {
|
|
auto *pullbackBB = pullback.createBasicBlock();
|
|
pullbackBBMap.insert({origBB, pullbackBB});
|
|
auto pbTupleLoweredType =
|
|
remapType(getPullbackInfo().getLinearMapTupleLoweredType(origBB));
|
|
// If the BB is the original exit, then the pullback block that we just
|
|
// created must be the pullback function's entry. For the pullback entry,
|
|
// create entry arguments and continue to the next block.
|
|
if (origBB == originalExitBlock) {
|
|
assert(pullbackBB->isEntry());
|
|
createEntryArguments(&pullback);
|
|
auto *origTerm = originalExitBlock->getTerminator();
|
|
builder.setCurrentDebugScope(remapScope(origTerm->getDebugScope()));
|
|
builder.setInsertionPoint(pullbackBB);
|
|
// Obtain the context object, if any, and the top-level subcontext, i.e.
|
|
// the main pullback struct.
|
|
if (getPullbackInfo().hasHeapAllocatedContext()) {
|
|
// The last argument is the context object (`Builtin.NativeObject`).
|
|
contextValue = pullbackBB->getArguments().back();
|
|
assert(contextValue->getType() ==
|
|
SILType::getNativeObjectType(getASTContext()));
|
|
// Load the pullback context.
|
|
auto subcontextAddr = emitProjectTopLevelSubcontext(
|
|
builder, pbLoc, contextValue, pbTupleLoweredType);
|
|
SILValue mainPullbackTuple = builder.createLoad(
|
|
pbLoc, subcontextAddr,
|
|
pbTupleLoweredType.isTrivial(getPullback()) ?
|
|
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
|
|
auto *dsi = builder.createDestructureTuple(pbLoc, mainPullbackTuple);
|
|
initializePullbackTupleElements(origBB, dsi->getAllResults());
|
|
} else {
|
|
// Obtain and destructure pullback struct elements.
|
|
unsigned numVals = pbTupleLoweredType.getAs<TupleType>()->getNumElements();
|
|
initializePullbackTupleElements(origBB,
|
|
pullbackBB->getArguments().take_back(numVals));
|
|
}
|
|
|
|
continue;
|
|
}
|
|
|
|
// Get all active values in the original block.
|
|
// If the original block has no active values, continue.
|
|
auto &bbActiveValues = activeValues[origBB];
|
|
if (bbActiveValues.empty())
|
|
continue;
|
|
|
|
// Otherwise, if the original block has active values:
|
|
// - For each active buffer in the original block, allocate a new local
|
|
// buffer in the pullback entry. (All adjoint buffers are allocated in
|
|
// the pullback entry and deallocated in the pullback exit.)
|
|
// - For each active value in the original block, add adjoint value
|
|
// arguments to the pullback block.
|
|
for (auto activeValue : bbActiveValues) {
|
|
// Handle the active value based on its value category.
|
|
switch (getTangentValueCategory(activeValue)) {
|
|
case SILValueCategory::Address: {
|
|
// Allocate and zero initialize a new local buffer using
|
|
// `getAdjointBuffer`.
|
|
builder.setCurrentDebugScope(
|
|
remapScope(originalExitBlock->getTerminator()->getDebugScope()));
|
|
builder.setInsertionPoint(pullback.getEntryBlock());
|
|
getAdjointBuffer(origBB, activeValue);
|
|
break;
|
|
}
|
|
case SILValueCategory::Object: {
|
|
// Create and register pullback block argument for the active value.
|
|
auto *pullbackArg = pullbackBB->createPhiArgument(
|
|
getRemappedTangentType(activeValue->getType()),
|
|
OwnershipKind::Owned);
|
|
activeValuePullbackBBArgumentMap[{origBB, activeValue}] = pullbackArg;
|
|
recordTemporary(pullbackArg);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
// Add a pullback tuple argument.
|
|
auto *pbTupleArg = pullbackBB->createPhiArgument(pbTupleLoweredType,
|
|
OwnershipKind::Owned);
|
|
// Destructure the pullback struct to get the elements.
|
|
builder.setCurrentDebugScope(
|
|
remapScope(origBB->getTerminator()->getDebugScope()));
|
|
builder.setInsertionPoint(pullbackBB);
|
|
auto *dsi = builder.createDestructureTuple(pbLoc, pbTupleArg);
|
|
initializePullbackTupleElements(origBB, dsi->getResults());
|
|
|
|
// - Create pullback trampoline blocks for each successor block of the
|
|
// original block. Pullback trampoline blocks only have a pullback
|
|
// struct argument. They branch from a pullback successor block to the
|
|
// pullback original block, passing adjoint values of active values.
|
|
for (auto *succBB : origBB->getSuccessorBlocks()) {
|
|
// Skip generating pullback block for original unreachable blocks.
|
|
if (!workqueue.isVisited(succBB))
|
|
continue;
|
|
auto *pullbackTrampolineBB = pullback.createBasicBlockBefore(pullbackBB);
|
|
pullbackTrampolineBBMap.insert({{origBB, succBB}, pullbackTrampolineBB});
|
|
// Get the enum element type (i.e. the pullback struct type). The enum
|
|
// element type may be boxed if the enum is indirect.
|
|
auto enumLoweredTy =
|
|
getPullbackInfo().getBranchingTraceEnumLoweredType(succBB);
|
|
auto *enumEltDecl =
|
|
getPullbackInfo().lookUpBranchingTraceEnumElement(origBB, succBB);
|
|
auto enumEltType = remapType(enumLoweredTy.getEnumElementType(
|
|
enumEltDecl, getModule(), TypeExpansionContext::minimal()));
|
|
pullbackTrampolineBB->createPhiArgument(enumEltType,
|
|
OwnershipKind::Owned);
|
|
}
|
|
}
|
|
|
|
auto *pullbackEntry = pullback.getEntryBlock();
|
|
auto pbTupleLoweredType =
|
|
remapType(getPullbackInfo().getLinearMapTupleLoweredType(originalExitBlock));
|
|
unsigned numVals = (getPullbackInfo().hasHeapAllocatedContext() ?
|
|
1 : pbTupleLoweredType.getAs<TupleType>()->getNumElements());
|
|
(void)numVals;
|
|
|
|
// The pullback function has type:
|
|
// `(seed0, seed1, ..., (exit_pb_tuple_el0, ..., )|context_obj) -> (d_arg0, ..., d_argn)`.
|
|
auto conv = getOriginal().getConventions();
|
|
auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults();
|
|
assert(getConfig().resultIndices->getNumIndices() - conv.getNumYields() == pbParamArgs.size() - numVals &&
|
|
pbParamArgs.size() >= 1);
|
|
// Assign adjoints for original result.
|
|
builder.setCurrentDebugScope(
|
|
remapScope(originalExitBlock->getTerminator()->getDebugScope()));
|
|
builder.setInsertionPoint(pullbackEntry,
|
|
getNextFunctionLocalAllocationInsertionPoint());
|
|
unsigned seedIndex = 0;
|
|
unsigned firstSemanticParamResultIdx = conv.getResults().size();
|
|
unsigned firstYieldResultIndex = firstSemanticParamResultIdx +
|
|
conv.getNumAutoDiffSemanticResultParameters();
|
|
for (auto resultIndex : getConfig().resultIndices->getIndices()) {
|
|
// Yields seed buffers are only to be touched in yield BB and required
|
|
// special handling
|
|
if (resultIndex >= firstYieldResultIndex)
|
|
continue;
|
|
|
|
auto origResult = origFormalResults[resultIndex];
|
|
auto *seed = pbParamArgs[seedIndex];
|
|
if (seed->getType().isAddress()) {
|
|
// If the seed argument is an `inout` parameter, assign it directly as
|
|
// the adjoint buffer of the original result.
|
|
auto seedParamInfo =
|
|
pullback.getLoweredFunctionType()->getParameters()[seedIndex];
|
|
|
|
if (seedParamInfo.isIndirectInOut()) {
|
|
setAdjointBuffer(originalExitBlock, origResult, seed);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Assigned seed buffer " << *seed
|
|
<< " as the adjoint of original indirect result "
|
|
<< origResult);
|
|
}
|
|
// Otherwise, assign a copy of the seed argument as the adjoint buffer of
|
|
// the original result.
|
|
else {
|
|
auto *seedBufCopy =
|
|
createFunctionLocalAllocation(seed->getType(), pbLoc);
|
|
builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake,
|
|
IsInitialization);
|
|
setAdjointBuffer(originalExitBlock, origResult, seedBufCopy);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Assigned seed buffer " << *seedBufCopy
|
|
<< " as the adjoint of original indirect result "
|
|
<< origResult);
|
|
}
|
|
} else {
|
|
addAdjointValue(originalExitBlock, origResult, makeConcreteAdjointValue(seed),
|
|
pbLoc);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Assigned seed " << *seed
|
|
<< " as the adjoint of original result " << origResult);
|
|
}
|
|
++seedIndex;
|
|
}
|
|
|
|
// If the original function is an accessor with special-case pullback
|
|
// generation logic, do special-case generation.
|
|
if (isSemanticMemberAccessor(&original)) {
|
|
if (runForSemanticMemberAccessor())
|
|
return true;
|
|
}
|
|
// Otherwise, perform standard pullback generation.
|
|
// Visit original blocks in post-order and perform differentiation
|
|
// in corresponding pullback blocks. If errors occurred, back out.
|
|
else {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Begin search for adjoints of loop-local active values\n");
|
|
llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
|
|
loopLocalActiveValues;
|
|
for (auto *bb : originalBlocks) {
|
|
const SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb);
|
|
if (loop == nullptr)
|
|
continue;
|
|
SILBasicBlock *loopHeader = loop->getHeader();
|
|
SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Original bb" << bb->getDebugID()
|
|
<< " belongs to a loop, original header bb"
|
|
<< loopHeader->getDebugID() << ", pullback header bb"
|
|
<< pbLoopHeader->getDebugID() << '\n');
|
|
builder.setInsertionPoint(pbLoopHeader);
|
|
auto bbActiveValuesIt = activeValues.find(bb);
|
|
if (bbActiveValuesIt == activeValues.end())
|
|
continue;
|
|
const auto &bbActiveValues = bbActiveValuesIt->second;
|
|
for (SILValue bbActiveValue : bbActiveValues) {
|
|
if (vjpCloner.getLoopInfo()->getLoopFor(
|
|
bbActiveValue->getParentBlock()) != loop) {
|
|
LLVM_DEBUG(
|
|
getADDebugStream()
|
|
<< "The following active value is NOT loop-local, skipping: "
|
|
<< bbActiveValue);
|
|
continue;
|
|
}
|
|
|
|
auto [_, wasInserted] =
|
|
loopLocalActiveValues[loop].insert(bbActiveValue);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "The following active value is loop-local, ");
|
|
if (!wasInserted) {
|
|
LLVM_DEBUG(llvm::dbgs() << "but it was already processed, skipping: "
|
|
<< bbActiveValue);
|
|
continue;
|
|
}
|
|
|
|
if (getTangentValueCategory(bbActiveValue) ==
|
|
SILValueCategory::Object) {
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "zeroing its adjoint value in loop header: "
|
|
<< bbActiveValue);
|
|
setAdjointValue(bb, bbActiveValue,
|
|
makeZeroAdjointValue(getRemappedTangentType(
|
|
bbActiveValue->getType())));
|
|
continue;
|
|
}
|
|
|
|
ASSERT(getTangentValueCategory(bbActiveValue) ==
|
|
SILValueCategory::Address);
|
|
|
|
// getAdjointProjection might call materializeAdjointDirect which
|
|
// writes to debug output, emit \n.
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "checking if it's adjoint is a projection\n");
|
|
|
|
if (!getAdjointProjection(bb, bbActiveValue)) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Adjoint for the following value is NOT a projection, "
|
|
"zeroing its adjoint buffer in loop header: "
|
|
<< bbActiveValue);
|
|
|
|
// All adjoint buffers are allocated in the pullback entry and
|
|
// deallocated in the pullback exit. So, use IsNotInitialization to
|
|
// emit destroy_addr before zeroing the buffer.
|
|
ASSERT(bufferMap.contains({bb, bbActiveValue}));
|
|
builder.emitZeroIntoBuffer(pbLoc, getAdjointBuffer(bb, bbActiveValue),
|
|
IsNotInitialization);
|
|
|
|
continue;
|
|
}
|
|
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Adjoint for the following value is a projection, ");
|
|
|
|
// If Projection::isAddressProjection(v) is true for a value v, it
|
|
// is not added to active values list (see recordValueIfActive).
|
|
//
|
|
// Ensure that only the following value types conforming to
|
|
// getAdjointProjection but not conforming to
|
|
// Projection::isAddressProjection can go here.
|
|
//
|
|
// Instructions conforming to Projection::isAddressProjection and
|
|
// thus never corresponding to an active value do not need any
|
|
// handling, because only active values can have adjoints from
|
|
// previous iterations propagated via BB arguments.
|
|
do {
|
|
// Consider '%X = begin_access [modify] [static] %Y'.
|
|
// 1. If %Y is loop-local, it's adjoint buffer will
|
|
// be zeroed, and we'll have zero adjoint projection to it.
|
|
// 2. Otherwise, we do not need to zero the projection buffer.
|
|
// Thus, we can just skip.
|
|
if (dyn_cast<BeginAccessInst>(bbActiveValue)) {
|
|
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
|
|
break;
|
|
}
|
|
|
|
// Consider the following sequence:
|
|
// %1 = function_ref @allocUninitArray
|
|
// %2 = apply %1<Float>(%0)
|
|
// (%3, %4) = destructure_tuple %2
|
|
// %5 = mark_dependence %4 on %3
|
|
// %6 = pointer_to_address %6 to [strict] $*Float
|
|
// Since %6 is active, %3 (which is an array) must also be active.
|
|
// Thus, adjoint for %3 will be zeroed if needed. Ensure that expected
|
|
// invariants hold and then skip.
|
|
if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress(
|
|
bbActiveValue)) {
|
|
ASSERT(isa<PointerToAddressInst>(bbActiveValue));
|
|
SILValue arrayValue = getArrayValue(ai);
|
|
ASSERT(llvm::find(bbActiveValues, arrayValue) !=
|
|
bbActiveValues.end());
|
|
ASSERT(vjpCloner.getLoopInfo()->getLoopFor(
|
|
arrayValue->getParentBlock()) == loop);
|
|
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
|
|
break;
|
|
}
|
|
|
|
ASSERT(false);
|
|
} while (false);
|
|
}
|
|
}
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "End search for adjoints of loop-local active values\n");
|
|
|
|
for (auto *bb : originalBlocks) {
|
|
visitSILBasicBlock(bb);
|
|
if (errorOccurred)
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// Prepare and emit a `return` in the pullback exit block.
|
|
auto *origEntry = getOriginal().getEntryBlock();
|
|
auto *pbExit = getPullbackBlock(origEntry);
|
|
builder.setCurrentDebugScope(pbExit->back().getDebugScope());
|
|
builder.setInsertionPoint(pbExit);
|
|
|
|
// This vector will contain all the materialized return elements.
|
|
SmallVector<SILValue, 8> retElts;
|
|
// This vector will contain all indirect parameter adjoint buffers.
|
|
SmallVector<SILValue, 4> indParamAdjoints;
|
|
// This vector will identify the locations where initialization is needed.
|
|
SmallBitVector outputsToInitialize;
|
|
|
|
auto origParams = getOriginal().getArgumentsWithoutIndirectResults();
|
|
|
|
// Materializes the return element corresponding to the parameter
|
|
// `parameterIndex` into the `retElts` vector.
|
|
auto addRetElt = [&](unsigned parameterIndex) -> void {
|
|
auto origParam = origParams[parameterIndex];
|
|
switch (getTangentValueCategory(origParam)) {
|
|
case SILValueCategory::Object: {
|
|
auto pbVal = getAdjointValue(origEntry, origParam);
|
|
auto val = materializeAdjointDirect(pbVal, pbLoc);
|
|
auto newVal = builder.emitCopyValueOperation(pbLoc, val);
|
|
retElts.push_back(newVal);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(origEntry, origParam);
|
|
indParamAdjoints.push_back(adjBuf);
|
|
outputsToInitialize.push_back(
|
|
!conv.getParameters()[parameterIndex].isIndirectMutating());
|
|
break;
|
|
}
|
|
}
|
|
};
|
|
SmallVector<SILArgument *, 4> pullbackIndirectResults(
|
|
getPullback().getIndirectResults().begin(),
|
|
getPullback().getIndirectResults().end());
|
|
|
|
// Collect differentiation parameter adjoints.
|
|
// Do a first pass to collect non-inout values.
|
|
for (auto i : getConfig().parameterIndices->getIndices()) {
|
|
if (!conv.getParameters()[i].isAutoDiffSemanticResult()) {
|
|
addRetElt(i);
|
|
}
|
|
}
|
|
|
|
// Do a second pass for all inout parameters, however this is only necessary
|
|
// for functions with multiple basic blocks. For functions with a single
|
|
// basic block adjoint accumulation for those parameters is already done by
|
|
// per-instruction visitors.
|
|
if (getOriginal().size() > 1) {
|
|
const auto &pullbackConv = pullback.getConventions();
|
|
SmallVector<SILArgument *, 1> pullbackInOutArgs;
|
|
for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) {
|
|
if (pullbackConv.getParameters()[pullbackArg.index()].isAutoDiffSemanticResult())
|
|
pullbackInOutArgs.push_back(pullbackArg.value());
|
|
}
|
|
|
|
unsigned pullbackInoutArgumentIdx = 0;
|
|
for (auto i : getConfig().parameterIndices->getIndices()) {
|
|
// Skip non-inout parameters.
|
|
if (!conv.getParameters()[i].isAutoDiffSemanticResult())
|
|
continue;
|
|
|
|
// For functions with multiple basic blocks, accumulation is needed
|
|
// for `inout` parameters because pullback basic blocks have different
|
|
// adjoint buffers.
|
|
pullbackIndirectResults.push_back(pullbackInOutArgs[pullbackInoutArgumentIdx++]);
|
|
addRetElt(i);
|
|
}
|
|
}
|
|
|
|
// Copy them to adjoint indirect results.
|
|
assert(indParamAdjoints.size() == pullbackIndirectResults.size() &&
|
|
"Indirect parameter adjoint count mismatch");
|
|
unsigned currentIndex = 0;
|
|
for (auto pair : zip(indParamAdjoints, pullbackIndirectResults)) {
|
|
auto source = std::get<0>(pair);
|
|
auto *dest = std::get<1>(pair);
|
|
if (outputsToInitialize[currentIndex]) {
|
|
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
|
|
} else {
|
|
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsNotInitialization);
|
|
}
|
|
currentIndex++;
|
|
// Prevent source buffer from being deallocated, since the underlying
|
|
// value is moved.
|
|
destroyedLocalAllocations.insert(source);
|
|
}
|
|
|
|
// Emit cleanups for all local values.
|
|
cleanUpTemporariesForBlock(pbExit, pbLoc);
|
|
// Deallocate local allocations.
|
|
for (auto alloc : functionLocalAllocations) {
|
|
// Assert that local allocations have at least one use.
|
|
// Buffers should not be allocated needlessly.
|
|
assert(!alloc->use_empty());
|
|
if (!destroyedLocalAllocations.count(alloc)) {
|
|
builder.emitDestroyAddrAndFold(pbLoc, alloc);
|
|
destroyedLocalAllocations.insert(alloc);
|
|
}
|
|
builder.createDeallocStack(pbLoc, alloc);
|
|
}
|
|
builder.createReturn(pbLoc, joinElements(retElts, builder, pbLoc));
|
|
|
|
#ifndef NDEBUG
|
|
bool leakFound = false;
|
|
// Ensure all temporaries have been cleaned up.
|
|
for (auto &bb : pullback) {
|
|
for (auto temp : blockTemporaries[&bb]) {
|
|
if (blockTemporaries[&bb].count(temp)) {
|
|
leakFound = true;
|
|
getADDebugStream() << "Found leaked temporary:\n" << temp;
|
|
}
|
|
}
|
|
}
|
|
// Ensure all enum adjoint copeis have been cleaned up
|
|
for (const auto &enumData : enumDataAdjCopies) {
|
|
leakFound = true;
|
|
getADDebugStream() << "Found leaked temporary:\n" << enumData.second;
|
|
}
|
|
|
|
// Ensure all local allocations have been cleaned up.
|
|
for (auto localAlloc : functionLocalAllocations) {
|
|
if (!destroyedLocalAllocations.count(localAlloc)) {
|
|
leakFound = true;
|
|
getADDebugStream() << "Found leaked local buffer:\n" << localAlloc;
|
|
}
|
|
}
|
|
assert(!leakFound && "Leaks found!");
|
|
#endif
|
|
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Generated pullback for " << original.getName() << ":\n"
|
|
<< pullback);
|
|
return errorOccurred;
|
|
}
|
|
|
|
void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult(
|
|
SILValue origNonvariedResult) {
|
|
auto &pullback = getPullback();
|
|
auto pbLoc = getPullback().getLocation();
|
|
/*
|
|
// TODO(TF-788): Re-enable non-varied result warning.
|
|
// Emit fixit if original non-varied result has a valid source location.
|
|
auto startLoc = origNonvariedResult.getLoc().getStartSourceLoc();
|
|
auto endLoc = origNonvariedResult.getLoc().getEndSourceLoc();
|
|
if (startLoc.isValid() && endLoc.isValid()) {
|
|
getContext().diagnose(startLoc, diag::autodiff_nonvaried_result_fixit)
|
|
.fixItInsert(startLoc, "withoutDerivative(at:")
|
|
.fixItInsertAfter(endLoc, ")");
|
|
}
|
|
*/
|
|
LLVM_DEBUG(getADDebugStream() << getOriginal().getName()
|
|
<< " has non-varied result, returning zero"
|
|
" for all pullback results\n");
|
|
auto *pullbackEntry = pullback.createBasicBlock();
|
|
createEntryArguments(&pullback);
|
|
builder.setCurrentDebugScope(
|
|
remapScope(originalExitBlock->getTerminator()->getDebugScope()));
|
|
builder.setInsertionPoint(pullbackEntry);
|
|
// Destroy all owned arguments.
|
|
for (auto *arg : pullbackEntry->getArguments())
|
|
if (arg->getOwnershipKind() == OwnershipKind::Owned)
|
|
builder.emitDestroyOperation(pbLoc, arg);
|
|
// Return zero for each result.
|
|
SmallVector<SILValue, 4> directResults;
|
|
auto indirectResultIt = pullback.getIndirectResults().begin();
|
|
for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) {
|
|
auto resultType =
|
|
pullback.mapTypeIntoContext(resultInfo.getInterfaceType())
|
|
->getCanonicalType();
|
|
if (resultInfo.isFormalDirect())
|
|
directResults.push_back(builder.emitZero(pbLoc, resultType));
|
|
else
|
|
builder.emitZeroIntoBuffer(pbLoc, *indirectResultIt++, IsInitialization);
|
|
}
|
|
builder.createReturn(pbLoc, joinElements(directResults, builder, pbLoc));
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Generated pullback for " << getOriginal().getName() << ":\n"
|
|
<< pullback);
|
|
}
|
|
|
|
AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint(
|
|
SILBasicBlock *bb, SILValue wrappedAdjoint, SILType optionalTy) {
|
|
auto pbLoc = getPullback().getLocation();
|
|
// `Optional<T>`
|
|
optionalTy = remapType(optionalTy);
|
|
assert(optionalTy.getASTType()->isOptional());
|
|
// `T`
|
|
auto wrappedType = optionalTy.getOptionalObjectType();
|
|
// `T.TangentVector`
|
|
auto wrappedTanType = remapType(wrappedAdjoint->getType());
|
|
// `Optional<T.TangentVector>`
|
|
auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType);
|
|
// `Optional<T>.TangentVector`
|
|
auto optionalTanTy = getRemappedTangentType(optionalTy);
|
|
// Look up the `Optional<T>.TangentVector.init` declaration.
|
|
ConstructorDecl *constructorDecl =
|
|
getASTContext().getOptionalTanInitDecl(optionalTanTy.getASTType());
|
|
|
|
// Allocate a local buffer for the `Optional` adjoint value.
|
|
auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy);
|
|
// Find `Optional<T.TangentVector>.some` EnumElementDecl.
|
|
auto someEltDecl = builder.getASTContext().getOptionalSomeDecl();
|
|
|
|
// Initialize an `Optional<T.TangentVector>` buffer from `wrappedAdjoint` as
|
|
// the input for `Optional<T>.TangentVector.init`.
|
|
auto *optArgBuf = builder.createAllocStack(pbLoc, optionalOfWrappedTanType);
|
|
if (optionalOfWrappedTanType.isObject()) {
|
|
// %enum = enum $Optional<T.TangentVector>, #Optional.some!enumelt,
|
|
// %wrappedAdjoint : $T
|
|
auto *enumInst = builder.createEnum(pbLoc, wrappedAdjoint, someEltDecl,
|
|
optionalOfWrappedTanType);
|
|
// store %enum to %optArgBuf
|
|
builder.emitStoreValueOperation(pbLoc, enumInst, optArgBuf,
|
|
StoreOwnershipQualifier::Init);
|
|
} else {
|
|
// %enumAddr = init_enum_data_addr %optArgBuf $Optional<T.TangentVector>,
|
|
// #Optional.some!enumelt
|
|
auto *enumAddr = builder.createInitEnumDataAddr(
|
|
pbLoc, optArgBuf, someEltDecl, wrappedTanType.getAddressType());
|
|
// copy_addr %wrappedAdjoint to [init] %enumAddr
|
|
builder.createCopyAddr(pbLoc, wrappedAdjoint, enumAddr, IsNotTake,
|
|
IsInitialization);
|
|
// inject_enum_addr %optArgBuf : $*Optional<T.TangentVector>,
|
|
// #Optional.some!enumelt
|
|
builder.createInjectEnumAddr(pbLoc, optArgBuf, someEltDecl);
|
|
}
|
|
|
|
// Apply `Optional<T>.TangentVector.init`.
|
|
SILOptFunctionBuilder fb(getContext().getTransform());
|
|
// %init_fn = function_ref @Optional<T>.TangentVector.init
|
|
auto *initFn = fb.getOrCreateFunction(pbLoc, SILDeclRef(constructorDecl),
|
|
NotForDefinition);
|
|
auto *initFnRef = builder.createFunctionRef(pbLoc, initFn);
|
|
auto *diffProto =
|
|
builder.getASTContext().getProtocol(KnownProtocolKind::Differentiable);
|
|
auto diffConf =
|
|
lookupConformance(wrappedType.getASTType(), diffProto);
|
|
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
|
|
auto subMap = SubstitutionMap::get(
|
|
initFn->getLoweredFunctionType()->getSubstGenericSignature(),
|
|
ArrayRef<Type>(wrappedType.getASTType()), {diffConf});
|
|
// %metatype = metatype $Optional<T>.TangentVector.Type
|
|
auto metatypeType = CanMetatypeType::get(optionalTanTy.getASTType(),
|
|
MetatypeRepresentation::Thin);
|
|
auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
|
|
auto metatype = builder.createMetatype(pbLoc, metatypeSILType);
|
|
// apply %init_fn(%optTanAdjBuf, %optArgBuf, %metatype)
|
|
builder.createApply(pbLoc, initFnRef, subMap,
|
|
{optTanAdjBuf, optArgBuf, metatype});
|
|
builder.createDeallocStack(pbLoc, optArgBuf);
|
|
return optTanAdjBuf;
|
|
}
|
|
|
|
// Accumulate adjoint for the incoming `Optional` buffer.
|
|
void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer(
|
|
SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint) {
|
|
assert(getTangentValueCategory(optionalBuffer) == SILValueCategory::Address);
|
|
auto pbLoc = getPullback().getLocation();
|
|
|
|
// Allocate and initialize Optional<Wrapped>.TangentVector from
|
|
// Wrapped.TangentVector
|
|
AllocStackInst *optTanAdjBuf =
|
|
createOptionalAdjoint(bb, wrappedAdjoint, optionalBuffer->getType());
|
|
|
|
// Accumulate into optionalBuffer
|
|
addToAdjointBuffer(bb, optionalBuffer, optTanAdjBuf, pbLoc);
|
|
builder.emitDestroyAddr(pbLoc, optTanAdjBuf);
|
|
builder.createDeallocStack(pbLoc, optTanAdjBuf);
|
|
}
|
|
|
|
// Accumulate adjoint for the incoming `Optional` value.
|
|
void PullbackCloner::Implementation::accumulateAdjointValueForOptional(
|
|
SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
|
|
assert(getTangentValueCategory(optionalValue) == SILValueCategory::Object);
|
|
auto pbLoc = getPullback().getLocation();
|
|
|
|
// Allocate and initialize Optional<Wrapped>.TangentVector from
|
|
// Wrapped.TangentVector
|
|
AllocStackInst *optTanAdjBuf =
|
|
createOptionalAdjoint(bb, wrappedAdjoint, optionalValue->getType());
|
|
|
|
auto optTanAdjVal = builder.emitLoadValueOperation(
|
|
pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
|
|
|
|
recordTemporary(optTanAdjVal);
|
|
builder.createDeallocStack(pbLoc, optTanAdjBuf);
|
|
|
|
addAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal), pbLoc);
|
|
}
|
|
|
|
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
|
|
SILBasicBlock *origBB, SILBasicBlock *origPredBB,
|
|
SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) {
|
|
// Get the pullback block and optional pullback trampoline block of the
|
|
// predecessor block.
|
|
auto *pullbackBB = getPullbackBlock(origPredBB);
|
|
auto *pullbackTrampolineBB = getPullbackTrampolineBlock(origPredBB, origBB);
|
|
// If the predecessor block does not have a corresponding pullback
|
|
// trampoline block, then the pullback successor is the pullback block.
|
|
if (!pullbackTrampolineBB)
|
|
return pullbackBB;
|
|
|
|
// Otherwise, the pullback successor is the pullback trampoline block,
|
|
// which branches to the pullback block and propagates adjoint values of
|
|
// active values.
|
|
assert(pullbackTrampolineBB->getNumArguments() == 1);
|
|
auto loc = origBB->getParent()->getLocation();
|
|
SmallVector<SILValue, 8> trampolineArguments;
|
|
|
|
// Propagate adjoint values/buffers of active values/buffers to
|
|
// predecessor blocks.
|
|
auto &predBBActiveValues = activeValues[origPredBB];
|
|
llvm::SmallSet<std::pair<SILValue, SILValue>, 32> propagatedAdjoints;
|
|
for (auto activeValue : predBBActiveValues) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Propagating adjoint of active value " << activeValue
|
|
<< "from bb" << origBB->getDebugID()
|
|
<< " to predecessors' (bb" << origPredBB->getDebugID()
|
|
<< ") pullback blocks\n");
|
|
switch (getTangentValueCategory(activeValue)) {
|
|
case SILValueCategory::Object: {
|
|
auto activeValueAdj = getAdjointValue(origBB, activeValue);
|
|
auto concreteActiveValueAdj =
|
|
materializeAdjointDirect(activeValueAdj, loc);
|
|
|
|
if (!pullbackTrampolineBlockMap.count(concreteActiveValueAdj)) {
|
|
concreteActiveValueAdj =
|
|
builder.emitCopyValueOperation(loc, concreteActiveValueAdj);
|
|
setAdjointValue(origBB, activeValue,
|
|
makeConcreteAdjointValue(concreteActiveValueAdj));
|
|
}
|
|
auto insertion = pullbackTrampolineBlockMap.try_emplace(
|
|
concreteActiveValueAdj, TrampolineBlockSet());
|
|
auto &blockSet = insertion.first->getSecond();
|
|
blockSet.insert(pullbackTrampolineBB);
|
|
trampolineArguments.push_back(concreteActiveValueAdj);
|
|
|
|
// If the pullback block does not yet have a registered adjoint
|
|
// value for the active value, set the adjoint value to the
|
|
// forwarded adjoint value argument.
|
|
// TODO: Hoist this logic out of loop over predecessor blocks to
|
|
// remove the `hasAdjointValue` check.
|
|
if (!hasAdjointValue(origPredBB, activeValue)) {
|
|
auto *pullbackBBArg =
|
|
getActiveValuePullbackBlockArgument(origPredBB, activeValue);
|
|
auto forwardedArgAdj = makeConcreteAdjointValue(pullbackBBArg);
|
|
setAdjointValue(origPredBB, activeValue, forwardedArgAdj);
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
// Propagate adjoint buffers using `copy_addr`.
|
|
auto adjBuf = getAdjointBuffer(origBB, activeValue);
|
|
auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue);
|
|
if (propagatedAdjoints.insert({adjBuf, predAdjBuf}).second)
|
|
builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake,
|
|
IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Propagate pullback struct argument.
|
|
TangentBuilder pullbackTrampolineBBBuilder(
|
|
pullbackTrampolineBB, getContext());
|
|
pullbackTrampolineBBBuilder.setCurrentDebugScope(
|
|
remapScope(origPredBB->getTerminator()->getDebugScope()));
|
|
|
|
auto *pullbackTrampolineBBArg = pullbackTrampolineBB->getArguments().front();
|
|
if (vjpCloner.getLoopInfo()->getLoopFor(origPredBB)) {
|
|
assert(pullbackTrampolineBBArg->getType() ==
|
|
SILType::getRawPointerType(getASTContext()));
|
|
auto pbTupleType =
|
|
remapType(getPullbackInfo().getLinearMapTupleLoweredType(origPredBB));
|
|
auto predPbTupleAddr = pullbackTrampolineBBBuilder.createPointerToAddress(
|
|
loc, pullbackTrampolineBBArg, pbTupleType.getAddressType(),
|
|
/*isStrict*/ true);
|
|
auto predPbStructVal = pullbackTrampolineBBBuilder.createLoad(
|
|
loc, predPbTupleAddr,
|
|
pbTupleType.isTrivial(getPullback()) ?
|
|
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
|
|
trampolineArguments.push_back(predPbStructVal);
|
|
} else {
|
|
trampolineArguments.push_back(pullbackTrampolineBBArg);
|
|
}
|
|
// Branch from pullback trampoline block to pullback block.
|
|
pullbackTrampolineBBBuilder.createBranch(loc, pullbackBB,
|
|
trampolineArguments);
|
|
return pullbackTrampolineBB;
|
|
}
|
|
|
|
void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
|
|
auto pbLoc = getPullback().getLocation();
|
|
// Get the corresponding pullback basic block.
|
|
auto *pbBB = getPullbackBlock(bb);
|
|
builder.setInsertionPoint(pbBB);
|
|
|
|
LLVM_DEBUG({
|
|
auto &s = getADDebugStream()
|
|
<< "Original bb" + std::to_string(bb->getDebugID())
|
|
<< ": To differentiate or not to differentiate?\n";
|
|
for (auto &inst : llvm::reverse(*bb)) {
|
|
s << (getPullbackInfo().shouldDifferentiateInstruction(&inst) ? "[x] "
|
|
: "[ ] ")
|
|
<< inst;
|
|
}
|
|
});
|
|
|
|
// Visit each instruction in reverse order.
|
|
for (auto &inst : llvm::reverse(*bb)) {
|
|
if (!getPullbackInfo().shouldDifferentiateInstruction(&inst))
|
|
continue;
|
|
// Differentiate instruction.
|
|
builder.setCurrentDebugScope(remapScope(inst.getDebugScope()));
|
|
visit(&inst);
|
|
if (errorOccurred)
|
|
return;
|
|
}
|
|
|
|
// Emit a branching terminator for the block.
|
|
// If the original block is the original entry, then the pullback block is
|
|
// the pullback exit. This is handled specially in
|
|
// `PullbackCloner::Implementation::run()`, so we leave the block
|
|
// non-terminated.
|
|
if (bb->isEntry())
|
|
return;
|
|
|
|
// If the original block is a resume yield destination, then we need to yield
|
|
// the adjoint buffer and do everything else in the resume destination. Unwind
|
|
// destination is unreachable as the co-routine can never be aborted.
|
|
if (auto *predBB = bb->getSinglePredecessorBlock()) {
|
|
if (auto *yield = dyn_cast<YieldInst>(predBB->getTerminator())) {
|
|
auto *resumeBB = pbBB->split(builder.getInsertionPoint());
|
|
auto *unwindBB = getPullback().createBasicBlock();
|
|
|
|
SmallVector<SILValue, 1> adjYields;
|
|
for (auto yieldedVal : yield->getYieldedValues())
|
|
adjYields.push_back(getAdjointBuffer(bb, yieldedVal));
|
|
|
|
builder.createYield(yield->getLoc(), adjYields, resumeBB, unwindBB);
|
|
builder.setInsertionPoint(unwindBB);
|
|
builder.createUnreachable(SILLocation::invalid());
|
|
|
|
pbBB = resumeBB;
|
|
builder.setInsertionPoint(pbBB);
|
|
}
|
|
}
|
|
|
|
// Otherwise, add a `switch_enum` terminator for non-exit
|
|
// pullback blocks.
|
|
// 1. Get the pullback struct pullback block argument.
|
|
// 2. Extract the predecessor enum value from the pullback struct value.
|
|
auto *predEnum = getPullbackInfo().getBranchingTraceDecl(bb);
|
|
(void)predEnum;
|
|
auto predEnumVal = getPullbackPredTupleElement(bb);
|
|
|
|
// Propagate adjoint values from active basic block arguments to
|
|
// incoming values (predecessor terminator operands).
|
|
for (auto *bbArg : bb->getArguments()) {
|
|
if (!getActivityInfo().isActive(bbArg, getConfig()))
|
|
continue;
|
|
LLVM_DEBUG(getADDebugStream() << "Propagating adjoint value for active bb"
|
|
<< bb->getDebugID() << " argument: "
|
|
<< *bbArg);
|
|
|
|
// Get predecessor terminator operands.
|
|
SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues;
|
|
if (bbArg->getSingleTerminatorOperands(incomingValues)) {
|
|
// Returns true if the given terminator instruction is a `switch_enum` on
|
|
// an `Optional`-typed value. `switch_enum` instructions require
|
|
// special-case adjoint value propagation for the operand.
|
|
auto isSwitchEnumInstOnOptional =
|
|
[&ctx = getASTContext()](TermInst *termInst) {
|
|
if (!termInst)
|
|
return false;
|
|
if (auto *sei = dyn_cast<SwitchEnumInst>(termInst)) {
|
|
auto operandTy = sei->getOperand()->getType();
|
|
return operandTy.getASTType()->isOptional();
|
|
}
|
|
return false;
|
|
};
|
|
|
|
// Check the tangent value category of the active basic block argument.
|
|
switch (getTangentValueCategory(bbArg)) {
|
|
// If argument has a loadable tangent value category: materialize adjoint
|
|
// value of the argument, create a copy, and set the copy as the adjoint
|
|
// value of incoming values.
|
|
case SILValueCategory::Object: {
|
|
auto bbArgAdj = getAdjointValue(bb, bbArg);
|
|
auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc);
|
|
auto concreteBBArgAdjCopy =
|
|
builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj);
|
|
for (auto pair : incomingValues) {
|
|
auto *predBB = std::get<0>(pair);
|
|
auto incomingValue = std::get<1>(pair);
|
|
// Handle `switch_enum` on `Optional`.
|
|
auto termInst = bbArg->getSingleTerminator();
|
|
if (isSwitchEnumInstOnOptional(termInst)) {
|
|
accumulateAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy);
|
|
} else {
|
|
blockTemporaries[getPullbackBlock(predBB)].insert(
|
|
concreteBBArgAdjCopy);
|
|
addAdjointValue(predBB, incomingValue,
|
|
makeConcreteAdjointValue(concreteBBArgAdjCopy), pbLoc);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
// If argument has an address tangent value category: materialize adjoint
|
|
// value of the argument, create a copy, and set the copy as the adjoint
|
|
// value of incoming values.
|
|
case SILValueCategory::Address: {
|
|
auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
|
|
for (auto pair : incomingValues) {
|
|
auto incomingValue = std::get<1>(pair);
|
|
// Handle `switch_enum` on `Optional`.
|
|
auto termInst = bbArg->getSingleTerminator();
|
|
if (isSwitchEnumInstOnOptional(termInst))
|
|
accumulateAdjointForOptionalBuffer(bb, incomingValue, bbArgAdjBuf);
|
|
else
|
|
addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
LLVM_DEBUG(getADDebugStream() <<
|
|
"do not know how to handle this incoming bb argument");
|
|
if (auto term = bbArg->getSingleTerminator()) {
|
|
getContext().emitNondifferentiabilityError(term, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
} else {
|
|
// This will be a bit confusing, but still better than nothing.
|
|
getContext().emitNondifferentiabilityError(bbArg, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
}
|
|
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
}
|
|
|
|
// 3. Build the pullback successor cases for the `switch_enum`
|
|
// instruction. The pullback successors correspond to the predecessors
|
|
// of the current original block.
|
|
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4>
|
|
pullbackSuccessorCases;
|
|
// A map from active values' adjoint values to the trampoline blocks that
|
|
// are using them.
|
|
SmallDenseMap<SILValue, TrampolineBlockSet> pullbackTrampolineBlockMap;
|
|
SmallDenseMap<SILBasicBlock *, SILBasicBlock *> origPredpullbackSuccBBMap;
|
|
for (auto *predBB : bb->getPredecessorBlocks()) {
|
|
// If there is no linear map tuple for predecessor BB, then BB is
|
|
// unreachable from function entry. There is no branch tracing enum for it
|
|
// as well, so we should not create any branching to it in the pullback.
|
|
if (!getPullbackInfo().getLinearMapTupleType(predBB))
|
|
continue;
|
|
auto *pullbackSuccBB =
|
|
buildPullbackSuccessor(bb, predBB, pullbackTrampolineBlockMap);
|
|
origPredpullbackSuccBBMap[predBB] = pullbackSuccBB;
|
|
auto *enumEltDecl =
|
|
getPullbackInfo().lookUpBranchingTraceEnumElement(predBB, bb);
|
|
pullbackSuccessorCases.emplace_back(enumEltDecl, pullbackSuccBB);
|
|
}
|
|
// Values are trampolined by only a subset of pullback successor blocks.
|
|
// Other successors blocks should destroy the value.
|
|
for (auto pair : pullbackTrampolineBlockMap) {
|
|
auto value = pair.getFirst();
|
|
// The set of trampoline BBs that are users of `value`.
|
|
auto &userTrampolineBBSet = pair.getSecond();
|
|
// For each pullback successor block that does not trampoline the value,
|
|
// release the value.
|
|
for (auto origPredPbSuccPair : origPredpullbackSuccBBMap) {
|
|
auto *origPred = origPredPbSuccPair.getFirst();
|
|
auto *pbSucc = origPredPbSuccPair.getSecond();
|
|
if (userTrampolineBBSet.count(pbSucc))
|
|
continue;
|
|
TangentBuilder pullbackSuccBuilder(pbSucc->begin(), getContext());
|
|
pullbackSuccBuilder.setCurrentDebugScope(
|
|
remapScope(origPred->getTerminator()->getDebugScope()));
|
|
pullbackSuccBuilder.emitDestroyValueOperation(pbLoc, value);
|
|
}
|
|
}
|
|
// Emit cleanups for all block-local temporaries.
|
|
cleanUpTemporariesForBlock(pbBB, pbLoc);
|
|
// Branch to pullback successor blocks.
|
|
assert(pullbackSuccessorCases.size() == predEnum->getNumElements());
|
|
builder.createSwitchEnum(pbLoc, predEnumVal, /*DefaultBB*/ nullptr,
|
|
pullbackSuccessorCases, std::nullopt,
|
|
ProfileCounter(), OwnershipKind::Owned);
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Member accessor pullback generation
|
|
//--------------------------------------------------------------------------//
|
|
|
|
bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
|
|
auto &original = getOriginal();
|
|
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
|
|
switch (accessor->getAccessorKind()) {
|
|
case AccessorKind::Get:
|
|
case AccessorKind::DistributedGet:
|
|
return runForSemanticMemberGetter();
|
|
case AccessorKind::Set:
|
|
return runForSemanticMemberSetter();
|
|
// TODO(https://github.com/apple/swift/issues/55084): Support `modify` accessors.
|
|
default:
|
|
llvm_unreachable("Unsupported accessor kind; inconsistent with "
|
|
"`isSemanticMemberAccessor`?");
|
|
}
|
|
}
|
|
|
|
bool PullbackCloner::Implementation::runForSemanticMemberGetter() {
|
|
auto &original = getOriginal();
|
|
auto &pullback = getPullback();
|
|
auto pbLoc = getPullback().getLocation();
|
|
|
|
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
|
|
assert(accessor->getAccessorKind() == AccessorKind::Get);
|
|
|
|
auto *origEntry = original.getEntryBlock();
|
|
auto *pbEntry = pullback.getEntryBlock();
|
|
builder.setCurrentDebugScope(
|
|
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
|
|
builder.setInsertionPoint(pbEntry);
|
|
|
|
// Get getter argument and result values.
|
|
// Getter type: $(Self) -> Result
|
|
// Pullback type: $(Result') -> Self'
|
|
assert(original.getLoweredFunctionType()->getNumParameters() == 1);
|
|
assert(pullback.getLoweredFunctionType()->getNumParameters() == 1);
|
|
assert(pullback.getLoweredFunctionType()->getNumResults() == 1);
|
|
SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();
|
|
|
|
SmallVector<SILValue, 8> origFormalResults;
|
|
collectAllFormalResultsInTypeOrder(original, origFormalResults);
|
|
assert(getConfig().resultIndices->getNumIndices() == 1 &&
|
|
"Getter should have one semantic result");
|
|
auto origResult = origFormalResults[*getConfig().resultIndices->begin()];
|
|
|
|
auto tangentVectorSILTy = pullback.getConventions().getResults().front()
|
|
.getSILStorageType(getModule(),
|
|
pullback.getLoweredFunctionType(),
|
|
TypeExpansionContext::minimal());
|
|
auto tangentVectorTy = tangentVectorSILTy.getASTType();
|
|
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
|
|
|
|
// Look up the corresponding field in the tangent space.
|
|
auto *origField = cast<VarDecl>(accessor->getStorage());
|
|
auto baseType = remapType(origSelf->getType()).getASTType();
|
|
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
|
|
pbLoc, getInvoker());
|
|
if (!tanField) {
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
|
|
// Switch based on the base tangent struct's value category.
|
|
switch (getTangentValueCategory(origSelf)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjResult = getAdjointValue(origEntry, origResult);
|
|
switch (adjResult.getKind()) {
|
|
case AdjointValueKind::Zero:
|
|
addAdjointValue(origEntry, origSelf,
|
|
makeZeroAdjointValue(tangentVectorSILTy), pbLoc);
|
|
break;
|
|
case AdjointValueKind::Concrete:
|
|
case AdjointValueKind::Aggregate: {
|
|
SmallVector<AdjointValue, 8> eltVals;
|
|
for (auto *field : tangentVectorDecl->getStoredProperties()) {
|
|
if (field == tanField) {
|
|
eltVals.push_back(adjResult);
|
|
} else {
|
|
auto substMap = tangentVectorTy->getMemberSubstitutionMap(field);
|
|
auto fieldTy = field->getInterfaceType().subst(substMap);
|
|
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
|
|
assert(fieldSILTy.isObject());
|
|
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
|
|
}
|
|
}
|
|
addAdjointValue(origEntry, origSelf,
|
|
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
|
|
pbLoc);
|
|
|
|
break;
|
|
}
|
|
case AdjointValueKind::AddElement:
|
|
llvm_unreachable("Adjoint of an aggregate type's field cannot be of kind "
|
|
"`AddElement`");
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
assert(pullback.getIndirectResults().size() == 1);
|
|
auto pbIndRes = pullback.getIndirectResults().front();
|
|
auto *adjSelf = createFunctionLocalAllocation(
|
|
pbIndRes->getType().getObjectType(), pbLoc);
|
|
setAdjointBuffer(origEntry, origSelf, adjSelf);
|
|
for (auto *field : tangentVectorDecl->getStoredProperties()) {
|
|
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field);
|
|
// Non-tangent fields get a zero.
|
|
if (field != tanField) {
|
|
builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization);
|
|
continue;
|
|
}
|
|
// Switch based on the property's value category.
|
|
switch (getTangentValueCategory(origResult)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjResult = getAdjointValue(origEntry, origResult);
|
|
auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc);
|
|
auto adjResultValueCopy =
|
|
builder.emitCopyValueOperation(pbLoc, adjResultValue);
|
|
builder.emitStoreValueOperation(pbLoc, adjResultValueCopy, adjSelfElt,
|
|
StoreOwnershipQualifier::Init);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjResult = getAdjointBuffer(origEntry, origResult);
|
|
builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake,
|
|
IsInitialization);
|
|
destroyedLocalAllocations.insert(adjResult);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
|
|
auto &original = getOriginal();
|
|
auto &pullback = getPullback();
|
|
auto pbLoc = getPullback().getLocation();
|
|
|
|
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
|
|
assert(accessor->getAccessorKind() == AccessorKind::Set);
|
|
|
|
auto *origEntry = original.getEntryBlock();
|
|
auto *pbEntry = pullback.getEntryBlock();
|
|
builder.setCurrentDebugScope(
|
|
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
|
|
builder.setInsertionPoint(pbEntry);
|
|
|
|
// Get setter argument values.
|
|
// Setter type: $(inout Self, Argument) -> ()
|
|
// Pullback type (wrt self): $(inout Self') -> ()
|
|
// Pullback type (wrt both): $(inout Self') -> Argument'
|
|
assert(original.getLoweredFunctionType()->getNumParameters() == 2);
|
|
assert(pullback.getLoweredFunctionType()->getNumParameters() == 1);
|
|
assert(pullback.getLoweredFunctionType()->getNumResults() == 0 ||
|
|
pullback.getLoweredFunctionType()->getNumResults() == 1);
|
|
|
|
SILValue origArg = original.getArgumentsWithoutIndirectResults()[0];
|
|
SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1];
|
|
|
|
// Look up the corresponding field in the tangent space.
|
|
auto *origField = cast<VarDecl>(accessor->getStorage());
|
|
auto baseType = remapType(origSelf->getType()).getASTType();
|
|
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
|
|
pbLoc, getInvoker());
|
|
if (!tanField) {
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
|
|
auto adjSelf = getAdjointBuffer(origEntry, origSelf);
|
|
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
|
|
// Switch based on the property's value category.
|
|
switch (getTangentValueCategory(origArg)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt,
|
|
LoadOwnershipQualifier::Take);
|
|
setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg));
|
|
blockTemporaries[pbEntry].insert(adjArg);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
addToAdjointBuffer(origEntry, origArg, adjSelfElt, pbLoc);
|
|
builder.emitDestroyOperation(pbLoc, adjSelfElt);
|
|
break;
|
|
}
|
|
}
|
|
builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization);
|
|
|
|
return false;
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint buffer mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
SILValue PullbackCloner::Implementation::getAdjointProjection(
|
|
SILBasicBlock *origBB, SILValue originalProjection) {
|
|
// Handle `struct_element_addr`.
|
|
// Adjoint projection: a `struct_element_addr` into the base adjoint buffer.
|
|
if (auto *seai = dyn_cast<StructElementAddrInst>(originalProjection)) {
|
|
assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
|
|
"`@noDerivative` struct projections should never be active");
|
|
auto adjSource = getAdjointBuffer(origBB, seai->getOperand());
|
|
auto structType = remapType(seai->getOperand()->getType()).getASTType();
|
|
auto *tanField =
|
|
getTangentStoredProperty(getContext(), seai, structType, getInvoker());
|
|
assert(tanField && "Invalid projections should have been diagnosed");
|
|
return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField);
|
|
}
|
|
// Handle `tuple_element_addr`.
|
|
// Adjoint projection: a `tuple_element_addr` into the base adjoint buffer.
|
|
if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) {
|
|
auto source = teai->getOperand();
|
|
auto adjSource = getAdjointBuffer(origBB, source);
|
|
if (!adjSource->getType().is<TupleType>())
|
|
return adjSource;
|
|
auto origTupleTy = remapType(source->getType()).castTo<TupleType>();
|
|
unsigned adjIndex = 0;
|
|
for (unsigned i : range(teai->getFieldIndex())) {
|
|
if (getTangentSpace(
|
|
origTupleTy->getElement(i).getType()->getCanonicalType()))
|
|
++adjIndex;
|
|
}
|
|
return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex);
|
|
}
|
|
// Handle `ref_element_addr`.
|
|
// Adjoint projection: a local allocation initialized with the corresponding
|
|
// field value from the class's base adjoint value.
|
|
if (auto *reai = dyn_cast<RefElementAddrInst>(originalProjection)) {
|
|
assert(!reai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
|
|
"`@noDerivative` class projections should never be active");
|
|
auto loc = reai->getLoc();
|
|
// Get the class operand, stripping `begin_borrow`.
|
|
auto classOperand = stripBorrow(reai->getOperand());
|
|
auto classType = remapType(reai->getOperand()->getType()).getASTType();
|
|
auto *tanField =
|
|
getTangentStoredProperty(getContext(), reai->getField(), classType,
|
|
reai->getLoc(), getInvoker());
|
|
assert(tanField && "Invalid projections should have been diagnosed");
|
|
// Create a local allocation for the element adjoint buffer.
|
|
auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
|
|
auto eltTanSILType =
|
|
remapType(SILType::getPrimitiveAddressType(eltTanType));
|
|
auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc);
|
|
// Check the class operand's `TangentVector` value category.
|
|
switch (getTangentValueCategory(classOperand)) {
|
|
case SILValueCategory::Object: {
|
|
// Get the class operand's adjoint value. Currently, it must be a
|
|
// `TangentVector` struct.
|
|
auto adjClass =
|
|
materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc);
|
|
builder.emitScopedBorrowOperation(
|
|
loc, adjClass, [&](SILValue borrowedAdjClass) {
|
|
// Initialize the element adjoint buffer with the base adjoint
|
|
// value.
|
|
auto *adjElt =
|
|
builder.createStructExtract(loc, borrowedAdjClass, tanField);
|
|
auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt);
|
|
builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer,
|
|
StoreOwnershipQualifier::Init);
|
|
});
|
|
return eltAdjBuffer;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
// Get the class operand's adjoint buffer. Currently, it must be a
|
|
// `TangentVector` struct.
|
|
auto adjClass = getAdjointBuffer(origBB, classOperand);
|
|
// Initialize the element adjoint buffer with the base adjoint buffer.
|
|
auto *adjElt = builder.createStructElementAddr(loc, adjClass, tanField);
|
|
builder.createCopyAddr(loc, adjElt, eltAdjBuffer, IsNotTake,
|
|
IsInitialization);
|
|
return eltAdjBuffer;
|
|
}
|
|
}
|
|
}
|
|
// Handle `begin_access`.
|
|
// Adjoint projection: the base adjoint buffer itself.
|
|
if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) {
|
|
auto adjBase = getAdjointBuffer(origBB, bai->getOperand());
|
|
if (errorOccurred)
|
|
return (bufferMap[{origBB, originalProjection}] = SILValue());
|
|
// Return the base buffer's adjoint buffer.
|
|
return adjBase;
|
|
}
|
|
// Handle `array.uninitialized_intrinsic` application element addresses.
|
|
// Adjoint projection: a local allocation initialized by applying
|
|
// `Array.TangentVector.subscript` to the base array's adjoint value.
|
|
auto *ai =
|
|
getAllocateUninitializedArrayIntrinsicElementAddress(originalProjection);
|
|
auto *definingInst = dyn_cast_or_null<SingleValueInstruction>(
|
|
originalProjection->getDefiningInstruction());
|
|
bool isAllocateUninitializedArrayIntrinsicElementAddress =
|
|
ai && definingInst &&
|
|
(isa<PointerToAddressInst>(definingInst) ||
|
|
isa<IndexAddrInst>(definingInst));
|
|
if (isAllocateUninitializedArrayIntrinsicElementAddress) {
|
|
// Get the array element index of the result address.
|
|
int eltIndex = 0;
|
|
if (auto *iai = dyn_cast<IndexAddrInst>(definingInst)) {
|
|
auto *ili = cast<IntegerLiteralInst>(iai->getIndex());
|
|
eltIndex = ili->getValue().getLimitedValue();
|
|
}
|
|
// Get the array adjoint value.
|
|
SILValue arrayValue = getArrayValue(ai);
|
|
SILValue arrayAdjoint = materializeAdjointDirect(
|
|
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
|
|
// Apply `Array.TangentVector.subscript` to get array element adjoint value.
|
|
auto *eltAdjBuffer =
|
|
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc());
|
|
return eltAdjBuffer;
|
|
}
|
|
return SILValue();
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Adjoint value accumulation
|
|
//----------------------------------------------------------------------------//
|
|
|
|
AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
|
|
AdjointValue lhs, AdjointValue rhs, SILLocation loc) {
|
|
LLVM_DEBUG(getADDebugStream() << "Accumulating adjoint directly.\nLHS: "
|
|
<< lhs << "\nRHS: " << rhs << '\n');
|
|
switch (lhs.getKind()) {
|
|
// x
|
|
case AdjointValueKind::Concrete: {
|
|
auto lhsVal = lhs.getConcreteValue();
|
|
switch (rhs.getKind()) {
|
|
// x + y
|
|
case AdjointValueKind::Concrete: {
|
|
auto rhsVal = rhs.getConcreteValue();
|
|
auto sum = recordTemporary(builder.emitAdd(loc, lhsVal, rhsVal));
|
|
return makeConcreteAdjointValue(sum);
|
|
}
|
|
// x + 0 => x
|
|
case AdjointValueKind::Zero:
|
|
return lhs;
|
|
// x + (y, z) => (x.0 + y, x.1 + z)
|
|
case AdjointValueKind::Aggregate: {
|
|
SmallVector<AdjointValue, 8> newElements;
|
|
auto lhsTy = lhsVal->getType().getASTType();
|
|
auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal);
|
|
if (lhsTy->is<TupleType>()) {
|
|
auto elts = builder.createDestructureTuple(loc, lhsValCopy);
|
|
llvm::for_each(elts->getResults(),
|
|
[this](SILValue result) { recordTemporary(result); });
|
|
for (auto i : indices(elts->getResults())) {
|
|
auto rhsElt = rhs.getAggregateElement(i);
|
|
newElements.push_back(accumulateAdjointsDirect(
|
|
makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc));
|
|
}
|
|
} else if (lhsTy->getStructOrBoundGenericStruct()) {
|
|
auto elts =
|
|
builder.createDestructureStruct(lhsVal.getLoc(), lhsValCopy);
|
|
llvm::for_each(elts->getResults(),
|
|
[this](SILValue result) { recordTemporary(result); });
|
|
for (unsigned i : indices(elts->getResults())) {
|
|
auto rhsElt = rhs.getAggregateElement(i);
|
|
newElements.push_back(accumulateAdjointsDirect(
|
|
makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc));
|
|
}
|
|
} else {
|
|
llvm_unreachable("Not an aggregate type");
|
|
}
|
|
return makeAggregateAdjointValue(lhsVal->getType(), newElements);
|
|
}
|
|
// x + (baseAdjoint, index, eltToAdd) => (x+baseAdjoint, index, eltToAdd)
|
|
case AdjointValueKind::AddElement: {
|
|
auto *addElementValue = rhs.getAddElementValue();
|
|
auto baseAdjoint = addElementValue->baseAdjoint;
|
|
auto eltToAdd = addElementValue->eltToAdd;
|
|
|
|
auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
|
|
return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
|
|
addElementValue->fieldLocator);
|
|
}
|
|
}
|
|
}
|
|
// 0
|
|
case AdjointValueKind::Zero:
|
|
// 0 + x => x
|
|
return rhs;
|
|
// (x, y)
|
|
case AdjointValueKind::Aggregate: {
|
|
switch (rhs.getKind()) {
|
|
// (x, y) + z => (z.0 + x, z.1 + y)
|
|
case AdjointValueKind::Concrete:
|
|
return accumulateAdjointsDirect(rhs, lhs, loc);
|
|
// x + 0 => x
|
|
case AdjointValueKind::Zero:
|
|
return lhs;
|
|
// (x, y) + (z, w) => (x + z, y + w)
|
|
case AdjointValueKind::Aggregate: {
|
|
SmallVector<AdjointValue, 8> newElements;
|
|
for (auto i : range(lhs.getNumAggregateElements()))
|
|
newElements.push_back(accumulateAdjointsDirect(
|
|
lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc));
|
|
return makeAggregateAdjointValue(lhs.getType(), newElements);
|
|
}
|
|
// (x.0, ..., x.n) + (baseAdjoint, index, eltToAdd) => (x + baseAdjoint,
|
|
// index, eltToAdd)
|
|
case AdjointValueKind::AddElement: {
|
|
auto *addElementValue = rhs.getAddElementValue();
|
|
auto baseAdjoint = addElementValue->baseAdjoint;
|
|
auto eltToAdd = addElementValue->eltToAdd;
|
|
auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
|
|
|
|
return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
|
|
addElementValue->fieldLocator);
|
|
}
|
|
}
|
|
}
|
|
// (baseAdjoint, index, eltToAdd)
|
|
case AdjointValueKind::AddElement: {
|
|
switch (rhs.getKind()) {
|
|
case AdjointValueKind::Zero:
|
|
return lhs;
|
|
// (baseAdjoint, index, eltToAdd) + x => (x + baseAdjoint, index, eltToAdd)
|
|
case AdjointValueKind::Concrete:
|
|
// (baseAdjoint, index, eltToAdd) + (x.0, ..., x.n) => (x + baseAdjoint,
|
|
// index, eltToAdd)
|
|
case AdjointValueKind::Aggregate:
|
|
return accumulateAdjointsDirect(rhs, lhs, loc);
|
|
// (baseAdjoint1, index1, eltToAdd1) + (baseAdjoint2, index2, eltToAdd2)
|
|
// => ((baseAdjoint1 + baseAdjoint2, index1, eltToAdd1), index2, eltToAdd2)
|
|
case AdjointValueKind::AddElement: {
|
|
auto *addElementValueLhs = lhs.getAddElementValue();
|
|
auto baseAdjointLhs = addElementValueLhs->baseAdjoint;
|
|
auto eltToAddLhs = addElementValueLhs->eltToAdd;
|
|
|
|
auto *addElementValueRhs = rhs.getAddElementValue();
|
|
auto baseAdjointRhs = addElementValueRhs->baseAdjoint;
|
|
auto eltToAddRhs = addElementValueRhs->eltToAdd;
|
|
|
|
auto sumOfBaseAdjoints =
|
|
accumulateAdjointsDirect(baseAdjointLhs, baseAdjointRhs, loc);
|
|
auto newBaseAdjoint = makeAddElementAdjointValue(
|
|
sumOfBaseAdjoints, eltToAddLhs, addElementValueLhs->fieldLocator);
|
|
|
|
return makeAddElementAdjointValue(newBaseAdjoint, eltToAddRhs,
|
|
addElementValueRhs->fieldLocator);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Array literal initialization differentiation
|
|
//----------------------------------------------------------------------------//
|
|
|
|
void PullbackCloner::Implementation::
|
|
accumulateArrayLiteralElementAddressAdjoints(SILBasicBlock *origBB,
|
|
SILValue originalValue,
|
|
AdjointValue arrayAdjointValue,
|
|
SILLocation loc) {
|
|
// Return if the original value is not the `Array` result of an
|
|
// `array.uninitialized_intrinsic` application.
|
|
auto *dti = dyn_cast_or_null<DestructureTupleInst>(
|
|
originalValue->getDefiningInstruction());
|
|
if (!dti)
|
|
return;
|
|
if (!ArraySemanticsCall(dti->getOperand(),
|
|
semantics::ARRAY_UNINITIALIZED_INTRINSIC))
|
|
return;
|
|
if (originalValue != dti->getResult(0))
|
|
return;
|
|
// Accumulate the array's adjoint value into the adjoint buffers of its
|
|
// element addresses: `pointer_to_address` and (optionally) `index_addr`
|
|
// instructions.
|
|
// The input code looks like as follows:
|
|
// %17 = integer_literal $Builtin.Word, 1
|
|
// function_ref _allocateUninitializedArray<A>(_:)
|
|
// %18 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
|
|
// %19 = apply %18<Float>(%17) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
|
|
// (%20, %21) = destructure_tuple %19
|
|
// %22 = mark_dependence %21 on %20
|
|
// %23 = pointer_to_address %22 to [strict] $*Float
|
|
// store %0 to [trivial] %23
|
|
// function_ref _finalizeUninitializedArray<A>(_:)
|
|
// %25 = function_ref @$ss27_finalizeUninitializedArrayySayxGABnlF : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0>
|
|
// %26 = apply %25<Float>(%20) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // user: %27
|
|
// Note that %20 and %21 are in some sense "aliases" for each other. Here our `originalValue` is %20 in the code above.
|
|
// We need to trace from %21 down to %23 and propagate (decomposed) adjoint of originalValue to adjoint of %23.
|
|
// Then the generic adjoint propagation code would do its job to propagate %23' to %0'.
|
|
// If we're initializing multiple values we're having additional `index_addr` instructions, but
|
|
// the handling is similar.
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Accumulating adjoint value for array literal into element "
|
|
"address adjoint buffers"
|
|
<< originalValue);
|
|
auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc);
|
|
builder.setCurrentDebugScope(remapScope(dti->getDebugScope()));
|
|
for (auto use : dti->getResult(1)->getUses()) {
|
|
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser());
|
|
assert(mdi && "Expected mark_dependence user");
|
|
auto *ptai =
|
|
dyn_cast_or_null<PointerToAddressInst>(getSingleNonDebugUser(mdi));
|
|
assert(ptai && "Expected pointer_to_address user");
|
|
auto adjBuf = getAdjointBuffer(origBB, ptai);
|
|
auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc);
|
|
builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
|
|
for (auto use : ptai->getUses()) {
|
|
if (auto *iai = dyn_cast<IndexAddrInst>(use->getUser())) {
|
|
auto *ili = cast<IntegerLiteralInst>(iai->getIndex());
|
|
auto eltIndex = ili->getValue().getLimitedValue();
|
|
auto adjBuf = getAdjointBuffer(origBB, iai);
|
|
auto *eltAdjBuf =
|
|
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc);
|
|
builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
AllocStackInst *PullbackCloner::Implementation::getArrayAdjointElementBuffer(
|
|
SILValue arrayAdjoint, int eltIndex, SILLocation loc) {
|
|
auto &ctx = builder.getASTContext();
|
|
auto arrayTanType = cast<StructType>(arrayAdjoint->getType().getASTType());
|
|
auto arrayType = arrayTanType->getParent()->castTo<BoundGenericStructType>();
|
|
auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType();
|
|
auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType));
|
|
// Get `function_ref` and generic signature of
|
|
// `Array.TangentVector.subscript.getter`.
|
|
auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct();
|
|
auto subscriptLookup =
|
|
arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript());
|
|
SubscriptDecl *subscriptDecl = nullptr;
|
|
for (auto *candidate : subscriptLookup) {
|
|
auto candidateModule = candidate->getModuleContext();
|
|
if (candidateModule->getName() == ctx.Id_Differentiation ||
|
|
candidateModule->isStdlibModule()) {
|
|
assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s");
|
|
subscriptDecl = cast<SubscriptDecl>(candidate);
|
|
#ifdef NDEBUG
|
|
break;
|
|
#endif
|
|
}
|
|
}
|
|
assert(subscriptDecl && "No `Array.TangentVector.subscript`");
|
|
auto *subscriptGetterDecl =
|
|
subscriptDecl->getOpaqueAccessor(AccessorKind::Get);
|
|
assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter");
|
|
SILOptFunctionBuilder fb(getContext().getTransform());
|
|
auto *subscriptGetterFn = fb.getOrCreateFunction(
|
|
loc, SILDeclRef(subscriptGetterDecl), NotForDefinition);
|
|
// %subscript_fn = function_ref @Array.TangentVector<T>.subscript.getter
|
|
auto *subscriptFnRef = builder.createFunctionRef(loc, subscriptGetterFn);
|
|
auto subscriptFnGenSig =
|
|
subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature();
|
|
// Apply `Array.TangentVector.subscript.getter` to get array element adjoint
|
|
// buffer.
|
|
// %index_literal = integer_literal $Builtin.IntXX, <index>
|
|
auto builtinIntType =
|
|
SILType::getPrimitiveObjectType(ctx.getIntDecl()
|
|
->getStoredProperties()
|
|
.front()
|
|
->getInterfaceType()
|
|
->getCanonicalType());
|
|
auto *eltIndexLiteral =
|
|
builder.createIntegerLiteral(loc, builtinIntType, eltIndex);
|
|
auto intType = SILType::getPrimitiveObjectType(
|
|
ctx.getIntType()->getCanonicalType());
|
|
// %index_int = struct $Int (%index_literal)
|
|
auto *eltIndexInt = builder.createStruct(loc, intType, {eltIndexLiteral});
|
|
auto *diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
|
|
auto diffConf = lookupConformance(eltTanType, diffProto);
|
|
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
|
|
auto *addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
|
|
auto addArithConf = lookupConformance(eltTanType, addArithProto);
|
|
assert(!addArithConf.isInvalid() &&
|
|
"Missing conformance to `AdditiveArithmetic`");
|
|
auto subMap = SubstitutionMap::get(subscriptFnGenSig, {eltTanType},
|
|
{addArithConf, diffConf});
|
|
// %elt_adj = alloc_stack $T.TangentVector
|
|
// Create and register a local allocation.
|
|
auto *eltAdjBuffer = createFunctionLocalAllocation(
|
|
eltTanSILType, loc, /*zeroInitialize*/ true);
|
|
// Immediately destroy the emitted zero value.
|
|
// NOTE: It is not efficient to emit a zero value then immediately destroy
|
|
// it. However, it was the easiest way to to avoid "lifetime mismatch in
|
|
// predecessors" memory lifetime verification errors for control flow
|
|
// differentiation.
|
|
// Perhaps we can avoid emitting a zero value if local allocations are created
|
|
// per pullback bb instead of all in the pullback entry: TF-1075.
|
|
builder.emitDestroyOperation(loc, eltAdjBuffer);
|
|
// apply %subscript_fn<T.TangentVector>(%elt_adj, %index_int, %array_adj)
|
|
builder.createApply(loc, subscriptFnRef, subMap,
|
|
{eltAdjBuffer, eltIndexInt, arrayAdjoint});
|
|
return eltAdjBuffer;
|
|
}
|
|
|
|
} // end namespace autodiff
|
|
} // end namespace swift
|