mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Some fixes for coroutines with normal results and `partial_apply` of coroutines were required. Fixes #55084
3876 lines
163 KiB
C++
3876 lines
163 KiB
C++
//===--- PullbackCloner.cpp - Pullback function generation ---*- C++ -*----===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines a helper class for generating pullback functions for
|
|
// automatic differentiation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define DEBUG_TYPE "differentiation"
|
|
|
|
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
|
|
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
|
|
#include "swift/SILOptimizer/Differentiation/ADContext.h"
|
|
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
|
|
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
|
|
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
|
|
#include "swift/SILOptimizer/Differentiation/Thunk.h"
|
|
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
|
|
|
|
#include "swift/AST/ConformanceLookup.h"
|
|
#include "swift/AST/Expr.h"
|
|
#include "swift/AST/PropertyWrappers.h"
|
|
#include "swift/AST/TypeCheckRequests.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
#include "swift/Basic/STLExtras.h"
|
|
#include "swift/SIL/ApplySite.h"
|
|
#include "swift/SIL/InstructionUtils.h"
|
|
#include "swift/SIL/Projection.h"
|
|
#include "swift/SIL/TypeSubstCloner.h"
|
|
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
|
|
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
|
|
namespace swift {
|
|
|
|
class SILDifferentiabilityWitness;
|
|
class SILBasicBlock;
|
|
class SILFunction;
|
|
class SILInstruction;
|
|
|
|
namespace autodiff {
|
|
|
|
class ADContext;
|
|
class VJPCloner;
|
|
|
|
/// The implementation class for `PullbackCloner`.
|
|
///
|
|
/// The implementation class is a `SILInstructionVisitor`. Effectively, it acts
|
|
/// as a `SILCloner` that visits basic blocks in post-order and that visits
|
|
/// instructions per basic block in reverse order. This visitation order is
|
|
/// necessary for generating pullback functions, whose control flow graph is
|
|
/// ~a transposed version of the original function's control flow graph.
|
|
class PullbackCloner::Implementation final
|
|
: public SILInstructionVisitor<PullbackCloner::Implementation> {
|
|
|
|
public:
|
|
explicit Implementation(VJPCloner &vjpCloner);
|
|
|
|
private:
|
|
/// The parent VJP cloner.
|
|
VJPCloner &vjpCloner;
|
|
|
|
/// Dominance info for the original function.
|
|
DominanceInfo *domInfo = nullptr;
|
|
|
|
/// Post-dominance info for the original function.
|
|
PostDominanceInfo *postDomInfo = nullptr;
|
|
|
|
/// Post-order info for the original function.
|
|
PostOrderFunctionInfo *postOrderInfo = nullptr;
|
|
|
|
/// Mapping from original basic blocks to corresponding pullback basic blocks.
|
|
/// Pullback basic blocks always have the predecessor as the single argument.
|
|
llvm::DenseMap<SILBasicBlock *, SILBasicBlock *> pullbackBBMap;
|
|
|
|
/// Mapping from original basic blocks and original values to corresponding
|
|
/// adjoint values.
|
|
llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, AdjointValue> valueMap;
|
|
|
|
/// Mapping from original basic blocks and original values to corresponding
|
|
/// adjoint buffers.
|
|
llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap;
|
|
|
|
/// Mapping from pullback struct field declarations to pullback struct
|
|
/// elements destructured from the linear map basic block argument. In the
|
|
/// beginning of each pullback basic block, the block's pullback struct is
|
|
/// destructured into individual elements stored here.
|
|
llvm::DenseMap<SILBasicBlock*, SmallVector<SILValue, 4>> pullbackTupleElements;
|
|
|
|
/// Mapping from original basic blocks and successor basic blocks to
|
|
/// corresponding pullback trampoline basic blocks. Trampoline basic blocks
|
|
/// take additional arguments in addition to the predecessor enum argument.
|
|
llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, SILBasicBlock *>
|
|
pullbackTrampolineBBMap;
|
|
|
|
/// Mapping from original basic blocks to dominated active values.
|
|
llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues;
|
|
|
|
/// Mapping from original basic blocks and original active values to
|
|
/// corresponding pullback block arguments.
|
|
llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *>
|
|
activeValuePullbackBBArgumentMap;
|
|
|
|
/// Mapping from original basic blocks to local temporary values to be cleaned
|
|
/// up. This is populated when pullback emission is run on one basic block and
|
|
/// cleaned before processing another basic block.
|
|
llvm::DenseMap<SILBasicBlock *, llvm::SmallSetVector<SILValue, 32>>
|
|
blockTemporaries;
|
|
|
|
/// The scope cloner.
|
|
ScopeCloner scopeCloner;
|
|
|
|
/// The main builder.
|
|
TangentBuilder builder;
|
|
|
|
/// An auxiliary local allocation builder.
|
|
TangentBuilder localAllocBuilder;
|
|
|
|
/// The original function's exit block.
|
|
SILBasicBlock *originalExitBlock = nullptr;
|
|
|
|
/// Stack buffers allocated for storing local adjoint values.
|
|
SmallVector<AllocStackInst *, 64> functionLocalAllocations;
|
|
|
|
/// Copies created to deal with destructive enum operations
|
|
/// (unchecked_take_enum_addr)
|
|
llvm::SmallDenseMap<InitEnumDataAddrInst*, SILValue> enumDataAdjCopies;
|
|
|
|
/// A set used to remember local allocations that were destroyed.
|
|
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
|
|
|
|
/// The seed arguments of the pullback function.
|
|
SmallVector<SILArgument *, 4> seeds;
|
|
|
|
/// The `AutoDiffLinearMapContext` object, if any.
|
|
SILValue contextValue = nullptr;
|
|
|
|
llvm::BumpPtrAllocator allocator;
|
|
|
|
bool errorOccurred = false;
|
|
|
|
ADContext &getContext() const { return vjpCloner.getContext(); }
|
|
SILModule &getModule() const { return getContext().getModule(); }
|
|
ASTContext &getASTContext() const { return getPullback().getASTContext(); }
|
|
SILFunction &getOriginal() const { return vjpCloner.getOriginal(); }
|
|
SILDifferentiabilityWitness *getWitness() const {
|
|
return vjpCloner.getWitness();
|
|
}
|
|
DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); }
|
|
LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); }
|
|
const AutoDiffConfig &getConfig() const { return vjpCloner.getConfig(); }
|
|
const DifferentiableActivityInfo &getActivityInfo() const {
|
|
return vjpCloner.getActivityInfo();
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Pullback struct mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
void initializePullbackTupleElements(SILBasicBlock *origBB,
|
|
SILInstructionResultArray values) {
|
|
auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB);
|
|
assert(pbTupleTyple->getNumElements() == values.size() &&
|
|
"The number of pullback tuple fields must equal the number of "
|
|
"pullback tuple element values");
|
|
auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }});
|
|
(void)res;
|
|
assert(res.second && "A pullback tuple element already exists!");
|
|
}
|
|
|
|
void initializePullbackTupleElements(SILBasicBlock *origBB,
|
|
const llvm::ArrayRef<SILArgument *> &values) {
|
|
auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB);
|
|
assert(pbTupleTyple->getNumElements() == values.size() &&
|
|
"The number of pullback tuple fields must equal the number of "
|
|
"pullback tuple element values");
|
|
auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }});
|
|
(void)res;
|
|
assert(res.second && "A pullback struct element already exists!");
|
|
}
|
|
|
|
/// Returns the pullback tuple element value corresponding to the given
|
|
/// original block and apply inst.
|
|
SILValue getPullbackTupleElement(FullApplySite fai) {
|
|
unsigned idx = getPullbackInfo().lookUpLinearMapIndex(fai);
|
|
assert((idx > 0 || (idx == 0 && fai.getParent()->isEntry())) &&
|
|
"impossible linear map index");
|
|
auto values = pullbackTupleElements.lookup(fai.getParent());
|
|
assert(idx < values.size() &&
|
|
"pullback tuple element for this apply does not exist!");
|
|
return values[idx];
|
|
}
|
|
|
|
/// Returns the pullback tuple element value corresponding to the predecessor
|
|
/// for the given original block.
|
|
SILValue getPullbackPredTupleElement(SILBasicBlock *origBB) {
|
|
assert(!origBB->isEntry() && "no predecessors for entry block");
|
|
auto values = pullbackTupleElements.lookup(origBB);
|
|
assert(values.size() && "pullback tuple cannot be empty");
|
|
return values[0];
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Type transformer
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Get the type lowering for the given AST type.
|
|
///
|
|
/// Explicitly use minimal type expansion context: in general, differentiation
|
|
/// happens on function types, so it cannot know if the original function is
|
|
/// resilient or not.
|
|
const Lowering::TypeLowering &getTypeLowering(Type type) {
|
|
auto pbGenSig =
|
|
getPullback().getLoweredFunctionType()->getSubstGenericSignature();
|
|
Lowering::AbstractionPattern pattern(pbGenSig,
|
|
type->getReducedType(pbGenSig));
|
|
return getContext().getTypeConverter().getTypeLowering(
|
|
pattern, type, TypeExpansionContext::minimal());
|
|
}
|
|
|
|
/// Remap any archetypes into the current function's context.
|
|
SILType remapType(SILType ty) {
|
|
if (ty.hasArchetype())
|
|
ty = ty.mapTypeOutOfContext();
|
|
auto remappedType = ty.getASTType()->getReducedType(
|
|
getPullback().getLoweredFunctionType()->getSubstGenericSignature());
|
|
auto remappedSILType =
|
|
SILType::getPrimitiveType(remappedType, ty.getCategory());
|
|
// FIXME: Sometimes getPullback() doesn't have a generic environment, in which
|
|
// case callers are apparently happy to receive an interface type.
|
|
if (getPullback().getGenericEnvironment())
|
|
return getPullback().mapTypeIntoContext(remappedSILType);
|
|
return remappedSILType;
|
|
}
|
|
|
|
std::optional<TangentSpace> getTangentSpace(CanType type) {
|
|
// Use witness generic signature to remap types.
|
|
type =
|
|
getWitness()->getDerivativeGenericSignature().getReducedType(
|
|
type);
|
|
return type->getAutoDiffTangentSpace(
|
|
LookUpConformanceInModule());
|
|
}
|
|
|
|
/// Returns the tangent value category of the given value.
|
|
SILValueCategory getTangentValueCategory(SILValue v) {
|
|
// Tangent value category table:
|
|
//
|
|
// Let $L be a loadable type and $*A be an address-only type.
|
|
//
|
|
// Original type | Tangent type loadable? | Tangent value category and type
|
|
// --------------|------------------------|--------------------------------
|
|
// $L | loadable | object, $L' (no mismatch)
|
|
// $*A | loadable | address, $*L' (create a buffer)
|
|
// $L | address-only | address, $*A' (no alternative)
|
|
// $*A | address-only | address, $*A' (no alternative)
|
|
|
|
// TODO(https://github.com/apple/swift/issues/55523): Make "tangent value category" depend solely on whether the tangent type is loadable or address-only.
|
|
//
|
|
// For loadable tangent types, using symbolic adjoint values instead of
|
|
// concrete adjoint buffers is more efficient.
|
|
|
|
// Quick check: if the value has an address type, the tangent value category
|
|
// is currently always "address".
|
|
if (v->getType().isAddress())
|
|
return SILValueCategory::Address;
|
|
// If the value has an object type and the tangent type is not address-only,
|
|
// then the tangent value category is "object".
|
|
auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType());
|
|
auto tanASTType = tanSpace->getCanonicalType();
|
|
if (v->getType().isObject() && getTypeLowering(tanASTType).isLoadable())
|
|
return SILValueCategory::Object;
|
|
// Otherwise, the tangent value category is "address".
|
|
return SILValueCategory::Address;
|
|
}
|
|
|
|
/// Assuming the given type conforms to `Differentiable` after remapping,
|
|
/// returns the associated tangent space type.
|
|
SILType getRemappedTangentType(SILType type) {
|
|
return SILType::getPrimitiveType(
|
|
getTangentSpace(remapType(type).getASTType())->getCanonicalType(),
|
|
type.getCategory());
|
|
}
|
|
|
|
/// Substitutes all replacement types of the given substitution map using the
|
|
/// pullback function's substitution map.
|
|
SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) {
|
|
return substMap.subst(getPullback().getForwardingSubstitutionMap());
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Temporary value management
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Record a temporary value for cleanup before its block's terminator.
|
|
SILValue recordTemporary(SILValue value) {
|
|
assert(value->getType().isObject());
|
|
assert(value->getFunction() == &getPullback());
|
|
auto inserted = blockTemporaries[value->getParentBlock()].insert(value);
|
|
(void)inserted;
|
|
LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value);
|
|
assert(inserted && "Temporary already recorded?");
|
|
return value;
|
|
}
|
|
|
|
/// Clean up all temporary values for the given pullback block.
|
|
void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) {
|
|
assert(bb->getParent() == &getPullback());
|
|
LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb"
|
|
<< bb->getDebugID() << '\n');
|
|
for (auto temp : blockTemporaries[bb])
|
|
builder.emitDestroyValueOperation(loc, temp);
|
|
blockTemporaries[bb].clear();
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint value factory methods
|
|
//--------------------------------------------------------------------------//
|
|
|
|
AdjointValue makeZeroAdjointValue(SILType type) {
|
|
return AdjointValue::createZero(allocator, remapType(type));
|
|
}
|
|
|
|
AdjointValue makeConcreteAdjointValue(SILValue value) {
|
|
return AdjointValue::createConcrete(allocator, value);
|
|
}
|
|
|
|
AdjointValue makeAggregateAdjointValue(SILType type,
|
|
ArrayRef<AdjointValue> elements) {
|
|
return AdjointValue::createAggregate(allocator, remapType(type), elements);
|
|
}
|
|
|
|
AdjointValue makeAddElementAdjointValue(AdjointValue baseAdjoint,
|
|
AdjointValue eltToAdd,
|
|
FieldLocator fieldLocator) {
|
|
auto *addElementValue =
|
|
new AddElementValue(baseAdjoint, eltToAdd, fieldLocator);
|
|
return AdjointValue::createAddElement(allocator, baseAdjoint.getType(),
|
|
addElementValue);
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint value materialization
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Materializes an adjoint value. The type of the given adjoint value must be
|
|
/// loadable.
|
|
SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) {
|
|
assert(val.getType().isObject());
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Materializing adjoint for " << val << '\n');
|
|
SILValue result;
|
|
switch (val.getKind()) {
|
|
case AdjointValueKind::Zero:
|
|
result = recordTemporary(builder.emitZero(loc, val.getSwiftType()));
|
|
break;
|
|
case AdjointValueKind::Aggregate: {
|
|
SmallVector<SILValue, 8> elements;
|
|
for (auto i : range(val.getNumAggregateElements())) {
|
|
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
|
|
elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
|
|
}
|
|
if (val.getType().is<TupleType>())
|
|
result = recordTemporary(
|
|
builder.createTuple(loc, val.getType(), elements));
|
|
else
|
|
result = recordTemporary(
|
|
builder.createStruct(loc, val.getType(), elements));
|
|
break;
|
|
}
|
|
case AdjointValueKind::Concrete:
|
|
result = val.getConcreteValue();
|
|
break;
|
|
case AdjointValueKind::AddElement: {
|
|
auto adjointSILType = val.getAddElementValue()->baseAdjoint.getType();
|
|
auto *baseAdjAlloc = builder.createAllocStack(loc, adjointSILType);
|
|
materializeAdjointIndirect(val, baseAdjAlloc, loc);
|
|
|
|
auto baseAdjConcrete = recordTemporary(builder.emitLoadValueOperation(
|
|
loc, baseAdjAlloc, LoadOwnershipQualifier::Take));
|
|
|
|
builder.createDeallocStack(loc, baseAdjAlloc);
|
|
|
|
result = baseAdjConcrete;
|
|
break;
|
|
}
|
|
}
|
|
if (auto debugInfo = val.getDebugInfo())
|
|
builder.createDebugValue(
|
|
debugInfo->first.getLocation(), result, debugInfo->second);
|
|
return result;
|
|
}
|
|
|
|
/// Materializes an adjoint value indirectly to a SIL buffer.
|
|
void materializeAdjointIndirect(AdjointValue val, SILValue destAddress,
|
|
SILLocation loc) {
|
|
assert(destAddress->getType().isAddress());
|
|
switch (val.getKind()) {
|
|
/// If adjoint value is a symbolic zero, emit a call to
|
|
/// `AdditiveArithmetic.zero`.
|
|
case AdjointValueKind::Zero:
|
|
builder.emitZeroIntoBuffer(loc, destAddress, IsInitialization);
|
|
break;
|
|
/// If adjoint value is a symbolic aggregate (tuple or struct), recursively
|
|
/// materialize the symbolic tuple or struct, filling the
|
|
/// buffer.
|
|
case AdjointValueKind::Aggregate: {
|
|
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
|
|
for (auto idx : range(val.getNumAggregateElements())) {
|
|
auto eltTy = SILType::getPrimitiveAddressType(
|
|
tupTy->getElementType(idx)->getCanonicalType());
|
|
auto *eltBuf =
|
|
builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
|
|
materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc);
|
|
}
|
|
} else if (auto *structDecl =
|
|
val.getSwiftType()->getStructOrBoundGenericStruct()) {
|
|
auto fieldIt = structDecl->getStoredProperties().begin();
|
|
for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
|
|
++fieldIt, ++i) {
|
|
auto eltBuf =
|
|
builder.createStructElementAddr(loc, destAddress, *fieldIt);
|
|
materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc);
|
|
}
|
|
} else {
|
|
llvm_unreachable("Not an aggregate type");
|
|
}
|
|
break;
|
|
}
|
|
/// If adjoint value is concrete, it is already materialized. Store it in
|
|
/// the destination address.
|
|
case AdjointValueKind::Concrete: {
|
|
auto concreteVal = val.getConcreteValue();
|
|
auto copyOfConcreteVal = builder.emitCopyValueOperation(loc, concreteVal);
|
|
builder.emitStoreValueOperation(loc, copyOfConcreteVal, destAddress,
|
|
StoreOwnershipQualifier::Init);
|
|
break;
|
|
}
|
|
case AdjointValueKind::AddElement: {
|
|
auto baseAdjoint = val;
|
|
auto baseAdjointType = baseAdjoint.getType();
|
|
|
|
// Current adjoint may be made up of layers of `AddElement` adjoints.
|
|
// We can iteratively gather the list of elements to add instead of making
|
|
// recursive calls to `materializeAdjointIndirect`.
|
|
SmallVector<AddElementValue *, 4> addEltAdjValues;
|
|
|
|
do {
|
|
auto addElementValue = baseAdjoint.getAddElementValue();
|
|
addEltAdjValues.push_back(addElementValue);
|
|
baseAdjoint = addElementValue->baseAdjoint;
|
|
assert(baseAdjointType == baseAdjoint.getType());
|
|
} while (baseAdjoint.getKind() == AdjointValueKind::AddElement);
|
|
|
|
materializeAdjointIndirect(baseAdjoint, destAddress, loc);
|
|
|
|
for (auto *addElementValue : addEltAdjValues) {
|
|
auto eltToAdd = addElementValue->eltToAdd;
|
|
|
|
SILValue baseAdjEltAddr;
|
|
if (baseAdjoint.getType().is<TupleType>()) {
|
|
baseAdjEltAddr = builder.createTupleElementAddr(
|
|
loc, destAddress, addElementValue->getFieldIndex());
|
|
} else {
|
|
baseAdjEltAddr = builder.createStructElementAddr(
|
|
loc, destAddress, addElementValue->getFieldDecl());
|
|
}
|
|
|
|
auto eltToAddMaterialized = materializeAdjointDirect(eltToAdd, loc);
|
|
// Copy `eltToAddMaterialized` so we have a value with owned ownership
|
|
// semantics, required for using `eltToAddMaterialized` in a `store`
|
|
// instruction.
|
|
auto eltToAddMaterializedCopy =
|
|
builder.emitCopyValueOperation(loc, eltToAddMaterialized);
|
|
auto *eltToAddAlloc = builder.createAllocStack(loc, eltToAdd.getType());
|
|
builder.emitStoreValueOperation(loc, eltToAddMaterializedCopy,
|
|
eltToAddAlloc,
|
|
StoreOwnershipQualifier::Init);
|
|
|
|
builder.emitInPlaceAdd(loc, baseAdjEltAddr, eltToAddAlloc);
|
|
builder.createDestroyAddr(loc, eltToAddAlloc);
|
|
builder.createDeallocStack(loc, eltToAddAlloc);
|
|
}
|
|
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint value mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Returns true if the given value in the original function has a
|
|
/// corresponding adjoint value.
|
|
bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const {
|
|
assert(origBB->getParent() == &getOriginal());
|
|
assert(originalValue->getType().isObject());
|
|
return valueMap.count({origBB, originalValue});
|
|
}
|
|
|
|
/// Initializes the adjoint value for the original value. Asserts that the
|
|
/// original value does not already have an adjoint value.
|
|
void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
|
|
AdjointValue adjointValue) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Setting adjoint value for " << originalValue);
|
|
assert(origBB->getParent() == &getOriginal());
|
|
assert(originalValue->getType().isObject());
|
|
assert(getTangentValueCategory(originalValue) == SILValueCategory::Object);
|
|
assert(adjointValue.getType().isObject());
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
// The adjoint value must be in the tangent space.
|
|
assert(adjointValue.getType() ==
|
|
getRemappedTangentType(originalValue->getType()));
|
|
// Try to assign a debug variable.
|
|
if (auto debugInfo = findDebugLocationAndVariable(originalValue)) {
|
|
LLVM_DEBUG({
|
|
auto &s = getADDebugStream();
|
|
s << "Found debug variable: \"" << debugInfo->second.Name
|
|
<< "\"\nLocation: ";
|
|
debugInfo->first.getLocation().print(s, getASTContext().SourceMgr);
|
|
s << '\n';
|
|
});
|
|
adjointValue.setDebugInfo(*debugInfo);
|
|
} else {
|
|
LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n");
|
|
}
|
|
// Insert into dictionary.
|
|
auto insertion =
|
|
valueMap.try_emplace({origBB, originalValue}, adjointValue);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "The new adjoint value, replacing the existing one, is: "
|
|
<< insertion.first->getSecond() << '\n');
|
|
if (!insertion.second)
|
|
insertion.first->getSecond() = adjointValue;
|
|
}
|
|
|
|
/// Returns the adjoint value for a value in the original function.
|
|
///
|
|
/// This method first tries to find an existing entry in the adjoint value
|
|
/// mapping. If no entry exists, creates a zero adjoint value.
|
|
AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) {
|
|
assert(origBB->getParent() == &getOriginal());
|
|
assert(originalValue->getType().isObject());
|
|
assert(getTangentValueCategory(originalValue) == SILValueCategory::Object);
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
auto insertion = valueMap.try_emplace(
|
|
{origBB, originalValue},
|
|
makeZeroAdjointValue(getRemappedTangentType(originalValue->getType())));
|
|
auto it = insertion.first;
|
|
return it->getSecond();
|
|
}
|
|
|
|
/// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets
|
|
/// the sum as the new adjoint value.
|
|
void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
|
|
AdjointValue newAdjointValue, SILLocation loc) {
|
|
assert(origBB->getParent() == &getOriginal());
|
|
assert(originalValue->getType().isObject());
|
|
assert(newAdjointValue.getType().isObject());
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Adding adjoint value for " << originalValue);
|
|
// The adjoint value must be in the tangent space.
|
|
assert(newAdjointValue.getType() ==
|
|
getRemappedTangentType(originalValue->getType()));
|
|
// Try to assign a debug variable.
|
|
if (auto debugInfo = findDebugLocationAndVariable(originalValue)) {
|
|
LLVM_DEBUG({
|
|
auto &s = getADDebugStream();
|
|
s << "Found debug variable: \"" << debugInfo->second.Name
|
|
<< "\"\nLocation: ";
|
|
debugInfo->first.getLocation().print(s, getASTContext().SourceMgr);
|
|
s << '\n';
|
|
});
|
|
newAdjointValue.setDebugInfo(*debugInfo);
|
|
} else {
|
|
LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n");
|
|
}
|
|
auto insertion =
|
|
valueMap.try_emplace({origBB, originalValue}, newAdjointValue);
|
|
auto inserted = insertion.second;
|
|
if (inserted)
|
|
return;
|
|
// If adjoint already exists, accumulate the adjoint onto the existing
|
|
// adjoint.
|
|
auto it = insertion.first;
|
|
auto existingValue = it->getSecond();
|
|
valueMap.erase(it);
|
|
auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc);
|
|
// If the original value is the `Array` result of an
|
|
// `array.uninitialized_intrinsic` application, accumulate adjoint buffers
|
|
// for the array element addresses.
|
|
accumulateArrayLiteralElementAddressAdjoints(origBB, originalValue, adjVal,
|
|
loc);
|
|
setAdjointValue(origBB, originalValue, adjVal);
|
|
}
|
|
|
|
/// Get the pullback block argument corresponding to the given original block
|
|
/// and active value.
|
|
SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB,
|
|
SILValue activeValue) {
|
|
assert(getTangentValueCategory(activeValue) == SILValueCategory::Object);
|
|
assert(origBB->getParent() == &getOriginal());
|
|
auto pullbackBBArg =
|
|
activeValuePullbackBBArgumentMap[{origBB, activeValue}];
|
|
assert(pullbackBBArg);
|
|
assert(pullbackBBArg->getParent() == getPullbackBlock(origBB));
|
|
return pullbackBBArg;
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint value accumulation
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Given two adjoint values, accumulates them and returns their sum.
|
|
AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs,
|
|
SILLocation loc);
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint buffer mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// If the given original value is an address projection, returns a
|
|
/// corresponding adjoint projection to be used as its adjoint buffer.
|
|
///
|
|
/// Helper function for `getAdjointBuffer`.
|
|
SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue);
|
|
|
|
/// Returns the adjoint buffer for the original value.
|
|
///
|
|
/// This method first tries to find an existing entry in the adjoint buffer
|
|
/// mapping. If no entry exists, creates a zero adjoint buffer.
|
|
SILValue getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) {
|
|
assert(getTangentValueCategory(originalValue) == SILValueCategory::Address);
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue());
|
|
if (!insertion.second) // not inserted
|
|
return insertion.first->getSecond();
|
|
|
|
// If the original buffer is a projection, return a corresponding projection
|
|
// into the adjoint buffer.
|
|
if (auto adjProj = getAdjointProjection(origBB, originalValue))
|
|
return (bufferMap[{origBB, originalValue}] = adjProj);
|
|
|
|
LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for "
|
|
<< originalValue
|
|
<< "in bb" << origBB->getDebugID() << '\n');
|
|
|
|
auto bufType = getRemappedTangentType(originalValue->getType());
|
|
// Set insertion point for local allocation builder: before the last local
|
|
// allocation, or at the start of the pullback function's entry if no local
|
|
// allocations exist yet.
|
|
auto debugInfo = findDebugLocationAndVariable(originalValue);
|
|
SILLocation loc = debugInfo ? debugInfo->first.getLocation()
|
|
: RegularLocation::getAutoGeneratedLocation();
|
|
llvm::SmallString<32> adjName;
|
|
auto *newBuf = createFunctionLocalAllocation(
|
|
bufType, loc, /*zeroInitialize*/ true,
|
|
swift::transform(debugInfo,
|
|
[&](AdjointValue::DebugInfo di) {
|
|
llvm::raw_svector_ostream adjNameStream(adjName);
|
|
SILDebugVariable &dv = di.second;
|
|
dv.ArgNo = 0;
|
|
adjNameStream << "derivative of '" << dv.Name << "'";
|
|
if (SILDebugLocation origBBLoc = origBB->front().getDebugLocation()) {
|
|
adjNameStream << " in scope at ";
|
|
origBBLoc.getLocation().print(adjNameStream, getASTContext().SourceMgr);
|
|
}
|
|
adjNameStream << " (scope #" << origBB->getDebugID() << ")";
|
|
dv.Name = adjName;
|
|
// We have no meaningful debug location, and the type is different.
|
|
dv.Scope = nullptr;
|
|
dv.Loc = {};
|
|
dv.Type = {};
|
|
dv.DIExpr = {};
|
|
return dv;
|
|
}));
|
|
return (insertion.first->getSecond() = newBuf);
|
|
}
|
|
|
|
/// Initializes the adjoint buffer for the original value. Asserts that the
|
|
/// original value does not already have an adjoint buffer.
|
|
void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue,
|
|
SILValue adjointBuffer) {
|
|
assert(getTangentValueCategory(originalValue) == SILValueCategory::Address);
|
|
auto insertion =
|
|
bufferMap.try_emplace({origBB, originalValue}, adjointBuffer);
|
|
assert(insertion.second && "Adjoint buffer already exists");
|
|
(void)insertion;
|
|
}
|
|
|
|
/// Accumulates `rhsAddress` into the adjoint buffer corresponding to the
|
|
/// original value.
|
|
void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue,
|
|
SILValue rhsAddress, SILLocation loc) {
|
|
assert(getTangentValueCategory(originalValue) ==
|
|
SILValueCategory::Address &&
|
|
rhsAddress->getType().isAddress());
|
|
assert(originalValue->getFunction() == &getOriginal());
|
|
assert(rhsAddress->getFunction() == &getPullback());
|
|
auto adjointBuffer = getAdjointBuffer(origBB, originalValue);
|
|
|
|
LLVM_DEBUG(getADDebugStream() << "Adding"
|
|
<< rhsAddress << "to adjoint ("
|
|
<< adjointBuffer << ") of "
|
|
<< originalValue
|
|
<< "in bb" << origBB->getDebugID() << '\n');
|
|
|
|
builder.emitInPlaceAdd(loc, adjointBuffer, rhsAddress);
|
|
}
|
|
|
|
/// Returns a next insertion point for creating a local allocation: either
|
|
/// before the previous local allocation, or at the start of the pullback
|
|
/// entry if no local allocations exist.
|
|
///
|
|
/// Helper for `createFunctionLocalAllocation`.
|
|
SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() {
|
|
// If there are no local allocations, insert at the pullback entry start.
|
|
if (functionLocalAllocations.empty())
|
|
return getPullback().getEntryBlock()->begin();
|
|
// Otherwise, insert before the last local allocation. Inserting before
|
|
// rather than after ensures that allocation and zero initialization
|
|
// instructions are grouped together.
|
|
auto lastLocalAlloc = functionLocalAllocations.back();
|
|
return lastLocalAlloc->getDefiningInstruction()->getIterator();
|
|
}
|
|
|
|
/// Creates and returns a local allocation with the given type.
|
|
///
|
|
/// Local allocations are created uninitialized in the pullback entry and
|
|
/// deallocated in the pullback exit. All local allocations not in
|
|
/// `destroyedLocalAllocations` are also destroyed in the pullback exit.
|
|
///
|
|
/// Helper for `getAdjointBuffer`.
|
|
AllocStackInst *createFunctionLocalAllocation(
|
|
SILType type, SILLocation loc, bool zeroInitialize = false,
|
|
std::optional<SILDebugVariable> varInfo = std::nullopt) {
|
|
// Set insertion point for local allocation builder: before the last local
|
|
// allocation, or at the start of the pullback function's entry if no local
|
|
// allocations exist yet.
|
|
localAllocBuilder.setInsertionPoint(
|
|
getPullback().getEntryBlock(),
|
|
getNextFunctionLocalAllocationInsertionPoint());
|
|
// Create and return local allocation.
|
|
auto *alloc = localAllocBuilder.createAllocStack(loc, type, varInfo);
|
|
functionLocalAllocations.push_back(alloc);
|
|
// Zero-initialize if requested.
|
|
if (zeroInitialize)
|
|
localAllocBuilder.emitZeroIntoBuffer(loc, alloc, IsInitialization);
|
|
return alloc;
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Optional differentiation
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Given a `wrappedAdjoint` value of type `T.TangentVector` and `Optional<T>`
|
|
/// type, creates an `Optional<T>.TangentVector` buffer from it.
|
|
///
|
|
/// `wrappedAdjoint` may be an object or address value, both cases are
|
|
/// handled.
|
|
AllocStackInst *createOptionalAdjoint(SILBasicBlock *bb,
|
|
SILValue wrappedAdjoint,
|
|
SILType optionalTy);
|
|
|
|
/// Accumulate adjoint of `wrappedAdjoint` into optionalBuffer.
|
|
void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb,
|
|
SILValue optionalBuffer,
|
|
SILValue wrappedAdjoint);
|
|
|
|
/// Accumulate adjoint of `wrappedAdjoint` into optionalValue.
|
|
void accumulateAdjointValueForOptional(SILBasicBlock *bb,
|
|
SILValue optionalValue,
|
|
SILValue wrappedAdjoint);
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Array literal initialization differentiation
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Given the adjoint value of an array initialized from an
|
|
/// `array.uninitialized_intrinsic` application and an array element index,
|
|
/// returns an `alloc_stack` containing the adjoint value of the array element
|
|
/// at the given index by applying `Array.TangentVector.subscript`.
|
|
AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint,
|
|
int eltIndex, SILLocation loc);
|
|
|
|
/// Given the adjoint value of an array initialized from an
|
|
/// `array.uninitialized_intrinsic` application, accumulates the adjoint
|
|
/// value's elements into the adjoint buffers of its element addresses.
|
|
void accumulateArrayLiteralElementAddressAdjoints(
|
|
SILBasicBlock *origBB, SILValue originalValue,
|
|
AdjointValue arrayAdjointValue, SILLocation loc);
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// CFG mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) {
|
|
return pullbackBBMap.lookup(originalBlock);
|
|
}
|
|
|
|
SILBasicBlock *getPullbackTrampolineBlock(SILBasicBlock *originalBlock,
|
|
SILBasicBlock *successorBlock) {
|
|
return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock});
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Debug info
|
|
//--------------------------------------------------------------------------//
|
|
|
|
const SILDebugScope *remapScope(const SILDebugScope *DS) {
|
|
return scopeCloner.getOrCreateClonedScope(DS);
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Debugging utilities
|
|
//--------------------------------------------------------------------------//
|
|
|
|
void printAdjointValueMapping() {
|
|
// Group original/adjoint values by basic block.
|
|
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, AdjointValue>> tmp;
|
|
for (auto pair : valueMap) {
|
|
auto origPair = pair.first;
|
|
auto *origBB = origPair.first;
|
|
auto origValue = origPair.second;
|
|
auto adjValue = pair.second;
|
|
tmp[origBB].insert({origValue, adjValue});
|
|
}
|
|
// Print original/adjoint values per basic block.
|
|
auto &s = getADDebugStream() << "Adjoint value mapping:\n";
|
|
for (auto &origBB : getOriginal()) {
|
|
if (!pullbackBBMap.count(&origBB))
|
|
continue;
|
|
auto bbValueMap = tmp[&origBB];
|
|
s << "bb" << origBB.getDebugID();
|
|
s << " (size " << bbValueMap.size() << "):\n";
|
|
for (auto valuePair : bbValueMap) {
|
|
auto origValue = valuePair.first;
|
|
auto adjValue = valuePair.second;
|
|
s << "ORIG: " << origValue;
|
|
s << "ADJ: " << adjValue << '\n';
|
|
}
|
|
s << '\n';
|
|
}
|
|
}
|
|
|
|
void printAdjointBufferMapping() {
|
|
// Group original/adjoint buffers by basic block.
|
|
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, SILValue>> tmp;
|
|
for (auto pair : bufferMap) {
|
|
auto origPair = pair.first;
|
|
auto *origBB = origPair.first;
|
|
auto origBuf = origPair.second;
|
|
auto adjBuf = pair.second;
|
|
tmp[origBB][origBuf] = adjBuf;
|
|
}
|
|
// Print original/adjoint buffers per basic block.
|
|
auto &s = getADDebugStream() << "Adjoint buffer mapping:\n";
|
|
for (auto &origBB : getOriginal()) {
|
|
if (!pullbackBBMap.count(&origBB))
|
|
continue;
|
|
auto bbBufferMap = tmp[&origBB];
|
|
s << "bb" << origBB.getDebugID();
|
|
s << " (size " << bbBufferMap.size() << "):\n";
|
|
for (auto valuePair : bbBufferMap) {
|
|
auto origBuf = valuePair.first;
|
|
auto adjBuf = valuePair.second;
|
|
s << "ORIG: " << origBuf;
|
|
s << "ADJ: " << adjBuf << '\n';
|
|
}
|
|
s << '\n';
|
|
}
|
|
}
|
|
|
|
public:
|
|
//--------------------------------------------------------------------------//
|
|
// Entry point
|
|
//--------------------------------------------------------------------------//
|
|
|
|
/// Performs pullback generation on the empty pullback function. Returns true
|
|
/// if any error occurs.
|
|
bool run();
|
|
|
|
/// Performs pullback generation on the empty pullback function, given that
|
|
/// the original function is a "semantic member accessor".
|
|
///
|
|
/// "Semantic member accessors" are attached to member properties that have a
|
|
/// corresponding tangent stored property in the parent `TangentVector` type.
|
|
/// These accessors have special-case pullback generation based on their
|
|
/// semantic behavior.
|
|
///
|
|
/// Returns true if any error occurs.
|
|
bool runForSemanticMemberAccessor();
|
|
bool runForSemanticMemberGetter();
|
|
bool runForSemanticMemberSetter();
|
|
bool runForSemanticMemberModify();
|
|
|
|
/// If original result is non-varied, it will always have a zero derivative.
|
|
/// Skip full pullback generation and simply emit zero derivatives for wrt
|
|
/// parameters.
|
|
void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult);
|
|
|
|
/// Public helper so that our users can get the underlying newly created
|
|
/// function.
|
|
SILFunction &getPullback() const { return vjpCloner.getPullback(); }
|
|
|
|
using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>;
|
|
|
|
/// Determines the pullback successor block for a given original block and one
|
|
/// of its predecessors. When a trampoline block is necessary, emits code into
|
|
/// the trampoline block to trampoline the original block's active value's
|
|
/// adjoint values.
|
|
///
|
|
/// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint
|
|
/// values to the pullback successor blocks in which they are used. This
|
|
/// allows us to release those values in pullback successor blocks that do not
|
|
/// use them.
|
|
SILBasicBlock *
|
|
buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB,
|
|
llvm::SmallDenseMap<SILValue, TrampolineBlockSet>
|
|
&pullbackTrampolineBlockMap);
|
|
|
|
/// Emits pullback code in the corresponding pullback block.
|
|
void visitSILBasicBlock(SILBasicBlock *bb);
|
|
|
|
void visit(SILInstruction *inst) {
|
|
if (errorOccurred)
|
|
return;
|
|
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "PullbackCloner visited:\n[ORIG]" << *inst);
|
|
#ifndef NDEBUG
|
|
auto beforeInsertion = std::prev(builder.getInsertionPoint());
|
|
#endif
|
|
SILInstructionVisitor::visit(inst);
|
|
LLVM_DEBUG({
|
|
auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback (pb bb" <<
|
|
builder.getInsertionBB()->getDebugID() << "):\n";
|
|
auto afterInsertion = builder.getInsertionPoint();
|
|
for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
|
|
s << *it;
|
|
});
|
|
}
|
|
|
|
/// Fallback instruction visitor for unhandled instructions.
|
|
/// Emit a general non-differentiability diagnostic.
|
|
void visitSILInstruction(SILInstruction *inst) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Unhandled instruction in PullbackCloner: " << *inst);
|
|
getContext().emitNondifferentiabilityError(
|
|
inst, getInvoker(), diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
}
|
|
|
|
/// Handle `apply` instruction.
|
|
/// Original: (y0, y1, ...) = apply @fn (x0, x1, ...)
|
|
/// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...)
|
|
void visitApplyInst(ApplyInst *ai) {
|
|
assert(getPullbackInfo().shouldDifferentiateApplySite(ai));
|
|
|
|
// Skip `array.uninitialized_intrinsic` applications, which have special
|
|
// `store` and `copy_addr` support.
|
|
if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC))
|
|
return;
|
|
|
|
auto loc = ai->getLoc();
|
|
auto *bb = ai->getParent();
|
|
// Handle `array.finalize_intrinsic` applications.
|
|
// `array.finalize_intrinsic` semantically behaves like an identity
|
|
// function.
|
|
if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) {
|
|
assert(ai->getNumArguments() == 1 &&
|
|
"Expected intrinsic to have one operand");
|
|
// Accumulate result's adjoint into argument's adjoint.
|
|
auto adjResult = getAdjointValue(bb, ai);
|
|
auto origArg = ai->getArgumentsWithoutIndirectResults().front();
|
|
addAdjointValue(bb, origArg, adjResult, loc);
|
|
return;
|
|
}
|
|
|
|
buildPullbackCall(ai);
|
|
}
|
|
|
|
void buildPullbackCall(FullApplySite fai) {
|
|
auto loc = fai->getLoc();
|
|
auto *bb = fai->getParent();
|
|
|
|
// Replace a call to a function with a call to its pullback.
|
|
auto &nestedApplyInfo = getContext().getNestedApplyInfo();
|
|
auto applyInfoLookup = nestedApplyInfo.find(fai);
|
|
// If no `NestedApplyInfo` was found, then this task doesn't need to be
|
|
// differentiated.
|
|
if (applyInfoLookup == nestedApplyInfo.end()) {
|
|
// Must not be active.
|
|
// TODO: Do we need to check token result for begin_apply?
|
|
SILValue result = fai.getResult();
|
|
assert(!result || !getActivityInfo().isActive(result, getConfig()));
|
|
return;
|
|
}
|
|
auto &applyInfo = applyInfoLookup->getSecond();
|
|
|
|
// Get the original result of the `apply` instruction.
|
|
const auto &conv = fai.getSubstCalleeConv();
|
|
SmallVector<SILValue, 8> origDirectResults;
|
|
forEachApplyDirectResult(fai, [&](SILValue directResult) {
|
|
origDirectResults.push_back(directResult);
|
|
});
|
|
SmallVector<SILValue, 8> origAllResults;
|
|
collectAllActualResultsInTypeOrder(fai, origDirectResults, origAllResults);
|
|
// Append semantic result arguments after original results.
|
|
for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) {
|
|
unsigned argIdx = fai.getNumIndirectSILResults() + paramIdx;
|
|
auto paramInfo = conv.getParamInfoForSILArg(argIdx);
|
|
if (!paramInfo.isAutoDiffSemanticResult())
|
|
continue;
|
|
origAllResults.push_back(
|
|
fai.getArgumentsWithoutIndirectResults()[paramIdx]);
|
|
}
|
|
|
|
// Get callee pullback arguments.
|
|
SmallVector<SILValue, 8> args;
|
|
|
|
// Handle callee pullback indirect results.
|
|
// Create local allocations for these and destroy them after the call.
|
|
auto pullback = getPullbackTupleElement(fai);
|
|
auto pullbackType =
|
|
remapType(pullback->getType()).castTo<SILFunctionType>();
|
|
|
|
auto actualPullbackType = applyInfo.originalPullbackType
|
|
? *applyInfo.originalPullbackType
|
|
: pullbackType;
|
|
actualPullbackType = actualPullbackType->getUnsubstitutedType(getModule());
|
|
SmallVector<AllocStackInst *, 4> pullbackIndirectResults;
|
|
for (auto indRes : actualPullbackType->getIndirectFormalResults()) {
|
|
auto *alloc = builder.createAllocStack(
|
|
loc, remapType(indRes.getSILStorageInterfaceType()));
|
|
pullbackIndirectResults.push_back(alloc);
|
|
args.push_back(alloc);
|
|
}
|
|
|
|
// Collect callee pullback formal arguments.
|
|
unsigned firstSemanticParamResultIdx = conv.getResults().size();
|
|
unsigned firstYieldResultIndex = firstSemanticParamResultIdx +
|
|
conv.getNumAutoDiffSemanticResultParameters();
|
|
for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) {
|
|
if (resultIndex >= firstYieldResultIndex)
|
|
continue;
|
|
assert(resultIndex < origAllResults.size());
|
|
auto origResult = origAllResults[resultIndex];
|
|
// Get the seed (i.e. adjoint value of the original result).
|
|
SILValue seed;
|
|
switch (getTangentValueCategory(origResult)) {
|
|
case SILValueCategory::Object:
|
|
seed = materializeAdjointDirect(getAdjointValue(bb, origResult), loc);
|
|
break;
|
|
case SILValueCategory::Address:
|
|
seed = getAdjointBuffer(bb, origResult);
|
|
break;
|
|
}
|
|
args.push_back(seed);
|
|
}
|
|
|
|
// If callee pullback was reabstracted in VJP, reabstract callee pullback.
|
|
if (applyInfo.originalPullbackType) {
|
|
auto toType = *applyInfo.originalPullbackType;
|
|
SILOptFunctionBuilder fb(getContext().getTransform());
|
|
if (toType->isCoroutine())
|
|
pullback = reabstractCoroutine(
|
|
builder, fb, loc, pullback, toType,
|
|
[this](SubstitutionMap subs) -> SubstitutionMap {
|
|
return this->remapSubstitutionMap(subs);
|
|
});
|
|
else
|
|
pullback = reabstractFunction(
|
|
builder, fb, loc, pullback, toType,
|
|
[this](SubstitutionMap subs) -> SubstitutionMap {
|
|
return this->remapSubstitutionMap(subs);
|
|
});
|
|
}
|
|
|
|
// Call the callee pullback.
|
|
FullApplySite pullbackCall;
|
|
SmallVector<SILValue, 8> dirResults;
|
|
if (actualPullbackType->isCoroutine()) {
|
|
pullbackCall = builder.createBeginApply(loc, pullback, SubstitutionMap(),
|
|
args);
|
|
// Record pullback and begin_apply token: the pullback will be consumed
|
|
// after end_apply.
|
|
applyInfo.pullback = pullback;
|
|
applyInfo.beginApplyToken = cast<BeginApplyInst>(pullbackCall)->getTokenResult();
|
|
} else {
|
|
pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
|
|
args);
|
|
builder.emitDestroyValueOperation(loc, pullback);
|
|
// Extract all results from `pullbackCall`.
|
|
extractAllElements(cast<ApplyInst>(pullbackCall), builder, dirResults);
|
|
}
|
|
|
|
// Get all results in type-defined order.
|
|
SmallVector<SILValue, 8> allResults;
|
|
collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults);
|
|
|
|
LLVM_DEBUG({
|
|
auto &s = getADDebugStream();
|
|
s << "All results of the nested pullback call:\n";
|
|
llvm::for_each(allResults, [&](SILValue v) { s << v; });
|
|
});
|
|
|
|
// Accumulate adjoints for original differentiation parameters.
|
|
auto allResultsIt = allResults.begin();
|
|
for (unsigned i : applyInfo.config.parameterIndices->getIndices()) {
|
|
unsigned argIdx = fai.getNumIndirectSILResults() + i;
|
|
auto origArg = fai.getArgument(argIdx);
|
|
// Skip adjoint accumulation for semantic results arguments.
|
|
auto paramInfo = fai.getSubstCalleeConv().getParamInfoForSILArg(argIdx);
|
|
if (paramInfo.isAutoDiffSemanticResult())
|
|
continue;
|
|
auto tan = *allResultsIt++;
|
|
if (tan->getType().isAddress()) {
|
|
addToAdjointBuffer(bb, origArg, tan, loc);
|
|
} else {
|
|
if (origArg->getType().isAddress()) {
|
|
auto *tmpBuf = builder.createAllocStack(loc, tan->getType());
|
|
builder.emitStoreValueOperation(loc, tan, tmpBuf,
|
|
StoreOwnershipQualifier::Init);
|
|
addToAdjointBuffer(bb, origArg, tmpBuf, loc);
|
|
builder.emitDestroyAddrAndFold(loc, tmpBuf);
|
|
builder.createDeallocStack(loc, tmpBuf);
|
|
} else {
|
|
recordTemporary(tan);
|
|
addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Propagate adjoints for yields
|
|
if (actualPullbackType->isCoroutine()) {
|
|
auto originalYields = cast<BeginApplyInst>(fai)->getYieldedValues();
|
|
auto pullbackYields = cast<BeginApplyInst>(pullbackCall)->getYieldedValues();
|
|
assert(originalYields.size() == pullbackYields.size());
|
|
|
|
for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) {
|
|
if (resultIndex < firstYieldResultIndex)
|
|
continue;
|
|
|
|
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
|
|
setAdjointBuffer(bb, originalYields[yieldResultIndex], pullbackYields[yieldResultIndex]);
|
|
}
|
|
}
|
|
|
|
// Destroy unused pullback direct results. Needed for pullback results from
|
|
// VJPs extracted from `@differentiable` function callees, where the
|
|
// `@differentiable` function's differentiation parameter indices are a
|
|
// superset of the active `apply` parameter indices.
|
|
while (allResultsIt != allResults.end()) {
|
|
auto unusedPullbackDirectResult = *allResultsIt++;
|
|
if (unusedPullbackDirectResult->getType().isAddress())
|
|
continue;
|
|
builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult);
|
|
}
|
|
|
|
// Destroy and deallocate pullback indirect results.
|
|
for (auto *alloc : llvm::reverse(pullbackIndirectResults)) {
|
|
builder.emitDestroyAddrAndFold(loc, alloc);
|
|
builder.createDeallocStack(loc, alloc);
|
|
}
|
|
}
|
|
|
|
void visitAbortApplyInst(AbortApplyInst *aai) {
|
|
BeginApplyInst *bai = aai->getBeginApply();
|
|
assert(getPullbackInfo().shouldDifferentiateApplySite(bai));
|
|
|
|
// abort_apply differentiation is not yet supported.
|
|
getContext().emitNondifferentiabilityError(
|
|
bai, getInvoker(), diag::autodiff_coroutines_not_supported);
|
|
errorOccurred = true;
|
|
}
|
|
|
|
void visitEndApplyInst(EndApplyInst *eai) {
|
|
BeginApplyInst *bai = eai->getBeginApply();
|
|
assert(getPullbackInfo().shouldDifferentiateApplySite(bai));
|
|
|
|
// Replace a call to a function with a call to its pullback.
|
|
auto &nestedApplyInfo = getContext().getNestedApplyInfo();
|
|
auto applyInfoLookup = nestedApplyInfo.find(bai);
|
|
// If no `NestedApplyInfo` was found, then this task doesn't need to be
|
|
// differentiated.
|
|
if (applyInfoLookup == nestedApplyInfo.end()) {
|
|
// Must not be active.
|
|
assert(!getActivityInfo().isActive(bai->getTokenResult(), getConfig()));
|
|
assert(!getActivityInfo().isActive(eai, getConfig()));
|
|
return;
|
|
}
|
|
|
|
buildPullbackCall(bai);
|
|
}
|
|
|
|
void visitBeginApplyInst(BeginApplyInst *bai) {
|
|
assert(getPullbackInfo().shouldDifferentiateApplySite(bai));
|
|
|
|
auto &nestedApplyInfo = getContext().getNestedApplyInfo();
|
|
auto applyInfoLookup = nestedApplyInfo.find(bai);
|
|
// If no `NestedApplyInfo` was found, then this task doesn't need to be
|
|
// differentiated.
|
|
if (applyInfoLookup == nestedApplyInfo.end()) {
|
|
// Must not be active.
|
|
assert(!getActivityInfo().isActive(bai->getTokenResult(), getConfig()));
|
|
return;
|
|
}
|
|
auto applyInfo = applyInfoLookup->getSecond();
|
|
|
|
auto loc = bai->getLoc();
|
|
builder.createEndApply(loc, applyInfo.beginApplyToken,
|
|
SILType::getEmptyTupleType(getASTContext()));
|
|
builder.emitDestroyValueOperation(loc, applyInfo.pullback);
|
|
}
|
|
|
|
/// Handle `struct` instruction.
|
|
/// Original: y = struct (x0, x1, x2, ...)
|
|
/// Adjoint: adj[x0] += struct_extract adj[y], #x0
|
|
/// adj[x1] += struct_extract adj[y], #x1
|
|
/// adj[x2] += struct_extract adj[y], #x2
|
|
/// ...
|
|
void visitStructInst(StructInst *si) {
|
|
auto *bb = si->getParent();
|
|
auto loc = si->getLoc();
|
|
auto *structDecl = si->getStructDecl();
|
|
switch (getTangentValueCategory(si)) {
|
|
case SILValueCategory::Object: {
|
|
auto av = getAdjointValue(bb, si);
|
|
switch (av.getKind()) {
|
|
case AdjointValueKind::Zero: {
|
|
for (auto *field : structDecl->getStoredProperties()) {
|
|
auto fv = si->getFieldValue(field);
|
|
addAdjointValue(
|
|
bb, fv,
|
|
makeZeroAdjointValue(getRemappedTangentType(fv->getType())), loc);
|
|
}
|
|
break;
|
|
}
|
|
case AdjointValueKind::Concrete: {
|
|
auto adjStruct = materializeAdjointDirect(std::move(av), loc);
|
|
auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct);
|
|
|
|
// Find the struct `TangentVector` type.
|
|
auto structTy = remapType(si->getType()).getASTType();
|
|
#ifndef NDEBUG
|
|
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
|
|
assert(!getTypeLowering(tangentVectorTy).isAddressOnly());
|
|
assert(tangentVectorTy->getStructOrBoundGenericStruct());
|
|
#endif
|
|
|
|
// Accumulate adjoints for the fields of the `struct` operand.
|
|
unsigned fieldIndex = 0;
|
|
for (auto it = structDecl->getStoredProperties().begin();
|
|
it != structDecl->getStoredProperties().end();
|
|
++it, ++fieldIndex) {
|
|
VarDecl *field = *it;
|
|
if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
|
|
continue;
|
|
// Find the corresponding field in the tangent space.
|
|
auto *tanField = getTangentStoredProperty(
|
|
getContext(), field, structTy, loc, getInvoker());
|
|
if (!tanField) {
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
auto tanElt = dti->getResult(fieldIndex);
|
|
addAdjointValue(bb, si->getFieldValue(field),
|
|
makeConcreteAdjointValue(tanElt), si->getLoc());
|
|
}
|
|
break;
|
|
}
|
|
case AdjointValueKind::Aggregate: {
|
|
// Note: All user-called initializations go through the calls to the
|
|
// initializer, and synthesized initializers only have one level of
|
|
// struct formation which will not result into any aggregate adjoint
|
|
// values.
|
|
llvm_unreachable(
|
|
"Aggregate adjoint values should not occur for `struct` "
|
|
"instructions");
|
|
}
|
|
case AdjointValueKind::AddElement: {
|
|
llvm_unreachable(
|
|
"Adjoint of `StructInst` cannot be of kind `AddElement`");
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(bb, si);
|
|
// Find the struct `TangentVector` type.
|
|
auto structTy = remapType(si->getType()).getASTType();
|
|
// Accumulate adjoints for the fields of the `struct` operand.
|
|
unsigned fieldIndex = 0;
|
|
for (auto it = structDecl->getStoredProperties().begin();
|
|
it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) {
|
|
VarDecl *field = *it;
|
|
if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
|
|
continue;
|
|
// Find the corresponding field in the tangent space.
|
|
auto *tanField = getTangentStoredProperty(getContext(), field, structTy,
|
|
loc, getInvoker());
|
|
if (!tanField) {
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
auto *adjFieldBuf =
|
|
builder.createStructElementAddr(loc, adjBuf, tanField);
|
|
auto fieldValue = si->getFieldValue(field);
|
|
switch (getTangentValueCategory(fieldValue)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjField = builder.emitLoadValueOperation(
|
|
loc, adjFieldBuf, LoadOwnershipQualifier::Copy);
|
|
recordTemporary(adjField);
|
|
addAdjointValue(bb, fieldValue, makeConcreteAdjointValue(adjField),
|
|
loc);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
addToAdjointBuffer(bb, fieldValue, adjFieldBuf, loc);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
} break;
|
|
}
|
|
}
|
|
|
|
/// Handle `struct_extract` instruction.
|
|
/// Original: y = struct_extract x, #field
|
|
/// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
|
|
/// ^~~~~~~
|
|
/// field in tangent space corresponding to #field
|
|
void visitStructExtractInst(StructExtractInst *sei) {
|
|
auto *bb = sei->getParent();
|
|
auto loc = getValidLocation(sei);
|
|
// Find the corresponding field in the tangent space.
|
|
auto structTy = remapType(sei->getOperand()->getType()).getASTType();
|
|
auto *tanField =
|
|
getTangentStoredProperty(getContext(), sei, structTy, getInvoker());
|
|
assert(tanField && "Invalid projections should have been diagnosed");
|
|
// Check the `struct_extract` operand's value tangent category.
|
|
switch (getTangentValueCategory(sei->getOperand())) {
|
|
case SILValueCategory::Object: {
|
|
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
|
|
auto tangentVectorSILTy =
|
|
SILType::getPrimitiveObjectType(tangentVectorTy);
|
|
auto eltAdj = getAdjointValue(bb, sei);
|
|
|
|
switch (eltAdj.getKind()) {
|
|
case AdjointValueKind::Zero: {
|
|
addAdjointValue(bb, sei->getOperand(),
|
|
makeZeroAdjointValue(tangentVectorSILTy), loc);
|
|
break;
|
|
}
|
|
case AdjointValueKind::Aggregate:
|
|
case AdjointValueKind::Concrete:
|
|
case AdjointValueKind::AddElement: {
|
|
auto baseAdj = makeZeroAdjointValue(tangentVectorSILTy);
|
|
addAdjointValue(bb, sei->getOperand(),
|
|
makeAddElementAdjointValue(baseAdj, eltAdj, tanField),
|
|
loc);
|
|
break;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBase = getAdjointBuffer(bb, sei->getOperand());
|
|
auto *adjBaseElt =
|
|
builder.createStructElementAddr(loc, adjBase, tanField);
|
|
// Check the `struct_extract`'s value tangent category.
|
|
switch (getTangentValueCategory(sei)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjElt = getAdjointValue(bb, sei);
|
|
auto concreteAdjElt = materializeAdjointDirect(adjElt, loc);
|
|
auto concreteAdjEltCopy =
|
|
builder.emitCopyValueOperation(loc, concreteAdjElt);
|
|
auto *alloc = builder.createAllocStack(loc, adjElt.getType());
|
|
builder.emitStoreValueOperation(loc, concreteAdjEltCopy, alloc,
|
|
StoreOwnershipQualifier::Init);
|
|
builder.emitInPlaceAdd(loc, adjBaseElt, alloc);
|
|
builder.createDestroyAddr(loc, alloc);
|
|
builder.createDeallocStack(loc, alloc);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjElt = getAdjointBuffer(bb, sei);
|
|
builder.emitInPlaceAdd(loc, adjBaseElt, adjElt);
|
|
break;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `ref_element_addr` instruction.
|
|
/// Original: y = ref_element_addr x, <n>
|
|
/// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
|
|
/// ^~~~~~~
|
|
/// field in tangent space corresponding to #field
|
|
void visitRefElementAddrInst(RefElementAddrInst *reai) {
|
|
auto *bb = reai->getParent();
|
|
auto loc = reai->getLoc();
|
|
auto adjBuf = getAdjointBuffer(bb, reai);
|
|
auto classOperand = reai->getOperand();
|
|
auto classType = remapType(reai->getOperand()->getType()).getASTType();
|
|
auto *tanField =
|
|
getTangentStoredProperty(getContext(), reai, classType, getInvoker());
|
|
assert(tanField && "Invalid projections should have been diagnosed");
|
|
switch (getTangentValueCategory(classOperand)) {
|
|
case SILValueCategory::Object: {
|
|
auto classTy = remapType(classOperand->getType()).getASTType();
|
|
auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType();
|
|
auto tangentVectorSILTy =
|
|
SILType::getPrimitiveObjectType(tangentVectorTy);
|
|
auto *tangentVectorDecl =
|
|
tangentVectorTy->getStructOrBoundGenericStruct();
|
|
// Accumulate adjoint for the `ref_element_addr` operand.
|
|
SmallVector<AdjointValue, 8> eltVals;
|
|
for (auto *field : tangentVectorDecl->getStoredProperties()) {
|
|
if (field == tanField) {
|
|
auto adjElt = builder.emitLoadValueOperation(
|
|
reai->getLoc(), adjBuf, LoadOwnershipQualifier::Copy);
|
|
eltVals.push_back(makeConcreteAdjointValue(adjElt));
|
|
recordTemporary(adjElt);
|
|
} else {
|
|
auto substMap = tangentVectorTy->getMemberSubstitutionMap(
|
|
field);
|
|
auto fieldTy = field->getInterfaceType().subst(substMap);
|
|
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
|
|
assert(fieldSILTy.isObject());
|
|
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
|
|
}
|
|
}
|
|
addAdjointValue(bb, classOperand,
|
|
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
|
|
loc);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBufClass = getAdjointBuffer(bb, classOperand);
|
|
auto adjBufElt =
|
|
builder.createStructElementAddr(loc, adjBufClass, tanField);
|
|
builder.emitInPlaceAdd(loc, adjBufElt, adjBuf);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `tuple` instruction.
|
|
/// Original: y = tuple (x0, x1, x2, ...)
|
|
/// Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y]
|
|
/// ^~~
|
|
/// excluding non-differentiable elements
|
|
void visitTupleInst(TupleInst *ti) {
|
|
auto *bb = ti->getParent();
|
|
auto loc = ti->getLoc();
|
|
switch (getTangentValueCategory(ti)) {
|
|
case SILValueCategory::Object: {
|
|
auto av = getAdjointValue(bb, ti);
|
|
switch (av.getKind()) {
|
|
case AdjointValueKind::Zero:
|
|
for (auto elt : ti->getElements()) {
|
|
if (!getTangentSpace(elt->getType().getASTType()))
|
|
continue;
|
|
addAdjointValue(
|
|
bb, elt,
|
|
makeZeroAdjointValue(getRemappedTangentType(elt->getType())),
|
|
loc);
|
|
}
|
|
break;
|
|
case AdjointValueKind::Concrete: {
|
|
auto adjVal = av.getConcreteValue();
|
|
auto adjValCopy = builder.emitCopyValueOperation(loc, adjVal);
|
|
SmallVector<SILValue, 4> adjElts;
|
|
if (!adjVal->getType().getAs<TupleType>()) {
|
|
recordTemporary(adjValCopy);
|
|
adjElts.push_back(adjValCopy);
|
|
} else {
|
|
auto *dti = builder.createDestructureTuple(loc, adjValCopy);
|
|
for (auto adjElt : dti->getResults())
|
|
recordTemporary(adjElt);
|
|
adjElts.append(dti->getResults().begin(), dti->getResults().end());
|
|
}
|
|
// Accumulate adjoints for `tuple` operands, skipping the
|
|
// non-`Differentiable` ones.
|
|
unsigned adjIndex = 0;
|
|
for (auto i : range(ti->getNumOperands())) {
|
|
if (!getTangentSpace(ti->getOperand(i)->getType().getASTType()))
|
|
continue;
|
|
auto adjElt = adjElts[adjIndex++];
|
|
addAdjointValue(bb, ti->getOperand(i),
|
|
makeConcreteAdjointValue(adjElt), loc);
|
|
}
|
|
break;
|
|
}
|
|
case AdjointValueKind::Aggregate: {
|
|
unsigned adjIndex = 0;
|
|
for (auto i : range(ti->getElements().size())) {
|
|
if (!getTangentSpace(ti->getElement(i)->getType().getASTType()))
|
|
continue;
|
|
addAdjointValue(bb, ti->getElement(i),
|
|
av.getAggregateElement(adjIndex++), loc);
|
|
}
|
|
break;
|
|
}
|
|
case AdjointValueKind::AddElement: {
|
|
llvm_unreachable(
|
|
"Adjoint of `TupleInst` cannot be of kind `AddElement`");
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(bb, ti);
|
|
// Accumulate adjoints for `tuple` operands, skipping the
|
|
// non-`Differentiable` ones.
|
|
unsigned adjIndex = 0;
|
|
for (auto i : range(ti->getNumOperands())) {
|
|
if (!getTangentSpace(ti->getOperand(i)->getType().getASTType()))
|
|
continue;
|
|
auto adjBufElt =
|
|
builder.createTupleElementAddr(loc, adjBuf, adjIndex++);
|
|
auto adjElt = getAdjointBuffer(bb, ti->getOperand(i));
|
|
builder.emitInPlaceAdd(loc, adjElt, adjBufElt);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `tuple_extract` instruction.
|
|
/// Original: y = tuple_extract x, <n>
|
|
/// Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0)
|
|
/// ^~~~~~
|
|
/// n'-th element, where n' is tuple tangent space
|
|
/// index corresponding to n
|
|
void visitTupleExtractInst(TupleExtractInst *tei) {
|
|
auto *bb = tei->getParent();
|
|
auto loc = tei->getLoc();
|
|
auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
|
|
auto eltAdj = getAdjointValue(bb, tei);
|
|
switch (eltAdj.getKind()) {
|
|
case AdjointValueKind::Zero: {
|
|
addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy),
|
|
loc);
|
|
break;
|
|
}
|
|
case AdjointValueKind::Aggregate:
|
|
case AdjointValueKind::Concrete:
|
|
case AdjointValueKind::AddElement: {
|
|
auto tupleTy = tei->getTupleType();
|
|
auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>();
|
|
if (!tupleTanTupleTy) {
|
|
addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
|
|
break;
|
|
}
|
|
|
|
unsigned elements = 0;
|
|
for (unsigned i : range(tupleTy->getNumElements())) {
|
|
if (!getTangentSpace(
|
|
tupleTy->getElement(i).getType()->getCanonicalType()))
|
|
continue;
|
|
elements++;
|
|
}
|
|
|
|
if (elements == 1) {
|
|
addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
|
|
} else {
|
|
auto baseAdj = makeZeroAdjointValue(tupleTanTy);
|
|
addAdjointValue(
|
|
bb, tei->getOperand(),
|
|
makeAddElementAdjointValue(baseAdj, eltAdj, tei->getFieldIndex()),
|
|
loc);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `destructure_tuple` instruction.
|
|
/// Original: (y0, ..., yn) = destructure_tuple x
|
|
/// Adjoint: adj[x].0 += adj[y0]
|
|
/// ...
|
|
/// adj[x].n += adj[yn]
|
|
void visitDestructureTupleInst(DestructureTupleInst *dti) {
|
|
auto *bb = dti->getParent();
|
|
auto loc = dti->getLoc();
|
|
auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType());
|
|
// Check the `destructure_tuple` operand's value tangent category.
|
|
switch (getTangentValueCategory(dti->getOperand())) {
|
|
case SILValueCategory::Object: {
|
|
SmallVector<AdjointValue, 8> adjValues;
|
|
for (auto origElt : dti->getResults()) {
|
|
// Skip non-`Differentiable` tuple elements.
|
|
if (!getTangentSpace(remapType(origElt->getType()).getASTType()))
|
|
continue;
|
|
adjValues.push_back(getAdjointValue(bb, origElt));
|
|
}
|
|
// Handle tuple tangent type.
|
|
// Add adjoints for every tuple element that has a tangent space.
|
|
if (tupleTanTy.is<TupleType>()) {
|
|
assert(adjValues.size() > 1);
|
|
addAdjointValue(bb, dti->getOperand(),
|
|
makeAggregateAdjointValue(tupleTanTy, adjValues), loc);
|
|
}
|
|
// Handle non-tuple tangent type.
|
|
// Add adjoint for the single tuple element that has a tangent space.
|
|
else {
|
|
assert(adjValues.size() == 1);
|
|
addAdjointValue(bb, dti->getOperand(), adjValues.front(), loc);
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(bb, dti->getOperand());
|
|
unsigned adjIndex = 0;
|
|
for (auto origElt : dti->getResults()) {
|
|
// Skip non-`Differentiable` tuple elements.
|
|
if (!getTangentSpace(remapType(origElt->getType()).getASTType()))
|
|
continue;
|
|
// Handle tuple tangent type.
|
|
// Add adjoints for every tuple element that has a tangent space.
|
|
if (tupleTanTy.is<TupleType>()) {
|
|
auto adjEltBuf = getAdjointBuffer(bb, origElt);
|
|
auto adjBufElt =
|
|
builder.createTupleElementAddr(loc, adjBuf, adjIndex);
|
|
builder.emitInPlaceAdd(loc, adjBufElt, adjEltBuf);
|
|
}
|
|
// Handle non-tuple tangent type.
|
|
// Add adjoint for the single tuple element that has a tangent space.
|
|
else {
|
|
auto adjEltBuf = getAdjointBuffer(bb, origElt);
|
|
addToAdjointBuffer(bb, dti->getOperand(), adjEltBuf, loc);
|
|
}
|
|
++adjIndex;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `load` or `load_borrow` instruction
|
|
/// Original: y = load/load_borrow x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitLoadOperation(SingleValueInstruction *inst) {
|
|
assert(isa<LoadInst>(inst) || isa<LoadBorrowInst>(inst));
|
|
auto *bb = inst->getParent();
|
|
auto loc = inst->getLoc();
|
|
switch (getTangentValueCategory(inst)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjVal = materializeAdjointDirect(getAdjointValue(bb, inst), loc);
|
|
// Allocate a local buffer and store the adjoint value. This buffer will
|
|
// be used for accumulation into the adjoint buffer.
|
|
auto adjBuf = builder.createAllocStack(loc, adjVal->getType(), {},
|
|
DoesNotHaveDynamicLifetime,
|
|
IsNotLexical, IsNotFromVarDecl,
|
|
DoesNotUseMoveableValueDebugInfo,
|
|
/* skipVarDeclAssert = */ true);
|
|
auto copy = builder.emitCopyValueOperation(loc, adjVal);
|
|
builder.emitStoreValueOperation(loc, copy, adjBuf,
|
|
StoreOwnershipQualifier::Init);
|
|
// Accumulate the adjoint value in the local buffer into the adjoint
|
|
// buffer.
|
|
addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc);
|
|
builder.emitDestroyAddr(loc, adjBuf);
|
|
builder.createDeallocStack(loc, adjBuf);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(bb, inst);
|
|
addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
void visitLoadInst(LoadInst *li) { visitLoadOperation(li); }
|
|
void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); }
|
|
|
|
/// Handle `store` or `store_borrow` instruction.
|
|
/// Original: store/store_borrow x to y
|
|
/// Adjoint: adj[x] += load adj[y]; adj[y] = 0
|
|
void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc,
|
|
SILValue origDest) {
|
|
auto adjBuf = getAdjointBuffer(bb, origDest);
|
|
switch (getTangentValueCategory(origSrc)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjVal = builder.emitLoadValueOperation(
|
|
loc, adjBuf, LoadOwnershipQualifier::Take);
|
|
recordTemporary(adjVal);
|
|
addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc);
|
|
builder.emitZeroIntoBuffer(loc, adjBuf, IsInitialization);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
addToAdjointBuffer(bb, origSrc, adjBuf, loc);
|
|
builder.emitZeroIntoBuffer(loc, adjBuf, IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
void visitStoreInst(StoreInst *si) {
|
|
visitStoreOperation(si->getParent(), si->getLoc(), si->getSrc(),
|
|
si->getDest());
|
|
}
|
|
void visitStoreBorrowInst(StoreBorrowInst *sbi) {
|
|
visitStoreOperation(sbi->getParent(), sbi->getLoc(), sbi->getSrc(),
|
|
sbi);
|
|
}
|
|
|
|
/// Handle `copy_addr` instruction.
|
|
/// Original: copy_addr x to y
|
|
/// Adjoint: adj[x] += adj[y]; adj[y] = 0
|
|
void visitCopyAddrInst(CopyAddrInst *cai) {
|
|
auto *bb = cai->getParent();
|
|
auto adjDest = getAdjointBuffer(bb, cai->getDest());
|
|
addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc());
|
|
builder.emitZeroIntoBuffer(cai->getLoc(), adjDest, IsNotInitialization);
|
|
}
|
|
|
|
/// Handle any ownership instruction that deals with values: copy_value,
|
|
/// move_value, begin_borrow.
|
|
/// Original: y = copy_value x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitValueOwnershipInst(SingleValueInstruction *svi,
|
|
bool needZeroResAdj = false) {
|
|
assert(svi->getNumOperands() == 1);
|
|
auto *bb = svi->getParent();
|
|
switch (getTangentValueCategory(svi)) {
|
|
case SILValueCategory::Object: {
|
|
auto adj = getAdjointValue(bb, svi);
|
|
addAdjointValue(bb, svi->getOperand(0), adj, svi->getLoc());
|
|
if (needZeroResAdj) {
|
|
assert(svi->getNumResults() == 1);
|
|
SILValue val = svi->getResult(0);
|
|
setAdjointValue(
|
|
bb, val,
|
|
makeZeroAdjointValue(getRemappedTangentType(val->getType())));
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjDest = getAdjointBuffer(bb, svi);
|
|
addToAdjointBuffer(bb, svi->getOperand(0), adjDest, svi->getLoc());
|
|
builder.emitZeroIntoBuffer(svi->getLoc(), adjDest, IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `copy_value` instruction.
|
|
/// Original: y = copy_value x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitCopyValueInst(CopyValueInst *cvi) { visitValueOwnershipInst(cvi); }
|
|
|
|
/// Handle `begin_borrow` instruction.
|
|
/// Original: y = begin_borrow x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitBeginBorrowInst(BeginBorrowInst *bbi) {
|
|
visitValueOwnershipInst(bbi);
|
|
}
|
|
|
|
/// Handle `move_value` instruction.
|
|
/// Original: y = move_value x
|
|
/// Adjoint: adj[x] += adj[y]; adj[y] = 0
|
|
void visitMoveValueInst(MoveValueInst *mvi) {
|
|
switch (getTangentValueCategory(mvi)) {
|
|
case SILValueCategory::Address:
|
|
LLVM_DEBUG(getADDebugStream() << "AutoDiff does not support move_value with "
|
|
"SILValueCategory::Address");
|
|
getContext().emitNondifferentiabilityError(
|
|
mvi, getInvoker(), diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
case SILValueCategory::Object:
|
|
visitValueOwnershipInst(mvi, /*needZeroResAdj=*/true);
|
|
}
|
|
}
|
|
|
|
void visitEndInitLetRefInst(EndInitLetRefInst *eir) { visitValueOwnershipInst(eir); }
|
|
|
|
/// Handle `begin_access` instruction.
|
|
/// Original: y = begin_access x
|
|
/// Adjoint: nothing
|
|
void visitBeginAccessInst(BeginAccessInst *bai) {
|
|
// Check for non-differentiable writes.
|
|
if (bai->getAccessKind() == SILAccessKind::Modify) {
|
|
if (isa<GlobalAddrInst>(bai->getSource())) {
|
|
getContext().emitNondifferentiabilityError(
|
|
bai, getInvoker(),
|
|
diag::autodiff_cannot_differentiate_writes_to_global_variables);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
if (isa<ProjectBoxInst>(bai->getSource())) {
|
|
getContext().emitNondifferentiabilityError(
|
|
bai, getInvoker(),
|
|
diag::autodiff_cannot_differentiate_writes_to_mutable_captures);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `unconditional_checked_cast_addr` instruction.
|
|
/// Original: y = unconditional_checked_cast_addr x
|
|
/// Adjoint: adj[x] += unconditional_checked_cast_addr adj[y]
|
|
void visitUnconditionalCheckedCastAddrInst(
|
|
UnconditionalCheckedCastAddrInst *uccai) {
|
|
auto *bb = uccai->getParent();
|
|
auto adjDest = getAdjointBuffer(bb, uccai->getDest());
|
|
auto adjSrc = getAdjointBuffer(bb, uccai->getSrc());
|
|
auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType());
|
|
builder.createUnconditionalCheckedCastAddr(
|
|
uccai->getLoc(), uccai->getCheckedCastOptions(),
|
|
adjDest, adjDest->getType().getASTType(), castBuf,
|
|
adjSrc->getType().getASTType());
|
|
addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc());
|
|
builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf);
|
|
builder.createDeallocStack(uccai->getLoc(), castBuf);
|
|
builder.emitZeroIntoBuffer(uccai->getLoc(), adjDest, IsInitialization);
|
|
}
|
|
|
|
/// Handle `enum` instruction.
|
|
/// Original: y = enum $Enum, #Enum.some!enumelt, x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
void visitEnumInst(EnumInst *ei) {
|
|
SILBasicBlock *bb = ei->getParent();
|
|
SILLocation loc = ei->getLoc();
|
|
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
|
|
|
|
// Only `Optional`-typed operands are supported for now. Diagnose all other
|
|
// enum operand types.
|
|
if (ei->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Unsupported enum type in PullbackCloner: " << *ei);
|
|
getContext().emitNondifferentiabilityError(
|
|
ei, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
|
|
auto adjOpt = getAdjointValue(bb, ei);
|
|
auto adjStruct = materializeAdjointDirect(adjOpt, loc);
|
|
|
|
VarDecl *adjOptVar =
|
|
getASTContext().getOptionalTanValueDecl(adjStruct->getType().getASTType());
|
|
auto *adjVal = builder.createStructExtract(loc, adjStruct, adjOptVar);
|
|
|
|
EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
|
|
auto *adjData = builder.createUncheckedEnumData(loc, adjVal, someElemDecl);
|
|
|
|
addAdjointValue(bb, ei->getOperand(), makeConcreteAdjointValue(adjData), loc);
|
|
}
|
|
|
|
/// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
|
|
/// instructions.
|
|
///
|
|
/// Original: x = init_enum_data_addr y : $*Enum, #Enum.Case
|
|
/// inject_enum_addr y
|
|
///
|
|
/// Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y]
|
|
void visitInjectEnumAddrInst(InjectEnumAddrInst *inject) {
|
|
SILBasicBlock *bb = inject->getParent();
|
|
SILValue origEnum = inject->getOperand();
|
|
|
|
// Only `Optional`-typed operands are supported for now. Diagnose all other
|
|
// enum operand types.
|
|
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
|
|
if (origEnum->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Unsupported enum type in PullbackCloner: " << *inject);
|
|
getContext().emitNondifferentiabilityError(
|
|
inject, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
|
|
// No associated value => no adjoint to propagate
|
|
if (!inject->getElement()->hasAssociatedValues())
|
|
return;
|
|
|
|
InitEnumDataAddrInst *origData = nullptr;
|
|
for (auto use : origEnum->getUses()) {
|
|
if (auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser())) {
|
|
// We need a more complicated analysis when init_enum_data_addr and
|
|
// inject_enum_addr are in different blocks, or there is more than one
|
|
// such instruction. Bail out for now.
|
|
if (origData || init->getParent() != bb) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Could not find a matching init_enum_data_addr for: "
|
|
<< *inject);
|
|
getContext().emitNondifferentiabilityError(
|
|
inject, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
|
|
origData = init;
|
|
}
|
|
}
|
|
|
|
SILValue adjDest = getAdjointBuffer(bb, origEnum);
|
|
VarDecl *adjOptVar =
|
|
getASTContext().getOptionalTanValueDecl(adjDest->getType().getASTType());
|
|
|
|
SILLocation loc = origData->getLoc();
|
|
StructElementAddrInst *adjOpt =
|
|
builder.createStructElementAddr(loc, adjDest, adjOptVar);
|
|
|
|
// unchecked_take_enum_data_addr is destructive, so copy
|
|
// Optional<T.TangentVector> to a new alloca.
|
|
AllocStackInst *adjOptCopy =
|
|
createFunctionLocalAllocation(adjOpt->getType(), loc);
|
|
builder.createCopyAddr(loc, adjOpt, adjOptCopy, IsNotTake,
|
|
IsInitialization);
|
|
// The Optional copy is invalidated, do not attempt to destroy it at the end
|
|
// of the pullback. The value returned from unchecked_take_enum_data_addr is
|
|
// destroyed in visitInitEnumDataAddrInst.
|
|
auto [_, inserted] = enumDataAdjCopies.try_emplace(origData, adjOptCopy);
|
|
assert(inserted && "expected single buffer");
|
|
|
|
EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
|
|
UncheckedTakeEnumDataAddrInst *adjData =
|
|
builder.createUncheckedTakeEnumDataAddr(loc, adjOptCopy, someElemDecl);
|
|
|
|
addToAdjointBuffer(bb, origData, adjData, loc);
|
|
}
|
|
|
|
/// Handle `init_enum_data_addr` instruction.
|
|
/// Destroy the value returned from `unchecked_take_enum_data_addr`.
|
|
void visitInitEnumDataAddrInst(InitEnumDataAddrInst *init) {
|
|
SILValue adjOptCopy = enumDataAdjCopies.at(init);
|
|
|
|
builder.emitDestroyAddr(init->getLoc(), adjOptCopy);
|
|
destroyedLocalAllocations.insert(adjOptCopy);
|
|
enumDataAdjCopies.erase(init);
|
|
}
|
|
|
|
/// Handle `unchecked_ref_cast` instruction.
|
|
/// Original: y = unchecked_ref_cast x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
/// (assuming adj[x] and adj[y] have the same type)
|
|
void visitUncheckedRefCastInst(UncheckedRefCastInst *urci) {
|
|
auto *bb = urci->getParent();
|
|
assert(urci->getOperand()->getType().isObject());
|
|
assert(getRemappedTangentType(urci->getOperand()->getType()) ==
|
|
getRemappedTangentType(urci->getType()) &&
|
|
"Operand/result must have the same `TangentVector` type");
|
|
switch (getTangentValueCategory(urci)) {
|
|
case SILValueCategory::Object: {
|
|
auto adj = getAdjointValue(bb, urci);
|
|
addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc());
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjDest = getAdjointBuffer(bb, urci);
|
|
addToAdjointBuffer(bb, urci->getOperand(), adjDest, urci->getLoc());
|
|
builder.emitZeroIntoBuffer(urci->getLoc(), adjDest, IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `upcast` instruction.
|
|
/// Original: y = upcast x
|
|
/// Adjoint: adj[x] += adj[y]
|
|
/// (assuming adj[x] and adj[y] have the same type)
|
|
void visitUpcastInst(UpcastInst *ui) {
|
|
auto *bb = ui->getParent();
|
|
assert(ui->getOperand()->getType().isObject());
|
|
assert(getRemappedTangentType(ui->getOperand()->getType()) ==
|
|
getRemappedTangentType(ui->getType()) &&
|
|
"Operand/result must have the same `TangentVector` type");
|
|
switch (getTangentValueCategory(ui)) {
|
|
case SILValueCategory::Object: {
|
|
auto adj = getAdjointValue(bb, ui);
|
|
addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc());
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjDest = getAdjointBuffer(bb, ui);
|
|
addToAdjointBuffer(bb, ui->getOperand(), adjDest, ui->getLoc());
|
|
builder.emitZeroIntoBuffer(ui->getLoc(), adjDest, IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle `unchecked_take_enum_data_addr` instruction.
|
|
/// Currently, only `Optional`-typed operands are supported.
|
|
/// Original: y = unchecked_take_enum_data_addr x : $*Enum, #Enum.Case
|
|
/// Adjoint: adj[x] += $Enum.TangentVector(adj[y])
|
|
void
|
|
visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *utedai) {
|
|
auto *bb = utedai->getParent();
|
|
auto adjDest = getAdjointBuffer(bb, utedai);
|
|
auto enumTy = utedai->getOperand()->getType();
|
|
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
|
|
// Only `Optional`-typed operands are supported for now. Diagnose all other
|
|
// enum operand types.
|
|
if (enumTy.getASTType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Unhandled instruction in PullbackCloner: " << *utedai);
|
|
getContext().emitNondifferentiabilityError(
|
|
utedai, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
accumulateAdjointForOptionalBuffer(bb, utedai->getOperand(), adjDest);
|
|
builder.emitZeroIntoBuffer(utedai->getLoc(), adjDest, IsNotInitialization);
|
|
}
|
|
|
|
#define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst);
|
|
#undef NOT_DIFFERENTIABLE
|
|
|
|
#define NO_ADJOINT(INST) \
|
|
void visit##INST##Inst(INST##Inst *inst) {}
|
|
// Terminators.
|
|
NO_ADJOINT(Return)
|
|
NO_ADJOINT(Branch)
|
|
NO_ADJOINT(CondBranch)
|
|
NO_ADJOINT(Yield)
|
|
|
|
// Address projections.
|
|
NO_ADJOINT(StructElementAddr)
|
|
NO_ADJOINT(TupleElementAddr)
|
|
|
|
// Array literal initialization address projections.
|
|
NO_ADJOINT(PointerToAddress)
|
|
NO_ADJOINT(IndexAddr)
|
|
|
|
// Memory allocation/access.
|
|
NO_ADJOINT(AllocStack)
|
|
NO_ADJOINT(DeallocStack)
|
|
NO_ADJOINT(EndAccess)
|
|
|
|
// Debugging/reference counting instructions.
|
|
NO_ADJOINT(DebugValue)
|
|
NO_ADJOINT(RetainValue)
|
|
NO_ADJOINT(RetainValueAddr)
|
|
NO_ADJOINT(ReleaseValue)
|
|
NO_ADJOINT(ReleaseValueAddr)
|
|
NO_ADJOINT(StrongRetain)
|
|
NO_ADJOINT(StrongRelease)
|
|
NO_ADJOINT(UnownedRetain)
|
|
NO_ADJOINT(UnownedRelease)
|
|
NO_ADJOINT(StrongRetainUnowned)
|
|
NO_ADJOINT(DestroyValue)
|
|
NO_ADJOINT(DestroyAddr)
|
|
|
|
// Value ownership.
|
|
NO_ADJOINT(EndBorrow)
|
|
#undef NO_ADJOINT
|
|
};
|
|
|
|
PullbackCloner::Implementation::Implementation(VJPCloner &vjpCloner)
|
|
: vjpCloner(vjpCloner), scopeCloner(getPullback()),
|
|
builder(getPullback(), getContext()),
|
|
localAllocBuilder(getPullback(), getContext()) {
|
|
// Get dominance and post-order info for the original function.
|
|
auto &passManager = getContext().getPassManager();
|
|
auto *domAnalysis = passManager.getAnalysis<DominanceAnalysis>();
|
|
auto *postDomAnalysis = passManager.getAnalysis<PostDominanceAnalysis>();
|
|
auto *postOrderAnalysis = passManager.getAnalysis<PostOrderAnalysis>();
|
|
auto *original = &vjpCloner.getOriginal();
|
|
domInfo = domAnalysis->get(original);
|
|
postDomInfo = postDomAnalysis->get(original);
|
|
postOrderInfo = postOrderAnalysis->get(original);
|
|
// Initialize `originalExitBlock`.
|
|
auto origExitIt = original->findReturnBB();
|
|
assert(origExitIt != original->end() &&
|
|
"Functions without returns must have been diagnosed");
|
|
originalExitBlock = &*origExitIt;
|
|
localAllocBuilder.setCurrentDebugScope(
|
|
remapScope(originalExitBlock->getTerminator()->getDebugScope()));
|
|
}
|
|
|
|
PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)
|
|
: impl(*new Implementation(vjpCloner)) {}
|
|
|
|
PullbackCloner::~PullbackCloner() { delete &impl; }
|
|
|
|
static SILValue getArrayValue(ApplyInst *ai) {
|
|
SILValue arrayValue;
|
|
for (auto use : ai->getUses()) {
|
|
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
|
|
if (!dti)
|
|
continue;
|
|
DEBUG_ASSERT(!arrayValue && "Array value already found");
|
|
// The first `destructure_tuple` result is the `Array` value.
|
|
arrayValue = dti->getResult(0);
|
|
#ifndef DEBUG_ASSERT_enabled
|
|
break;
|
|
#endif
|
|
}
|
|
ASSERT(arrayValue);
|
|
return arrayValue;
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Entry point
|
|
//--------------------------------------------------------------------------//
|
|
|
|
bool PullbackCloner::run() {
|
|
bool foundError = impl.run();
|
|
#ifndef NDEBUG
|
|
if (!foundError)
|
|
impl.getPullback().verify();
|
|
#endif
|
|
return foundError;
|
|
}
|
|
|
|
bool PullbackCloner::Implementation::run() {
|
|
PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal());
|
|
auto &original = getOriginal();
|
|
auto &pullback = getPullback();
|
|
auto pbLoc = getPullback().getLocation();
|
|
LLVM_DEBUG(getADDebugStream() << "Running PullbackCloner on\n" << original);
|
|
|
|
// Collect original formal results.
|
|
SmallVector<SILValue, 8> origFormalResults;
|
|
collectAllFormalResultsInTypeOrder(original, origFormalResults);
|
|
for (auto resultIndex : getConfig().resultIndices->getIndices()) {
|
|
auto origResult = origFormalResults[resultIndex];
|
|
// If original result is non-varied, it will always have a zero derivative.
|
|
// Skip full pullback generation and simply emit zero derivatives for wrt
|
|
// parameters.
|
|
//
|
|
// NOTE(TF-876): This shortcut is currently necessary for functions
|
|
// returning non-varied result with >1 basic block where some basic blocks
|
|
// have no dominated active values; control flow differentiation does not
|
|
// handle this case. See TF-876 for context.
|
|
if (!getActivityInfo().isVaried(origResult, getConfig().parameterIndices)) {
|
|
emitZeroDerivativesForNonvariedResult(origResult);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Collect dominated active values in original basic blocks.
|
|
// Adjoint values of dominated active values are passed as pullback block
|
|
// arguments.
|
|
DominanceOrder domOrder(original.getEntryBlock(), domInfo);
|
|
// Keep track of visited values.
|
|
SmallPtrSet<SILValue, 8> visited;
|
|
while (auto *bb = domOrder.getNext()) {
|
|
auto &bbActiveValues = activeValues[bb];
|
|
// If the current block has an immediate dominator, append the immediate
|
|
// dominator block's active values to the current block's active values.
|
|
if (auto *domNode = domInfo->getNode(bb)->getIDom()) {
|
|
auto &domBBActiveValues = activeValues[domNode->getBlock()];
|
|
bbActiveValues.append(domBBActiveValues.begin(), domBBActiveValues.end());
|
|
}
|
|
// If `v` is active and has not been visited, records it as an active value
|
|
// in the original basic block.
|
|
// For active values unsupported by differentiation, emits a diagnostic and
|
|
// returns true. Otherwise, returns false.
|
|
auto recordValueIfActive = [&](SILValue v) -> bool {
|
|
// If value is not active, skip.
|
|
if (!getActivityInfo().isActive(v, getConfig()))
|
|
return false;
|
|
// If active value has already been visited, skip.
|
|
if (visited.count(v))
|
|
return false;
|
|
// Mark active value as visited.
|
|
visited.insert(v);
|
|
|
|
// Diagnose unsupported active values.
|
|
auto type = v->getType();
|
|
// Do not emit remaining activity-related diagnostics for semantic member
|
|
// accessors, which have special-case pullback generation.
|
|
if (isSemanticMemberAccessor(&original))
|
|
return false;
|
|
// Diagnose active enum values. Differentiation of enum values requires
|
|
// special adjoint value handling and is not yet supported. Diagnose
|
|
// only the first active enum value to prevent too many diagnostics.
|
|
//
|
|
// Do not diagnose `Optional`-typed values, which will have special-case
|
|
// differentiation support.
|
|
if (type.getEnumOrBoundGenericEnum()) {
|
|
if (!type.getASTType()->isOptional()) {
|
|
getContext().emitNondifferentiabilityError(
|
|
v, getInvoker(), diag::autodiff_enums_unsupported);
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
}
|
|
// Diagnose unsupported stored property projections.
|
|
if (isa<StructExtractInst>(v) || isa<RefElementAddrInst>(v) ||
|
|
isa<StructElementAddrInst>(v)) {
|
|
auto *inst = cast<SingleValueInstruction>(v);
|
|
assert(inst->getNumOperands() == 1);
|
|
auto baseType = remapType(inst->getOperand(0)->getType()).getASTType();
|
|
if (!getTangentStoredProperty(getContext(), inst, baseType,
|
|
getInvoker())) {
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
}
|
|
// Skip address projections.
|
|
// Address projections do not need their own adjoint buffers; they
|
|
// become projections into their adjoint base buffer.
|
|
if (Projection::isAddressProjection(v))
|
|
return false;
|
|
|
|
// Co-routines borrow adjoint buffers for yields
|
|
if (isa_and_nonnull<BeginApplyInst>(v.getDefiningInstruction()))
|
|
return false;
|
|
|
|
// Check that active values are differentiable. Otherwise we may crash
|
|
// later when tangent space is required, but not available.
|
|
if (!getTangentSpace(remapType(type).getASTType())) {
|
|
getContext().emitNondifferentiabilityError(
|
|
v, getInvoker(), diag::autodiff_expression_not_differentiable_note);
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
|
|
// Record active value.
|
|
bbActiveValues.push_back(v);
|
|
return false;
|
|
};
|
|
// Record all active values in the basic block.
|
|
for (auto *arg : bb->getArguments())
|
|
if (recordValueIfActive(arg))
|
|
return true;
|
|
for (auto &inst : *bb) {
|
|
for (auto op : inst.getOperandValues())
|
|
if (recordValueIfActive(op))
|
|
return true;
|
|
for (auto result : inst.getResults())
|
|
if (recordValueIfActive(result))
|
|
return true;
|
|
}
|
|
domOrder.pushChildren(bb);
|
|
}
|
|
|
|
// Create pullback blocks and arguments, visiting original blocks using BFS
|
|
// starting from the original exit block. Unvisited original basic blocks
|
|
// (e.g unreachable blocks) are not relevant for pullback generation and thus
|
|
// ignored.
|
|
// The original blocks in traversal order for pullback generation.
|
|
SmallVector<SILBasicBlock *, 8> originalBlocks;
|
|
// The workqueue used for bookkeeping during the breadth-first traversal.
|
|
BasicBlockWorkqueue workqueue = {originalExitBlock};
|
|
|
|
// Perform BFS from the original exit block.
|
|
{
|
|
while (auto *BB = workqueue.pop()) {
|
|
originalBlocks.push_back(BB);
|
|
|
|
for (auto *nextBB : BB->getPredecessorBlocks()) {
|
|
// If there is no linear map tuple for predecessor BB, then BB is
|
|
// unreachable from function entry. Do not run pullback cloner on it.
|
|
if (getPullbackInfo().getLinearMapTupleType(nextBB))
|
|
workqueue.pushIfNotVisited(nextBB);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto *origBB : originalBlocks) {
|
|
auto *pullbackBB = pullback.createBasicBlock();
|
|
pullbackBBMap.insert({origBB, pullbackBB});
|
|
auto pbTupleLoweredType =
|
|
remapType(getPullbackInfo().getLinearMapTupleLoweredType(origBB));
|
|
// If the BB is the original exit, then the pullback block that we just
|
|
// created must be the pullback function's entry. For the pullback entry,
|
|
// create entry arguments and continue to the next block.
|
|
if (origBB == originalExitBlock) {
|
|
assert(pullbackBB->isEntry());
|
|
createEntryArguments(&pullback);
|
|
auto *origTerm = originalExitBlock->getTerminator();
|
|
builder.setCurrentDebugScope(remapScope(origTerm->getDebugScope()));
|
|
builder.setInsertionPoint(pullbackBB);
|
|
// Obtain the context object, if any, and the top-level subcontext, i.e.
|
|
// the main pullback struct.
|
|
if (getPullbackInfo().hasHeapAllocatedContext()) {
|
|
// The last argument is the context object (`Builtin.NativeObject`).
|
|
contextValue = pullbackBB->getArguments().back();
|
|
assert(contextValue->getType() ==
|
|
SILType::getNativeObjectType(getASTContext()));
|
|
// Load the pullback context.
|
|
auto subcontextAddr = emitProjectTopLevelSubcontext(
|
|
builder, pbLoc, contextValue, pbTupleLoweredType);
|
|
SILValue mainPullbackTuple = builder.createLoad(
|
|
pbLoc, subcontextAddr,
|
|
pbTupleLoweredType.isTrivial(getPullback()) ?
|
|
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
|
|
auto *dsi = builder.createDestructureTuple(pbLoc, mainPullbackTuple);
|
|
initializePullbackTupleElements(origBB, dsi->getAllResults());
|
|
} else {
|
|
// Obtain and destructure pullback struct elements.
|
|
unsigned numVals = pbTupleLoweredType.getAs<TupleType>()->getNumElements();
|
|
initializePullbackTupleElements(origBB,
|
|
pullbackBB->getArguments().take_back(numVals));
|
|
}
|
|
|
|
continue;
|
|
}
|
|
|
|
// Get all active values in the original block.
|
|
// If the original block has no active values, continue.
|
|
auto &bbActiveValues = activeValues[origBB];
|
|
if (bbActiveValues.empty())
|
|
continue;
|
|
|
|
// Otherwise, if the original block has active values:
|
|
// - For each active buffer in the original block, allocate a new local
|
|
// buffer in the pullback entry. (All adjoint buffers are allocated in
|
|
// the pullback entry and deallocated in the pullback exit.)
|
|
// - For each active value in the original block, add adjoint value
|
|
// arguments to the pullback block.
|
|
for (auto activeValue : bbActiveValues) {
|
|
// Handle the active value based on its value category.
|
|
switch (getTangentValueCategory(activeValue)) {
|
|
case SILValueCategory::Address: {
|
|
// Allocate and zero initialize a new local buffer using
|
|
// `getAdjointBuffer`.
|
|
builder.setCurrentDebugScope(
|
|
remapScope(originalExitBlock->getTerminator()->getDebugScope()));
|
|
builder.setInsertionPoint(pullback.getEntryBlock());
|
|
getAdjointBuffer(origBB, activeValue);
|
|
break;
|
|
}
|
|
case SILValueCategory::Object: {
|
|
// Create and register pullback block argument for the active value.
|
|
auto *pullbackArg = pullbackBB->createPhiArgument(
|
|
getRemappedTangentType(activeValue->getType()),
|
|
OwnershipKind::Owned);
|
|
activeValuePullbackBBArgumentMap[{origBB, activeValue}] = pullbackArg;
|
|
recordTemporary(pullbackArg);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
// Add a pullback tuple argument.
|
|
auto *pbTupleArg = pullbackBB->createPhiArgument(pbTupleLoweredType,
|
|
OwnershipKind::Owned);
|
|
// Destructure the pullback struct to get the elements.
|
|
builder.setCurrentDebugScope(
|
|
remapScope(origBB->getTerminator()->getDebugScope()));
|
|
builder.setInsertionPoint(pullbackBB);
|
|
auto *dsi = builder.createDestructureTuple(pbLoc, pbTupleArg);
|
|
initializePullbackTupleElements(origBB, dsi->getResults());
|
|
|
|
// - Create pullback trampoline blocks for each successor block of the
|
|
// original block. Pullback trampoline blocks only have a pullback
|
|
// struct argument. They branch from a pullback successor block to the
|
|
// pullback original block, passing adjoint values of active values.
|
|
for (auto *succBB : origBB->getSuccessorBlocks()) {
|
|
// Skip generating pullback block for original unreachable blocks.
|
|
if (!workqueue.isVisited(succBB))
|
|
continue;
|
|
auto *pullbackTrampolineBB = pullback.createBasicBlockBefore(pullbackBB);
|
|
pullbackTrampolineBBMap.insert({{origBB, succBB}, pullbackTrampolineBB});
|
|
// Get the enum element type (i.e. the pullback struct type). The enum
|
|
// element type may be boxed if the enum is indirect.
|
|
auto enumLoweredTy =
|
|
getPullbackInfo().getBranchingTraceEnumLoweredType(succBB);
|
|
auto *enumEltDecl =
|
|
getPullbackInfo().lookUpBranchingTraceEnumElement(origBB, succBB);
|
|
auto enumEltType = remapType(enumLoweredTy.getEnumElementType(
|
|
enumEltDecl, getModule(), TypeExpansionContext::minimal()));
|
|
pullbackTrampolineBB->createPhiArgument(enumEltType,
|
|
OwnershipKind::Owned);
|
|
}
|
|
}
|
|
|
|
auto *pullbackEntry = pullback.getEntryBlock();
|
|
auto pbTupleLoweredType =
|
|
remapType(getPullbackInfo().getLinearMapTupleLoweredType(originalExitBlock));
|
|
unsigned numVals = (getPullbackInfo().hasHeapAllocatedContext() ?
|
|
1 : pbTupleLoweredType.getAs<TupleType>()->getNumElements());
|
|
(void)numVals;
|
|
|
|
// The pullback function has type:
|
|
// `(seed0, seed1, ..., (exit_pb_tuple_el0, ..., )|context_obj) -> (d_arg0, ..., d_argn)`.
|
|
auto conv = getOriginal().getConventions();
|
|
auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults();
|
|
assert(getConfig().resultIndices->getNumIndices() - conv.getNumYields() == pbParamArgs.size() - numVals &&
|
|
pbParamArgs.size() >= 1);
|
|
// Assign adjoints for original result.
|
|
builder.setCurrentDebugScope(
|
|
remapScope(originalExitBlock->getTerminator()->getDebugScope()));
|
|
builder.setInsertionPoint(pullbackEntry,
|
|
getNextFunctionLocalAllocationInsertionPoint());
|
|
unsigned seedIndex = 0;
|
|
unsigned firstSemanticParamResultIdx = conv.getResults().size();
|
|
unsigned firstYieldResultIndex = firstSemanticParamResultIdx +
|
|
conv.getNumAutoDiffSemanticResultParameters();
|
|
for (auto resultIndex : getConfig().resultIndices->getIndices()) {
|
|
// Yields seed buffers are only to be touched in yield BB and required
|
|
// special handling
|
|
if (resultIndex >= firstYieldResultIndex)
|
|
continue;
|
|
|
|
auto origResult = origFormalResults[resultIndex];
|
|
auto *seed = pbParamArgs[seedIndex];
|
|
if (seed->getType().isAddress()) {
|
|
// If the seed argument is an `inout` parameter, assign it directly as
|
|
// the adjoint buffer of the original result.
|
|
auto seedParamInfo =
|
|
pullback.getLoweredFunctionType()->getParameters()[seedIndex];
|
|
|
|
if (seedParamInfo.isIndirectInOut()) {
|
|
setAdjointBuffer(originalExitBlock, origResult, seed);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Assigned seed buffer " << *seed
|
|
<< " as the adjoint of original indirect result "
|
|
<< origResult);
|
|
}
|
|
// Otherwise, assign a copy of the seed argument as the adjoint buffer of
|
|
// the original result.
|
|
else {
|
|
auto *seedBufCopy =
|
|
createFunctionLocalAllocation(seed->getType(), pbLoc);
|
|
builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake,
|
|
IsInitialization);
|
|
setAdjointBuffer(originalExitBlock, origResult, seedBufCopy);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Assigned seed buffer " << *seedBufCopy
|
|
<< " as the adjoint of original indirect result "
|
|
<< origResult);
|
|
}
|
|
} else {
|
|
addAdjointValue(originalExitBlock, origResult, makeConcreteAdjointValue(seed),
|
|
pbLoc);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Assigned seed " << *seed
|
|
<< " as the adjoint of original result " << origResult);
|
|
}
|
|
++seedIndex;
|
|
}
|
|
|
|
// If the original function is an accessor with special-case pullback
|
|
// generation logic, do special-case generation.
|
|
bool isSemanticMemberAcc = isSemanticMemberAccessor(&original);
|
|
if (isSemanticMemberAcc) {
|
|
if (runForSemanticMemberAccessor())
|
|
return true;
|
|
}
|
|
// Otherwise, perform standard pullback generation.
|
|
// Visit original blocks in post-order and perform differentiation
|
|
// in corresponding pullback blocks. If errors occurred, back out.
|
|
else {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Begin search for adjoints of loop-local active values\n");
|
|
llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
|
|
loopLocalActiveValues;
|
|
for (auto *bb : originalBlocks) {
|
|
const SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb);
|
|
if (loop == nullptr)
|
|
continue;
|
|
SILBasicBlock *loopHeader = loop->getHeader();
|
|
SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Original bb" << bb->getDebugID()
|
|
<< " belongs to a loop, original header bb"
|
|
<< loopHeader->getDebugID() << ", pullback header bb"
|
|
<< pbLoopHeader->getDebugID() << '\n');
|
|
builder.setInsertionPoint(pbLoopHeader);
|
|
auto bbActiveValuesIt = activeValues.find(bb);
|
|
if (bbActiveValuesIt == activeValues.end())
|
|
continue;
|
|
const auto &bbActiveValues = bbActiveValuesIt->second;
|
|
for (SILValue bbActiveValue : bbActiveValues) {
|
|
if (vjpCloner.getLoopInfo()->getLoopFor(
|
|
bbActiveValue->getParentBlock()) != loop) {
|
|
LLVM_DEBUG(
|
|
getADDebugStream()
|
|
<< "The following active value is NOT loop-local, skipping: "
|
|
<< bbActiveValue);
|
|
continue;
|
|
}
|
|
|
|
auto [_, wasInserted] =
|
|
loopLocalActiveValues[loop].insert(bbActiveValue);
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "The following active value is loop-local, ");
|
|
if (!wasInserted) {
|
|
LLVM_DEBUG(llvm::dbgs() << "but it was already processed, skipping: "
|
|
<< bbActiveValue);
|
|
continue;
|
|
}
|
|
|
|
if (getTangentValueCategory(bbActiveValue) ==
|
|
SILValueCategory::Object) {
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "zeroing its adjoint value in loop header: "
|
|
<< bbActiveValue);
|
|
setAdjointValue(bb, bbActiveValue,
|
|
makeZeroAdjointValue(getRemappedTangentType(
|
|
bbActiveValue->getType())));
|
|
continue;
|
|
}
|
|
|
|
ASSERT(getTangentValueCategory(bbActiveValue) ==
|
|
SILValueCategory::Address);
|
|
|
|
// getAdjointProjection might call materializeAdjointDirect which
|
|
// writes to debug output, emit \n.
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "checking if it's adjoint is a projection\n");
|
|
|
|
if (!getAdjointProjection(bb, bbActiveValue)) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Adjoint for the following value is NOT a projection, "
|
|
"zeroing its adjoint buffer in loop header: "
|
|
<< bbActiveValue);
|
|
|
|
// All adjoint buffers are allocated in the pullback entry and
|
|
// deallocated in the pullback exit. So, use IsNotInitialization to
|
|
// emit destroy_addr before zeroing the buffer.
|
|
ASSERT(bufferMap.contains({bb, bbActiveValue}));
|
|
builder.emitZeroIntoBuffer(pbLoc, getAdjointBuffer(bb, bbActiveValue),
|
|
IsNotInitialization);
|
|
|
|
continue;
|
|
}
|
|
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Adjoint for the following value is a projection, ");
|
|
|
|
// If Projection::isAddressProjection(v) is true for a value v, it
|
|
// is not added to active values list (see recordValueIfActive).
|
|
//
|
|
// Ensure that only the following value types conforming to
|
|
// getAdjointProjection but not conforming to
|
|
// Projection::isAddressProjection can go here.
|
|
//
|
|
// Instructions conforming to Projection::isAddressProjection and
|
|
// thus never corresponding to an active value do not need any
|
|
// handling, because only active values can have adjoints from
|
|
// previous iterations propagated via BB arguments.
|
|
do {
|
|
// Consider '%X = begin_access [modify] [static] %Y'.
|
|
// 1. If %Y is loop-local, it's adjoint buffer will
|
|
// be zeroed, and we'll have zero adjoint projection to it.
|
|
// 2. Otherwise, we do not need to zero the projection buffer.
|
|
// Thus, we can just skip.
|
|
if (dyn_cast<BeginAccessInst>(bbActiveValue)) {
|
|
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
|
|
break;
|
|
}
|
|
|
|
// Consider the following sequence:
|
|
// %1 = function_ref @allocUninitArray
|
|
// %2 = apply %1<Float>(%0)
|
|
// (%3, %4) = destructure_tuple %2
|
|
// %5 = mark_dependence %4 on %3
|
|
// %6 = pointer_to_address %6 to [strict] $*Float
|
|
// Since %6 is active, %3 (which is an array) must also be active.
|
|
// Thus, adjoint for %3 will be zeroed if needed. Ensure that expected
|
|
// invariants hold and then skip.
|
|
if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress(
|
|
bbActiveValue)) {
|
|
ASSERT(isa<PointerToAddressInst>(bbActiveValue));
|
|
SILValue arrayValue = getArrayValue(ai);
|
|
ASSERT(llvm::find(bbActiveValues, arrayValue) !=
|
|
bbActiveValues.end());
|
|
ASSERT(vjpCloner.getLoopInfo()->getLoopFor(
|
|
arrayValue->getParentBlock()) == loop);
|
|
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
|
|
break;
|
|
}
|
|
|
|
ASSERT(false);
|
|
} while (false);
|
|
}
|
|
}
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "End search for adjoints of loop-local active values\n");
|
|
|
|
for (auto *bb : originalBlocks) {
|
|
visitSILBasicBlock(bb);
|
|
if (errorOccurred)
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// Prepare and emit a `return` in the pullback exit block.
|
|
auto *origEntry = getOriginal().getEntryBlock();
|
|
auto *pbExit = getPullbackBlock(origEntry);
|
|
builder.setCurrentDebugScope(pbExit->back().getDebugScope());
|
|
builder.setInsertionPoint(pbExit);
|
|
|
|
// This vector will contain all the materialized return elements.
|
|
SmallVector<SILValue, 8> retElts;
|
|
// This vector will contain all indirect parameter adjoint buffers.
|
|
SmallVector<SILValue, 4> indParamAdjoints;
|
|
// This vector will identify the locations where initialization is needed.
|
|
SmallBitVector outputsToInitialize;
|
|
|
|
auto origParams = getOriginal().getArgumentsWithoutIndirectResults();
|
|
|
|
// Materializes the return element corresponding to the parameter
|
|
// `parameterIndex` into the `retElts` vector.
|
|
auto addRetElt = [&](unsigned parameterIndex) -> void {
|
|
auto origParam = origParams[parameterIndex];
|
|
switch (getTangentValueCategory(origParam)) {
|
|
case SILValueCategory::Object: {
|
|
auto pbVal = getAdjointValue(origEntry, origParam);
|
|
auto val = materializeAdjointDirect(pbVal, pbLoc);
|
|
auto newVal = builder.emitCopyValueOperation(pbLoc, val);
|
|
retElts.push_back(newVal);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjBuf = getAdjointBuffer(origEntry, origParam);
|
|
indParamAdjoints.push_back(adjBuf);
|
|
outputsToInitialize.push_back(
|
|
!conv.getParameters()[parameterIndex].isIndirectMutating());
|
|
break;
|
|
}
|
|
}
|
|
};
|
|
SmallVector<SILArgument *, 4> pullbackIndirectResults(
|
|
getPullback().getIndirectResults().begin(),
|
|
getPullback().getIndirectResults().end());
|
|
|
|
// Collect differentiation parameter adjoints.
|
|
// Do a first pass to collect non-inout values.
|
|
for (auto i : getConfig().parameterIndices->getIndices()) {
|
|
if (!conv.getParameters()[i].isAutoDiffSemanticResult()) {
|
|
addRetElt(i);
|
|
}
|
|
}
|
|
|
|
// Do a second pass for all inout parameters, however this is only necessary
|
|
// for functions with multiple basic blocks. For functions with a single
|
|
// basic block adjoint accumulation for those parameters is already done by
|
|
// per-instruction visitors.
|
|
if (getOriginal().size() > 1) {
|
|
const auto &pullbackConv = pullback.getConventions();
|
|
SmallVector<SILArgument *, 1> pullbackInOutArgs;
|
|
for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) {
|
|
if (pullbackConv.getParameters()[pullbackArg.index()].isAutoDiffSemanticResult())
|
|
pullbackInOutArgs.push_back(pullbackArg.value());
|
|
}
|
|
|
|
unsigned pullbackInoutArgumentIdx = 0;
|
|
for (auto i : getConfig().parameterIndices->getIndices()) {
|
|
// Skip non-inout parameters.
|
|
if (!conv.getParameters()[i].isAutoDiffSemanticResult())
|
|
continue;
|
|
|
|
// For functions with multiple basic blocks, accumulation is needed
|
|
// for `inout` parameters because pullback basic blocks have different
|
|
// adjoint buffers.
|
|
pullbackIndirectResults.push_back(pullbackInOutArgs[pullbackInoutArgumentIdx++]);
|
|
addRetElt(i);
|
|
}
|
|
}
|
|
|
|
// Copy them to adjoint indirect results.
|
|
assert(indParamAdjoints.size() == pullbackIndirectResults.size() &&
|
|
"Indirect parameter adjoint count mismatch");
|
|
unsigned currentIndex = 0;
|
|
for (auto pair : zip(indParamAdjoints, pullbackIndirectResults)) {
|
|
auto source = std::get<0>(pair);
|
|
auto *dest = std::get<1>(pair);
|
|
if (outputsToInitialize[currentIndex]) {
|
|
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
|
|
} else {
|
|
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsNotInitialization);
|
|
}
|
|
currentIndex++;
|
|
// Prevent source buffer from being deallocated, since the underlying
|
|
// value is moved.
|
|
destroyedLocalAllocations.insert(source);
|
|
}
|
|
|
|
// Emit cleanups for all local values.
|
|
cleanUpTemporariesForBlock(pbExit, pbLoc);
|
|
// Deallocate local allocations.
|
|
for (auto alloc : functionLocalAllocations) {
|
|
// Assert that local allocations have at least one use.
|
|
// Buffers should not be allocated needlessly.
|
|
assert(!alloc->use_empty());
|
|
if (!destroyedLocalAllocations.count(alloc)) {
|
|
builder.emitDestroyAddrAndFold(pbLoc, alloc);
|
|
destroyedLocalAllocations.insert(alloc);
|
|
}
|
|
builder.createDeallocStack(pbLoc, alloc);
|
|
}
|
|
builder.createReturn(pbLoc, joinElements(retElts, builder, pbLoc));
|
|
|
|
#ifndef NDEBUG
|
|
bool leakFound = false;
|
|
// Ensure all temporaries have been cleaned up.
|
|
for (auto &bb : pullback) {
|
|
for (auto temp : blockTemporaries[&bb]) {
|
|
if (blockTemporaries[&bb].count(temp)) {
|
|
leakFound = true;
|
|
getADDebugStream() << "Found leaked temporary:\n" << temp;
|
|
}
|
|
}
|
|
}
|
|
// Ensure all enum adjoint copeis have been cleaned up
|
|
for (const auto &enumData : enumDataAdjCopies) {
|
|
leakFound = true;
|
|
getADDebugStream() << "Found leaked temporary:\n" << enumData.second;
|
|
}
|
|
|
|
// Ensure all local allocations have been cleaned up.
|
|
for (auto localAlloc : functionLocalAllocations) {
|
|
if (!destroyedLocalAllocations.count(localAlloc)) {
|
|
leakFound = true;
|
|
getADDebugStream() << "Found leaked local buffer:\n" << localAlloc;
|
|
}
|
|
}
|
|
assert(!leakFound && "Leaks found!");
|
|
#endif
|
|
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Generated " << (isSemanticMemberAcc ? "semantic member accessor" : "normal")
|
|
<< " pullback for " << original.getName() << ":\n"
|
|
<< pullback);
|
|
return errorOccurred;
|
|
}
|
|
|
|
void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult(
|
|
SILValue origNonvariedResult) {
|
|
auto &pullback = getPullback();
|
|
auto pbLoc = getPullback().getLocation();
|
|
/*
|
|
// TODO(TF-788): Re-enable non-varied result warning.
|
|
// Emit fixit if original non-varied result has a valid source location.
|
|
auto startLoc = origNonvariedResult.getLoc().getStartSourceLoc();
|
|
auto endLoc = origNonvariedResult.getLoc().getEndSourceLoc();
|
|
if (startLoc.isValid() && endLoc.isValid()) {
|
|
getContext().diagnose(startLoc, diag::autodiff_nonvaried_result_fixit)
|
|
.fixItInsert(startLoc, "withoutDerivative(at:")
|
|
.fixItInsertAfter(endLoc, ")");
|
|
}
|
|
*/
|
|
LLVM_DEBUG(getADDebugStream() << getOriginal().getName()
|
|
<< " has non-varied result, returning zero"
|
|
" for all pullback results\n");
|
|
auto *pullbackEntry = pullback.createBasicBlock();
|
|
createEntryArguments(&pullback);
|
|
builder.setCurrentDebugScope(
|
|
remapScope(originalExitBlock->getTerminator()->getDebugScope()));
|
|
builder.setInsertionPoint(pullbackEntry);
|
|
// Destroy all owned arguments.
|
|
for (auto *arg : pullbackEntry->getArguments())
|
|
if (arg->getOwnershipKind() == OwnershipKind::Owned)
|
|
builder.emitDestroyOperation(pbLoc, arg);
|
|
// Return zero for each result.
|
|
SmallVector<SILValue, 4> directResults;
|
|
auto indirectResultIt = pullback.getIndirectResults().begin();
|
|
for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) {
|
|
auto resultType =
|
|
pullback.mapTypeIntoContext(resultInfo.getInterfaceType())
|
|
->getCanonicalType();
|
|
if (resultInfo.isFormalDirect())
|
|
directResults.push_back(builder.emitZero(pbLoc, resultType));
|
|
else
|
|
builder.emitZeroIntoBuffer(pbLoc, *indirectResultIt++, IsInitialization);
|
|
}
|
|
builder.createReturn(pbLoc, joinElements(directResults, builder, pbLoc));
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Generated pullback for " << getOriginal().getName() << ":\n"
|
|
<< pullback);
|
|
}
|
|
|
|
AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint(
|
|
SILBasicBlock *bb, SILValue wrappedAdjoint, SILType optionalTy) {
|
|
auto pbLoc = getPullback().getLocation();
|
|
// `Optional<T>`
|
|
optionalTy = remapType(optionalTy);
|
|
assert(optionalTy.getASTType()->isOptional());
|
|
// `T`
|
|
auto wrappedType = optionalTy.getOptionalObjectType();
|
|
// `T.TangentVector`
|
|
auto wrappedTanType = remapType(wrappedAdjoint->getType());
|
|
// `Optional<T.TangentVector>`
|
|
auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType);
|
|
// `Optional<T>.TangentVector`
|
|
auto optionalTanTy = getRemappedTangentType(optionalTy);
|
|
// Look up the `Optional<T>.TangentVector.init` declaration.
|
|
ConstructorDecl *constructorDecl =
|
|
getASTContext().getOptionalTanInitDecl(optionalTanTy.getASTType());
|
|
|
|
// Allocate a local buffer for the `Optional` adjoint value.
|
|
auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy);
|
|
// Find `Optional<T.TangentVector>.some` EnumElementDecl.
|
|
auto someEltDecl = builder.getASTContext().getOptionalSomeDecl();
|
|
|
|
// Initialize an `Optional<T.TangentVector>` buffer from `wrappedAdjoint` as
|
|
// the input for `Optional<T>.TangentVector.init`.
|
|
auto *optArgBuf = builder.createAllocStack(pbLoc, optionalOfWrappedTanType);
|
|
if (optionalOfWrappedTanType.isObject()) {
|
|
// %enum = enum $Optional<T.TangentVector>, #Optional.some!enumelt,
|
|
// %wrappedAdjoint : $T
|
|
auto *enumInst = builder.createEnum(pbLoc, wrappedAdjoint, someEltDecl,
|
|
optionalOfWrappedTanType);
|
|
// store %enum to %optArgBuf
|
|
builder.emitStoreValueOperation(pbLoc, enumInst, optArgBuf,
|
|
StoreOwnershipQualifier::Init);
|
|
} else {
|
|
// %enumAddr = init_enum_data_addr %optArgBuf $Optional<T.TangentVector>,
|
|
// #Optional.some!enumelt
|
|
auto *enumAddr = builder.createInitEnumDataAddr(
|
|
pbLoc, optArgBuf, someEltDecl, wrappedTanType.getAddressType());
|
|
// copy_addr %wrappedAdjoint to [init] %enumAddr
|
|
builder.createCopyAddr(pbLoc, wrappedAdjoint, enumAddr, IsNotTake,
|
|
IsInitialization);
|
|
// inject_enum_addr %optArgBuf : $*Optional<T.TangentVector>,
|
|
// #Optional.some!enumelt
|
|
builder.createInjectEnumAddr(pbLoc, optArgBuf, someEltDecl);
|
|
}
|
|
|
|
// Apply `Optional<T>.TangentVector.init`.
|
|
SILOptFunctionBuilder fb(getContext().getTransform());
|
|
// %init_fn = function_ref @Optional<T>.TangentVector.init
|
|
auto *initFn = fb.getOrCreateFunction(pbLoc, SILDeclRef(constructorDecl),
|
|
NotForDefinition);
|
|
auto *initFnRef = builder.createFunctionRef(pbLoc, initFn);
|
|
auto *diffProto =
|
|
builder.getASTContext().getProtocol(KnownProtocolKind::Differentiable);
|
|
auto diffConf =
|
|
lookupConformance(wrappedType.getASTType(), diffProto);
|
|
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
|
|
auto subMap = SubstitutionMap::get(
|
|
initFn->getLoweredFunctionType()->getSubstGenericSignature(),
|
|
ArrayRef<Type>(wrappedType.getASTType()), {diffConf});
|
|
// %metatype = metatype $Optional<T>.TangentVector.Type
|
|
auto metatypeType = CanMetatypeType::get(optionalTanTy.getASTType(),
|
|
MetatypeRepresentation::Thin);
|
|
auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
|
|
auto metatype = builder.createMetatype(pbLoc, metatypeSILType);
|
|
// apply %init_fn(%optTanAdjBuf, %optArgBuf, %metatype)
|
|
builder.createApply(pbLoc, initFnRef, subMap,
|
|
{optTanAdjBuf, optArgBuf, metatype});
|
|
builder.createDeallocStack(pbLoc, optArgBuf);
|
|
return optTanAdjBuf;
|
|
}
|
|
|
|
// Accumulate adjoint for the incoming `Optional` buffer.
|
|
void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer(
|
|
SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint) {
|
|
assert(getTangentValueCategory(optionalBuffer) == SILValueCategory::Address);
|
|
auto pbLoc = getPullback().getLocation();
|
|
|
|
// Allocate and initialize Optional<Wrapped>.TangentVector from
|
|
// Wrapped.TangentVector
|
|
AllocStackInst *optTanAdjBuf =
|
|
createOptionalAdjoint(bb, wrappedAdjoint, optionalBuffer->getType());
|
|
|
|
// Accumulate into optionalBuffer
|
|
addToAdjointBuffer(bb, optionalBuffer, optTanAdjBuf, pbLoc);
|
|
builder.emitDestroyAddr(pbLoc, optTanAdjBuf);
|
|
builder.createDeallocStack(pbLoc, optTanAdjBuf);
|
|
}
|
|
|
|
// Accumulate adjoint for the incoming `Optional` value.
|
|
void PullbackCloner::Implementation::accumulateAdjointValueForOptional(
|
|
SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) {
|
|
assert(getTangentValueCategory(optionalValue) == SILValueCategory::Object);
|
|
auto pbLoc = getPullback().getLocation();
|
|
|
|
// Allocate and initialize Optional<Wrapped>.TangentVector from
|
|
// Wrapped.TangentVector
|
|
AllocStackInst *optTanAdjBuf =
|
|
createOptionalAdjoint(bb, wrappedAdjoint, optionalValue->getType());
|
|
|
|
auto optTanAdjVal = builder.emitLoadValueOperation(
|
|
pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take);
|
|
|
|
recordTemporary(optTanAdjVal);
|
|
builder.createDeallocStack(pbLoc, optTanAdjBuf);
|
|
|
|
addAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal), pbLoc);
|
|
}
|
|
|
|
SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor(
|
|
SILBasicBlock *origBB, SILBasicBlock *origPredBB,
|
|
SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) {
|
|
// Get the pullback block and optional pullback trampoline block of the
|
|
// predecessor block.
|
|
auto *pullbackBB = getPullbackBlock(origPredBB);
|
|
auto *pullbackTrampolineBB = getPullbackTrampolineBlock(origPredBB, origBB);
|
|
// If the predecessor block does not have a corresponding pullback
|
|
// trampoline block, then the pullback successor is the pullback block.
|
|
if (!pullbackTrampolineBB)
|
|
return pullbackBB;
|
|
|
|
// Otherwise, the pullback successor is the pullback trampoline block,
|
|
// which branches to the pullback block and propagates adjoint values of
|
|
// active values.
|
|
assert(pullbackTrampolineBB->getNumArguments() == 1);
|
|
auto loc = origBB->getParent()->getLocation();
|
|
SmallVector<SILValue, 8> trampolineArguments;
|
|
|
|
// Propagate adjoint values/buffers of active values/buffers to
|
|
// predecessor blocks.
|
|
auto &predBBActiveValues = activeValues[origPredBB];
|
|
llvm::SmallSet<std::pair<SILValue, SILValue>, 32> propagatedAdjoints;
|
|
for (auto activeValue : predBBActiveValues) {
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Propagating adjoint of active value " << activeValue
|
|
<< "from bb" << origBB->getDebugID()
|
|
<< " to predecessors' (bb" << origPredBB->getDebugID()
|
|
<< ") pullback blocks\n");
|
|
switch (getTangentValueCategory(activeValue)) {
|
|
case SILValueCategory::Object: {
|
|
auto activeValueAdj = getAdjointValue(origBB, activeValue);
|
|
auto concreteActiveValueAdj =
|
|
materializeAdjointDirect(activeValueAdj, loc);
|
|
|
|
if (!pullbackTrampolineBlockMap.count(concreteActiveValueAdj)) {
|
|
concreteActiveValueAdj =
|
|
builder.emitCopyValueOperation(loc, concreteActiveValueAdj);
|
|
setAdjointValue(origBB, activeValue,
|
|
makeConcreteAdjointValue(concreteActiveValueAdj));
|
|
}
|
|
auto insertion = pullbackTrampolineBlockMap.try_emplace(
|
|
concreteActiveValueAdj, TrampolineBlockSet());
|
|
auto &blockSet = insertion.first->getSecond();
|
|
blockSet.insert(pullbackTrampolineBB);
|
|
trampolineArguments.push_back(concreteActiveValueAdj);
|
|
|
|
// If the pullback block does not yet have a registered adjoint
|
|
// value for the active value, set the adjoint value to the
|
|
// forwarded adjoint value argument.
|
|
// TODO: Hoist this logic out of loop over predecessor blocks to
|
|
// remove the `hasAdjointValue` check.
|
|
if (!hasAdjointValue(origPredBB, activeValue)) {
|
|
auto *pullbackBBArg =
|
|
getActiveValuePullbackBlockArgument(origPredBB, activeValue);
|
|
auto forwardedArgAdj = makeConcreteAdjointValue(pullbackBBArg);
|
|
setAdjointValue(origPredBB, activeValue, forwardedArgAdj);
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
// Propagate adjoint buffers using `copy_addr`.
|
|
auto adjBuf = getAdjointBuffer(origBB, activeValue);
|
|
auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue);
|
|
if (propagatedAdjoints.insert({adjBuf, predAdjBuf}).second)
|
|
builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake,
|
|
IsNotInitialization);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Propagate pullback struct argument.
|
|
TangentBuilder pullbackTrampolineBBBuilder(
|
|
pullbackTrampolineBB, getContext());
|
|
pullbackTrampolineBBBuilder.setCurrentDebugScope(
|
|
remapScope(origPredBB->getTerminator()->getDebugScope()));
|
|
|
|
auto *pullbackTrampolineBBArg = pullbackTrampolineBB->getArguments().front();
|
|
if (vjpCloner.getLoopInfo()->getLoopFor(origPredBB)) {
|
|
assert(pullbackTrampolineBBArg->getType() ==
|
|
SILType::getRawPointerType(getASTContext()));
|
|
auto pbTupleType =
|
|
remapType(getPullbackInfo().getLinearMapTupleLoweredType(origPredBB));
|
|
auto predPbTupleAddr = pullbackTrampolineBBBuilder.createPointerToAddress(
|
|
loc, pullbackTrampolineBBArg, pbTupleType.getAddressType(),
|
|
/*isStrict*/ true);
|
|
auto predPbStructVal = pullbackTrampolineBBBuilder.createLoad(
|
|
loc, predPbTupleAddr,
|
|
pbTupleType.isTrivial(getPullback()) ?
|
|
LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy);
|
|
trampolineArguments.push_back(predPbStructVal);
|
|
} else {
|
|
trampolineArguments.push_back(pullbackTrampolineBBArg);
|
|
}
|
|
// Branch from pullback trampoline block to pullback block.
|
|
pullbackTrampolineBBBuilder.createBranch(loc, pullbackBB,
|
|
trampolineArguments);
|
|
return pullbackTrampolineBB;
|
|
}
|
|
|
|
void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
|
|
auto pbLoc = getPullback().getLocation();
|
|
// Get the corresponding pullback basic block.
|
|
auto *pbBB = getPullbackBlock(bb);
|
|
builder.setInsertionPoint(pbBB);
|
|
|
|
LLVM_DEBUG({
|
|
auto &s = getADDebugStream()
|
|
<< "Original bb" + std::to_string(bb->getDebugID())
|
|
<< ": To differentiate or not to differentiate?\n";
|
|
for (auto &inst : llvm::reverse(*bb)) {
|
|
s << (getPullbackInfo().shouldDifferentiateInstruction(&inst) ? "[x] "
|
|
: "[ ] ")
|
|
<< inst;
|
|
}
|
|
});
|
|
|
|
// Visit each instruction in reverse order.
|
|
for (auto &inst : llvm::reverse(*bb)) {
|
|
if (!getPullbackInfo().shouldDifferentiateInstruction(&inst))
|
|
continue;
|
|
// Differentiate instruction.
|
|
builder.setCurrentDebugScope(remapScope(inst.getDebugScope()));
|
|
visit(&inst);
|
|
if (errorOccurred)
|
|
return;
|
|
}
|
|
|
|
// Emit a branching terminator for the block.
|
|
// If the original block is the original entry, then the pullback block is
|
|
// the pullback exit. This is handled specially in
|
|
// `PullbackCloner::Implementation::run()`, so we leave the block
|
|
// non-terminated.
|
|
if (bb->isEntry())
|
|
return;
|
|
|
|
// If the original block is a resume yield destination, then we need to yield
|
|
// the adjoint buffer and do everything else in the resume destination. Unwind
|
|
// destination is unreachable as the co-routine can never be aborted.
|
|
if (auto *predBB = bb->getSinglePredecessorBlock()) {
|
|
if (auto *yield = dyn_cast<YieldInst>(predBB->getTerminator())) {
|
|
auto *resumeBB = pbBB->split(builder.getInsertionPoint());
|
|
auto *unwindBB = getPullback().createBasicBlock();
|
|
|
|
SmallVector<SILValue, 1> adjYields;
|
|
for (auto yieldedVal : yield->getYieldedValues())
|
|
adjYields.push_back(getAdjointBuffer(bb, yieldedVal));
|
|
|
|
builder.createYield(yield->getLoc(), adjYields, resumeBB, unwindBB);
|
|
builder.setInsertionPoint(unwindBB);
|
|
builder.createUnreachable(SILLocation::invalid());
|
|
|
|
pbBB = resumeBB;
|
|
builder.setInsertionPoint(pbBB);
|
|
}
|
|
}
|
|
|
|
// Otherwise, add a `switch_enum` terminator for non-exit
|
|
// pullback blocks.
|
|
// 1. Get the pullback struct pullback block argument.
|
|
// 2. Extract the predecessor enum value from the pullback struct value.
|
|
auto *predEnum = getPullbackInfo().getBranchingTraceDecl(bb);
|
|
(void)predEnum;
|
|
auto predEnumVal = getPullbackPredTupleElement(bb);
|
|
|
|
// Propagate adjoint values from active basic block arguments to
|
|
// incoming values (predecessor terminator operands).
|
|
for (auto *bbArg : bb->getArguments()) {
|
|
if (!getActivityInfo().isActive(bbArg, getConfig()))
|
|
continue;
|
|
LLVM_DEBUG(getADDebugStream() << "Propagating adjoint value for active bb"
|
|
<< bb->getDebugID() << " argument: "
|
|
<< *bbArg);
|
|
|
|
// Get predecessor terminator operands.
|
|
SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues;
|
|
if (bbArg->getSingleTerminatorOperands(incomingValues)) {
|
|
// Returns true if the given terminator instruction is a `switch_enum` on
|
|
// an `Optional`-typed value. `switch_enum` instructions require
|
|
// special-case adjoint value propagation for the operand.
|
|
auto isSwitchEnumInstOnOptional =
|
|
[&ctx = getASTContext()](TermInst *termInst) {
|
|
if (!termInst)
|
|
return false;
|
|
if (auto *sei = dyn_cast<SwitchEnumInst>(termInst)) {
|
|
auto operandTy = sei->getOperand()->getType();
|
|
return operandTy.getASTType()->isOptional();
|
|
}
|
|
return false;
|
|
};
|
|
|
|
// Check the tangent value category of the active basic block argument.
|
|
switch (getTangentValueCategory(bbArg)) {
|
|
// If argument has a loadable tangent value category: materialize adjoint
|
|
// value of the argument, create a copy, and set the copy as the adjoint
|
|
// value of incoming values.
|
|
case SILValueCategory::Object: {
|
|
auto bbArgAdj = getAdjointValue(bb, bbArg);
|
|
auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc);
|
|
auto concreteBBArgAdjCopy =
|
|
builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj);
|
|
for (auto pair : incomingValues) {
|
|
auto *predBB = std::get<0>(pair);
|
|
auto incomingValue = std::get<1>(pair);
|
|
// Handle `switch_enum` on `Optional`.
|
|
auto termInst = bbArg->getSingleTerminator();
|
|
if (isSwitchEnumInstOnOptional(termInst)) {
|
|
accumulateAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy);
|
|
} else {
|
|
blockTemporaries[getPullbackBlock(predBB)].insert(
|
|
concreteBBArgAdjCopy);
|
|
addAdjointValue(predBB, incomingValue,
|
|
makeConcreteAdjointValue(concreteBBArgAdjCopy), pbLoc);
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
// If argument has an address tangent value category: materialize adjoint
|
|
// value of the argument, create a copy, and set the copy as the adjoint
|
|
// value of incoming values.
|
|
case SILValueCategory::Address: {
|
|
auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
|
|
for (auto pair : incomingValues) {
|
|
auto incomingValue = std::get<1>(pair);
|
|
// Handle `switch_enum` on `Optional`.
|
|
auto termInst = bbArg->getSingleTerminator();
|
|
if (isSwitchEnumInstOnOptional(termInst))
|
|
accumulateAdjointForOptionalBuffer(bb, incomingValue, bbArgAdjBuf);
|
|
else
|
|
addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
LLVM_DEBUG(getADDebugStream() <<
|
|
"do not know how to handle this incoming bb argument");
|
|
if (auto term = bbArg->getSingleTerminator()) {
|
|
getContext().emitNondifferentiabilityError(term, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
} else {
|
|
// This will be a bit confusing, but still better than nothing.
|
|
getContext().emitNondifferentiabilityError(bbArg, getInvoker(),
|
|
diag::autodiff_expression_not_differentiable_note);
|
|
}
|
|
|
|
errorOccurred = true;
|
|
return;
|
|
}
|
|
}
|
|
|
|
// 3. Build the pullback successor cases for the `switch_enum`
|
|
// instruction. The pullback successors correspond to the predecessors
|
|
// of the current original block.
|
|
SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4>
|
|
pullbackSuccessorCases;
|
|
// A map from active values' adjoint values to the trampoline blocks that
|
|
// are using them.
|
|
SmallDenseMap<SILValue, TrampolineBlockSet> pullbackTrampolineBlockMap;
|
|
SmallDenseMap<SILBasicBlock *, SILBasicBlock *> origPredpullbackSuccBBMap;
|
|
for (auto *predBB : bb->getPredecessorBlocks()) {
|
|
// If there is no linear map tuple for predecessor BB, then BB is
|
|
// unreachable from function entry. There is no branch tracing enum for it
|
|
// as well, so we should not create any branching to it in the pullback.
|
|
if (!getPullbackInfo().getLinearMapTupleType(predBB))
|
|
continue;
|
|
auto *pullbackSuccBB =
|
|
buildPullbackSuccessor(bb, predBB, pullbackTrampolineBlockMap);
|
|
origPredpullbackSuccBBMap[predBB] = pullbackSuccBB;
|
|
auto *enumEltDecl =
|
|
getPullbackInfo().lookUpBranchingTraceEnumElement(predBB, bb);
|
|
pullbackSuccessorCases.emplace_back(enumEltDecl, pullbackSuccBB);
|
|
}
|
|
// Values are trampolined by only a subset of pullback successor blocks.
|
|
// Other successors blocks should destroy the value.
|
|
for (auto pair : pullbackTrampolineBlockMap) {
|
|
auto value = pair.getFirst();
|
|
// The set of trampoline BBs that are users of `value`.
|
|
auto &userTrampolineBBSet = pair.getSecond();
|
|
// For each pullback successor block that does not trampoline the value,
|
|
// release the value.
|
|
for (auto origPredPbSuccPair : origPredpullbackSuccBBMap) {
|
|
auto *origPred = origPredPbSuccPair.getFirst();
|
|
auto *pbSucc = origPredPbSuccPair.getSecond();
|
|
if (userTrampolineBBSet.count(pbSucc))
|
|
continue;
|
|
TangentBuilder pullbackSuccBuilder(pbSucc->begin(), getContext());
|
|
pullbackSuccBuilder.setCurrentDebugScope(
|
|
remapScope(origPred->getTerminator()->getDebugScope()));
|
|
pullbackSuccBuilder.emitDestroyValueOperation(pbLoc, value);
|
|
}
|
|
}
|
|
// Emit cleanups for all block-local temporaries.
|
|
cleanUpTemporariesForBlock(pbBB, pbLoc);
|
|
// Branch to pullback successor blocks.
|
|
assert(pullbackSuccessorCases.size() == predEnum->getNumElements());
|
|
builder.createSwitchEnum(pbLoc, predEnumVal, /*DefaultBB*/ nullptr,
|
|
pullbackSuccessorCases, std::nullopt,
|
|
ProfileCounter(), OwnershipKind::Owned);
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Member accessor pullback generation
|
|
//--------------------------------------------------------------------------//
|
|
|
|
bool PullbackCloner::Implementation::runForSemanticMemberAccessor() {
|
|
auto &original = getOriginal();
|
|
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
|
|
switch (accessor->getAccessorKind()) {
|
|
case AccessorKind::Get:
|
|
case AccessorKind::DistributedGet:
|
|
return runForSemanticMemberGetter();
|
|
case AccessorKind::Set:
|
|
return runForSemanticMemberSetter();
|
|
case AccessorKind::Modify:
|
|
return runForSemanticMemberModify();
|
|
default:
|
|
llvm_unreachable("Unsupported accessor kind; inconsistent with "
|
|
"`isSemanticMemberAccessor`?");
|
|
}
|
|
}
|
|
|
|
bool PullbackCloner::Implementation::runForSemanticMemberGetter() {
|
|
auto &original = getOriginal();
|
|
auto &pullback = getPullback();
|
|
auto pbLoc = getPullback().getLocation();
|
|
|
|
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
|
|
assert(accessor->getAccessorKind() == AccessorKind::Get);
|
|
|
|
auto *origEntry = original.getEntryBlock();
|
|
auto *pbEntry = pullback.getEntryBlock();
|
|
builder.setCurrentDebugScope(
|
|
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
|
|
builder.setInsertionPoint(pbEntry);
|
|
|
|
// Get getter argument and result values.
|
|
// Getter type: $(Self) -> Result
|
|
// Pullback type: $(Result') -> Self'
|
|
assert(original.getLoweredFunctionType()->getNumParameters() == 1);
|
|
assert(pullback.getLoweredFunctionType()->getNumParameters() == 1);
|
|
assert(pullback.getLoweredFunctionType()->getNumResults() == 1);
|
|
SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();
|
|
|
|
SmallVector<SILValue, 8> origFormalResults;
|
|
collectAllFormalResultsInTypeOrder(original, origFormalResults);
|
|
assert(getConfig().resultIndices->getNumIndices() == 1 &&
|
|
"Getter should have one semantic result");
|
|
auto origResult = origFormalResults[*getConfig().resultIndices->begin()];
|
|
|
|
auto tangentVectorSILTy = pullback.getConventions().getResults().front()
|
|
.getSILStorageType(getModule(),
|
|
pullback.getLoweredFunctionType(),
|
|
TypeExpansionContext::minimal());
|
|
auto tangentVectorTy = tangentVectorSILTy.getASTType();
|
|
auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct();
|
|
|
|
// Look up the corresponding field in the tangent space.
|
|
auto *origField = cast<VarDecl>(accessor->getStorage());
|
|
auto baseType = remapType(origSelf->getType()).getASTType();
|
|
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
|
|
pbLoc, getInvoker());
|
|
if (!tanField) {
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
|
|
// Switch based on the base tangent struct's value category.
|
|
switch (getTangentValueCategory(origSelf)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjResult = getAdjointValue(origEntry, origResult);
|
|
switch (adjResult.getKind()) {
|
|
case AdjointValueKind::Zero:
|
|
addAdjointValue(origEntry, origSelf,
|
|
makeZeroAdjointValue(tangentVectorSILTy), pbLoc);
|
|
break;
|
|
case AdjointValueKind::Concrete:
|
|
case AdjointValueKind::Aggregate: {
|
|
SmallVector<AdjointValue, 8> eltVals;
|
|
for (auto *field : tangentVectorDecl->getStoredProperties()) {
|
|
if (field == tanField) {
|
|
eltVals.push_back(adjResult);
|
|
} else {
|
|
auto substMap = tangentVectorTy->getMemberSubstitutionMap(field);
|
|
auto fieldTy = field->getInterfaceType().subst(substMap);
|
|
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
|
|
assert(fieldSILTy.isObject());
|
|
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
|
|
}
|
|
}
|
|
addAdjointValue(origEntry, origSelf,
|
|
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
|
|
pbLoc);
|
|
|
|
break;
|
|
}
|
|
case AdjointValueKind::AddElement:
|
|
llvm_unreachable("Adjoint of an aggregate type's field cannot be of kind "
|
|
"`AddElement`");
|
|
}
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
assert(pullback.getIndirectResults().size() == 1);
|
|
auto pbIndRes = pullback.getIndirectResults().front();
|
|
auto *adjSelf = createFunctionLocalAllocation(
|
|
pbIndRes->getType().getObjectType(), pbLoc);
|
|
setAdjointBuffer(origEntry, origSelf, adjSelf);
|
|
for (auto *field : tangentVectorDecl->getStoredProperties()) {
|
|
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field);
|
|
// Non-tangent fields get a zero.
|
|
if (field != tanField) {
|
|
builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization);
|
|
continue;
|
|
}
|
|
// Switch based on the property's value category.
|
|
switch (getTangentValueCategory(origResult)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjResult = getAdjointValue(origEntry, origResult);
|
|
auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc);
|
|
auto adjResultValueCopy =
|
|
builder.emitCopyValueOperation(pbLoc, adjResultValue);
|
|
builder.emitStoreValueOperation(pbLoc, adjResultValueCopy, adjSelfElt,
|
|
StoreOwnershipQualifier::Init);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
auto adjResult = getAdjointBuffer(origEntry, origResult);
|
|
builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake,
|
|
IsInitialization);
|
|
destroyedLocalAllocations.insert(adjResult);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool PullbackCloner::Implementation::runForSemanticMemberSetter() {
|
|
auto &original = getOriginal();
|
|
auto &pullback = getPullback();
|
|
auto pbLoc = getPullback().getLocation();
|
|
|
|
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
|
|
assert(accessor->getAccessorKind() == AccessorKind::Set);
|
|
|
|
auto *origEntry = original.getEntryBlock();
|
|
auto *pbEntry = pullback.getEntryBlock();
|
|
builder.setCurrentDebugScope(
|
|
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
|
|
builder.setInsertionPoint(pbEntry);
|
|
|
|
// Get setter argument values.
|
|
// Setter type: $(inout Self, Argument) -> ()
|
|
// Pullback type (wrt self): $(inout Self') -> ()
|
|
// Pullback type (wrt both): $(inout Self') -> Argument'
|
|
assert(original.getLoweredFunctionType()->getNumParameters() == 2);
|
|
assert(pullback.getLoweredFunctionType()->getNumParameters() == 1);
|
|
assert(pullback.getLoweredFunctionType()->getNumResults() == 0 ||
|
|
pullback.getLoweredFunctionType()->getNumResults() == 1);
|
|
|
|
SILValue origArg = original.getArgumentsWithoutIndirectResults()[0];
|
|
SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1];
|
|
|
|
// Look up the corresponding field in the tangent space.
|
|
auto *origField = cast<VarDecl>(accessor->getStorage());
|
|
auto baseType = remapType(origSelf->getType()).getASTType();
|
|
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
|
|
pbLoc, getInvoker());
|
|
if (!tanField) {
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
|
|
auto adjSelf = getAdjointBuffer(origEntry, origSelf);
|
|
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
|
|
// Switch based on the property's value category.
|
|
switch (getTangentValueCategory(origArg)) {
|
|
case SILValueCategory::Object: {
|
|
auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt,
|
|
LoadOwnershipQualifier::Take);
|
|
setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg));
|
|
blockTemporaries[pbEntry].insert(adjArg);
|
|
break;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
addToAdjointBuffer(origEntry, origArg, adjSelfElt, pbLoc);
|
|
builder.emitDestroyOperation(pbLoc, adjSelfElt);
|
|
break;
|
|
}
|
|
}
|
|
builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization);
|
|
|
|
return false;
|
|
}
|
|
|
|
bool PullbackCloner::Implementation::runForSemanticMemberModify() {
|
|
auto &original = getOriginal();
|
|
auto &pullback = getPullback();
|
|
auto pbLoc = getPullback().getLocation();
|
|
|
|
auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl());
|
|
assert(accessor->getAccessorKind() == AccessorKind::Modify);
|
|
|
|
auto *origEntry = original.getEntryBlock();
|
|
// We assume that the accessor has a simple 3-BB structure with yield in the entry BB
|
|
// plus resume and unwind BBs
|
|
auto *yi = cast<YieldInst>(origEntry->getTerminator());
|
|
auto *origResumeBB = yi->getResumeBB();
|
|
|
|
auto *pbEntry = pullback.getEntryBlock();
|
|
builder.setCurrentDebugScope(
|
|
remapScope(origEntry->getScopeOfFirstNonMetaInstruction()));
|
|
builder.setInsertionPoint(pbEntry);
|
|
|
|
// Get _modify accessor argument values.
|
|
// Accessor type : $(inout Self) -> @yields @inout Argument
|
|
// Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument'
|
|
// Normally pullbacks for semantic member accessors are single BB and
|
|
// therefore have empty linear map tuple, however, coroutines have a branching
|
|
// control flow due to possible coroutine abort, so we need to accommodate for
|
|
// this. We keep branch tracing enums in order not to special case in many
|
|
// other places. As there is no way to return to coroutine via abort exit, we
|
|
// essentially "linearize" a coroutine.
|
|
auto loweredFnTy = original.getLoweredFunctionType();
|
|
auto pullbackLoweredFnTy = pullback.getLoweredFunctionType();
|
|
|
|
assert(loweredFnTy->getNumParameters() == 1 &&
|
|
loweredFnTy->getNumYields() == 1);
|
|
assert(pullbackLoweredFnTy->getNumParameters() == 2);
|
|
assert(pullbackLoweredFnTy->getNumYields() == 1);
|
|
|
|
SILValue origSelf = original.getArgumentsWithoutIndirectResults().front();
|
|
|
|
SmallVector<SILValue, 8> origFormalResults;
|
|
collectAllFormalResultsInTypeOrder(original, origFormalResults);
|
|
|
|
assert(getConfig().resultIndices->getNumIndices() == 2 &&
|
|
"Modify accessor should have two semantic results");
|
|
|
|
auto origYield = origFormalResults[*std::next(getConfig().resultIndices->begin())];
|
|
|
|
// Look up the corresponding field in the tangent space.
|
|
auto *origField = cast<VarDecl>(accessor->getStorage());
|
|
auto baseType = remapType(origSelf->getType()).getASTType();
|
|
auto *tanField = getTangentStoredProperty(getContext(), origField, baseType,
|
|
pbLoc, getInvoker());
|
|
if (!tanField) {
|
|
errorOccurred = true;
|
|
return true;
|
|
}
|
|
|
|
auto adjSelf = getAdjointBuffer(origResumeBB, origSelf);
|
|
auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField);
|
|
// Modify accessors have inout yields and therefore should yield addresses.
|
|
assert(getTangentValueCategory(origYield) == SILValueCategory::Address &&
|
|
"Modify accessors should yield indirect");
|
|
|
|
// Yield the adjoint buffer and do everything else in the resume
|
|
// destination. Unwind destination is unreachable as the coroutine can never
|
|
// be aborted.
|
|
auto *unwindBB = getPullback().createBasicBlock();
|
|
auto *resumeBB = getPullbackBlock(origEntry);
|
|
builder.createYield(yi->getLoc(), {adjSelfElt}, resumeBB, unwindBB);
|
|
builder.setInsertionPoint(unwindBB);
|
|
builder.createUnreachable(SILLocation::invalid());
|
|
|
|
builder.setInsertionPoint(resumeBB);
|
|
addToAdjointBuffer(origEntry, origSelf, adjSelf, pbLoc);
|
|
|
|
return false;
|
|
}
|
|
|
|
//--------------------------------------------------------------------------//
|
|
// Adjoint buffer mapping
|
|
//--------------------------------------------------------------------------//
|
|
|
|
SILValue PullbackCloner::Implementation::getAdjointProjection(
|
|
SILBasicBlock *origBB, SILValue originalProjection) {
|
|
// Handle `struct_element_addr`.
|
|
// Adjoint projection: a `struct_element_addr` into the base adjoint buffer.
|
|
if (auto *seai = dyn_cast<StructElementAddrInst>(originalProjection)) {
|
|
assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
|
|
"`@noDerivative` struct projections should never be active");
|
|
auto adjSource = getAdjointBuffer(origBB, seai->getOperand());
|
|
auto structType = remapType(seai->getOperand()->getType()).getASTType();
|
|
auto *tanField =
|
|
getTangentStoredProperty(getContext(), seai, structType, getInvoker());
|
|
assert(tanField && "Invalid projections should have been diagnosed");
|
|
return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField);
|
|
}
|
|
// Handle `tuple_element_addr`.
|
|
// Adjoint projection: a `tuple_element_addr` into the base adjoint buffer.
|
|
if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) {
|
|
auto source = teai->getOperand();
|
|
auto adjSource = getAdjointBuffer(origBB, source);
|
|
if (!adjSource->getType().is<TupleType>())
|
|
return adjSource;
|
|
auto origTupleTy = remapType(source->getType()).castTo<TupleType>();
|
|
unsigned adjIndex = 0;
|
|
for (unsigned i : range(teai->getFieldIndex())) {
|
|
if (getTangentSpace(
|
|
origTupleTy->getElement(i).getType()->getCanonicalType()))
|
|
++adjIndex;
|
|
}
|
|
return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex);
|
|
}
|
|
// Handle `ref_element_addr`.
|
|
// Adjoint projection: a local allocation initialized with the corresponding
|
|
// field value from the class's base adjoint value.
|
|
if (auto *reai = dyn_cast<RefElementAddrInst>(originalProjection)) {
|
|
assert(!reai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
|
|
"`@noDerivative` class projections should never be active");
|
|
auto loc = reai->getLoc();
|
|
// Get the class operand, stripping `begin_borrow`.
|
|
auto classOperand = stripBorrow(reai->getOperand());
|
|
auto classType = remapType(reai->getOperand()->getType()).getASTType();
|
|
auto *tanField =
|
|
getTangentStoredProperty(getContext(), reai->getField(), classType,
|
|
reai->getLoc(), getInvoker());
|
|
assert(tanField && "Invalid projections should have been diagnosed");
|
|
// Create a local allocation for the element adjoint buffer.
|
|
auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType();
|
|
auto eltTanSILType =
|
|
remapType(SILType::getPrimitiveAddressType(eltTanType));
|
|
auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc);
|
|
// Check the class operand's `TangentVector` value category.
|
|
switch (getTangentValueCategory(classOperand)) {
|
|
case SILValueCategory::Object: {
|
|
// Get the class operand's adjoint value. Currently, it must be a
|
|
// `TangentVector` struct.
|
|
auto adjClass =
|
|
materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc);
|
|
builder.emitScopedBorrowOperation(
|
|
loc, adjClass, [&](SILValue borrowedAdjClass) {
|
|
// Initialize the element adjoint buffer with the base adjoint
|
|
// value.
|
|
auto *adjElt =
|
|
builder.createStructExtract(loc, borrowedAdjClass, tanField);
|
|
auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt);
|
|
builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer,
|
|
StoreOwnershipQualifier::Init);
|
|
});
|
|
return eltAdjBuffer;
|
|
}
|
|
case SILValueCategory::Address: {
|
|
// Get the class operand's adjoint buffer. Currently, it must be a
|
|
// `TangentVector` struct.
|
|
auto adjClass = getAdjointBuffer(origBB, classOperand);
|
|
// Initialize the element adjoint buffer with the base adjoint buffer.
|
|
auto *adjElt = builder.createStructElementAddr(loc, adjClass, tanField);
|
|
builder.createCopyAddr(loc, adjElt, eltAdjBuffer, IsNotTake,
|
|
IsInitialization);
|
|
return eltAdjBuffer;
|
|
}
|
|
}
|
|
}
|
|
// Handle `begin_access`.
|
|
// Adjoint projection: the base adjoint buffer itself.
|
|
if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) {
|
|
auto adjBase = getAdjointBuffer(origBB, bai->getOperand());
|
|
if (errorOccurred)
|
|
return (bufferMap[{origBB, originalProjection}] = SILValue());
|
|
// Return the base buffer's adjoint buffer.
|
|
return adjBase;
|
|
}
|
|
// Handle `array.uninitialized_intrinsic` application element addresses.
|
|
// Adjoint projection: a local allocation initialized by applying
|
|
// `Array.TangentVector.subscript` to the base array's adjoint value.
|
|
auto *ai =
|
|
getAllocateUninitializedArrayIntrinsicElementAddress(originalProjection);
|
|
auto *definingInst = dyn_cast_or_null<SingleValueInstruction>(
|
|
originalProjection->getDefiningInstruction());
|
|
bool isAllocateUninitializedArrayIntrinsicElementAddress =
|
|
ai && definingInst &&
|
|
(isa<PointerToAddressInst>(definingInst) ||
|
|
isa<IndexAddrInst>(definingInst));
|
|
if (isAllocateUninitializedArrayIntrinsicElementAddress) {
|
|
// Get the array element index of the result address.
|
|
int eltIndex = 0;
|
|
if (auto *iai = dyn_cast<IndexAddrInst>(definingInst)) {
|
|
auto *ili = cast<IntegerLiteralInst>(iai->getIndex());
|
|
eltIndex = ili->getValue().getLimitedValue();
|
|
}
|
|
// Get the array adjoint value.
|
|
SILValue arrayValue = getArrayValue(ai);
|
|
SILValue arrayAdjoint = materializeAdjointDirect(
|
|
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
|
|
// Apply `Array.TangentVector.subscript` to get array element adjoint value.
|
|
auto *eltAdjBuffer =
|
|
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc());
|
|
return eltAdjBuffer;
|
|
}
|
|
return SILValue();
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Adjoint value accumulation
|
|
//----------------------------------------------------------------------------//
|
|
|
|
AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
|
|
AdjointValue lhs, AdjointValue rhs, SILLocation loc) {
|
|
LLVM_DEBUG(getADDebugStream() << "Accumulating adjoint directly.\nLHS: "
|
|
<< lhs << "\nRHS: " << rhs << '\n');
|
|
switch (lhs.getKind()) {
|
|
// x
|
|
case AdjointValueKind::Concrete: {
|
|
auto lhsVal = lhs.getConcreteValue();
|
|
switch (rhs.getKind()) {
|
|
// x + y
|
|
case AdjointValueKind::Concrete: {
|
|
auto rhsVal = rhs.getConcreteValue();
|
|
auto sum = recordTemporary(builder.emitAdd(loc, lhsVal, rhsVal));
|
|
return makeConcreteAdjointValue(sum);
|
|
}
|
|
// x + 0 => x
|
|
case AdjointValueKind::Zero:
|
|
return lhs;
|
|
// x + (y, z) => (x.0 + y, x.1 + z)
|
|
case AdjointValueKind::Aggregate: {
|
|
SmallVector<AdjointValue, 8> newElements;
|
|
auto lhsTy = lhsVal->getType().getASTType();
|
|
auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal);
|
|
if (lhsTy->is<TupleType>()) {
|
|
auto elts = builder.createDestructureTuple(loc, lhsValCopy);
|
|
llvm::for_each(elts->getResults(),
|
|
[this](SILValue result) { recordTemporary(result); });
|
|
for (auto i : indices(elts->getResults())) {
|
|
auto rhsElt = rhs.getAggregateElement(i);
|
|
newElements.push_back(accumulateAdjointsDirect(
|
|
makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc));
|
|
}
|
|
} else if (lhsTy->getStructOrBoundGenericStruct()) {
|
|
auto elts =
|
|
builder.createDestructureStruct(lhsVal.getLoc(), lhsValCopy);
|
|
llvm::for_each(elts->getResults(),
|
|
[this](SILValue result) { recordTemporary(result); });
|
|
for (unsigned i : indices(elts->getResults())) {
|
|
auto rhsElt = rhs.getAggregateElement(i);
|
|
newElements.push_back(accumulateAdjointsDirect(
|
|
makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc));
|
|
}
|
|
} else {
|
|
llvm_unreachable("Not an aggregate type");
|
|
}
|
|
return makeAggregateAdjointValue(lhsVal->getType(), newElements);
|
|
}
|
|
// x + (baseAdjoint, index, eltToAdd) => (x+baseAdjoint, index, eltToAdd)
|
|
case AdjointValueKind::AddElement: {
|
|
auto *addElementValue = rhs.getAddElementValue();
|
|
auto baseAdjoint = addElementValue->baseAdjoint;
|
|
auto eltToAdd = addElementValue->eltToAdd;
|
|
|
|
auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
|
|
return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
|
|
addElementValue->fieldLocator);
|
|
}
|
|
}
|
|
}
|
|
// 0
|
|
case AdjointValueKind::Zero:
|
|
// 0 + x => x
|
|
return rhs;
|
|
// (x, y)
|
|
case AdjointValueKind::Aggregate: {
|
|
switch (rhs.getKind()) {
|
|
// (x, y) + z => (z.0 + x, z.1 + y)
|
|
case AdjointValueKind::Concrete:
|
|
return accumulateAdjointsDirect(rhs, lhs, loc);
|
|
// x + 0 => x
|
|
case AdjointValueKind::Zero:
|
|
return lhs;
|
|
// (x, y) + (z, w) => (x + z, y + w)
|
|
case AdjointValueKind::Aggregate: {
|
|
SmallVector<AdjointValue, 8> newElements;
|
|
for (auto i : range(lhs.getNumAggregateElements()))
|
|
newElements.push_back(accumulateAdjointsDirect(
|
|
lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc));
|
|
return makeAggregateAdjointValue(lhs.getType(), newElements);
|
|
}
|
|
// (x.0, ..., x.n) + (baseAdjoint, index, eltToAdd) => (x + baseAdjoint,
|
|
// index, eltToAdd)
|
|
case AdjointValueKind::AddElement: {
|
|
auto *addElementValue = rhs.getAddElementValue();
|
|
auto baseAdjoint = addElementValue->baseAdjoint;
|
|
auto eltToAdd = addElementValue->eltToAdd;
|
|
auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
|
|
|
|
return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
|
|
addElementValue->fieldLocator);
|
|
}
|
|
}
|
|
}
|
|
// (baseAdjoint, index, eltToAdd)
|
|
case AdjointValueKind::AddElement: {
|
|
switch (rhs.getKind()) {
|
|
case AdjointValueKind::Zero:
|
|
return lhs;
|
|
// (baseAdjoint, index, eltToAdd) + x => (x + baseAdjoint, index, eltToAdd)
|
|
case AdjointValueKind::Concrete:
|
|
// (baseAdjoint, index, eltToAdd) + (x.0, ..., x.n) => (x + baseAdjoint,
|
|
// index, eltToAdd)
|
|
case AdjointValueKind::Aggregate:
|
|
return accumulateAdjointsDirect(rhs, lhs, loc);
|
|
// (baseAdjoint1, index1, eltToAdd1) + (baseAdjoint2, index2, eltToAdd2)
|
|
// => ((baseAdjoint1 + baseAdjoint2, index1, eltToAdd1), index2, eltToAdd2)
|
|
case AdjointValueKind::AddElement: {
|
|
auto *addElementValueLhs = lhs.getAddElementValue();
|
|
auto baseAdjointLhs = addElementValueLhs->baseAdjoint;
|
|
auto eltToAddLhs = addElementValueLhs->eltToAdd;
|
|
|
|
auto *addElementValueRhs = rhs.getAddElementValue();
|
|
auto baseAdjointRhs = addElementValueRhs->baseAdjoint;
|
|
auto eltToAddRhs = addElementValueRhs->eltToAdd;
|
|
|
|
auto sumOfBaseAdjoints =
|
|
accumulateAdjointsDirect(baseAdjointLhs, baseAdjointRhs, loc);
|
|
auto newBaseAdjoint = makeAddElementAdjointValue(
|
|
sumOfBaseAdjoints, eltToAddLhs, addElementValueLhs->fieldLocator);
|
|
|
|
return makeAddElementAdjointValue(newBaseAdjoint, eltToAddRhs,
|
|
addElementValueRhs->fieldLocator);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Array literal initialization differentiation
|
|
//----------------------------------------------------------------------------//
|
|
|
|
void PullbackCloner::Implementation::
|
|
accumulateArrayLiteralElementAddressAdjoints(SILBasicBlock *origBB,
|
|
SILValue originalValue,
|
|
AdjointValue arrayAdjointValue,
|
|
SILLocation loc) {
|
|
// Return if the original value is not the `Array` result of an
|
|
// `array.uninitialized_intrinsic` application.
|
|
auto *dti = dyn_cast_or_null<DestructureTupleInst>(
|
|
originalValue->getDefiningInstruction());
|
|
if (!dti)
|
|
return;
|
|
if (!ArraySemanticsCall(dti->getOperand(),
|
|
semantics::ARRAY_UNINITIALIZED_INTRINSIC))
|
|
return;
|
|
if (originalValue != dti->getResult(0))
|
|
return;
|
|
// Accumulate the array's adjoint value into the adjoint buffers of its
|
|
// element addresses: `pointer_to_address` and (optionally) `index_addr`
|
|
// instructions.
|
|
// The input code looks like as follows:
|
|
// %17 = integer_literal $Builtin.Word, 1
|
|
// function_ref _allocateUninitializedArray<A>(_:)
|
|
// %18 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
|
|
// %19 = apply %18<Float>(%17) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
|
|
// (%20, %21) = destructure_tuple %19
|
|
// %22 = mark_dependence %21 on %20
|
|
// %23 = pointer_to_address %22 to [strict] $*Float
|
|
// store %0 to [trivial] %23
|
|
// function_ref _finalizeUninitializedArray<A>(_:)
|
|
// %25 = function_ref @$ss27_finalizeUninitializedArrayySayxGABnlF : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0>
|
|
// %26 = apply %25<Float>(%20) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // user: %27
|
|
// Note that %20 and %21 are in some sense "aliases" for each other. Here our `originalValue` is %20 in the code above.
|
|
// We need to trace from %21 down to %23 and propagate (decomposed) adjoint of originalValue to adjoint of %23.
|
|
// Then the generic adjoint propagation code would do its job to propagate %23' to %0'.
|
|
// If we're initializing multiple values we're having additional `index_addr` instructions, but
|
|
// the handling is similar.
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Accumulating adjoint value for array literal into element "
|
|
"address adjoint buffers"
|
|
<< originalValue);
|
|
auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc);
|
|
builder.setCurrentDebugScope(remapScope(dti->getDebugScope()));
|
|
for (auto use : dti->getResult(1)->getUses()) {
|
|
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser());
|
|
assert(mdi && "Expected mark_dependence user");
|
|
auto *ptai =
|
|
dyn_cast_or_null<PointerToAddressInst>(getSingleNonDebugUser(mdi));
|
|
assert(ptai && "Expected pointer_to_address user");
|
|
auto adjBuf = getAdjointBuffer(origBB, ptai);
|
|
auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc);
|
|
builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
|
|
for (auto use : ptai->getUses()) {
|
|
if (auto *iai = dyn_cast<IndexAddrInst>(use->getUser())) {
|
|
auto *ili = cast<IntegerLiteralInst>(iai->getIndex());
|
|
auto eltIndex = ili->getValue().getLimitedValue();
|
|
auto adjBuf = getAdjointBuffer(origBB, iai);
|
|
auto *eltAdjBuf =
|
|
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc);
|
|
builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
AllocStackInst *PullbackCloner::Implementation::getArrayAdjointElementBuffer(
|
|
SILValue arrayAdjoint, int eltIndex, SILLocation loc) {
|
|
auto &ctx = builder.getASTContext();
|
|
auto arrayTanType = cast<StructType>(arrayAdjoint->getType().getASTType());
|
|
auto arrayType = arrayTanType->getParent()->castTo<BoundGenericStructType>();
|
|
auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType();
|
|
auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType));
|
|
// Get `function_ref` and generic signature of
|
|
// `Array.TangentVector.subscript.getter`.
|
|
auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct();
|
|
auto subscriptLookup =
|
|
arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript());
|
|
SubscriptDecl *subscriptDecl = nullptr;
|
|
for (auto *candidate : subscriptLookup) {
|
|
auto candidateModule = candidate->getModuleContext();
|
|
if (candidateModule->getName() == ctx.Id_Differentiation ||
|
|
candidateModule->isStdlibModule()) {
|
|
assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s");
|
|
subscriptDecl = cast<SubscriptDecl>(candidate);
|
|
#ifdef NDEBUG
|
|
break;
|
|
#endif
|
|
}
|
|
}
|
|
assert(subscriptDecl && "No `Array.TangentVector.subscript`");
|
|
auto *subscriptGetterDecl =
|
|
subscriptDecl->getOpaqueAccessor(AccessorKind::Get);
|
|
assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter");
|
|
SILOptFunctionBuilder fb(getContext().getTransform());
|
|
auto *subscriptGetterFn = fb.getOrCreateFunction(
|
|
loc, SILDeclRef(subscriptGetterDecl), NotForDefinition);
|
|
// %subscript_fn = function_ref @Array.TangentVector<T>.subscript.getter
|
|
auto *subscriptFnRef = builder.createFunctionRef(loc, subscriptGetterFn);
|
|
auto subscriptFnGenSig =
|
|
subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature();
|
|
// Apply `Array.TangentVector.subscript.getter` to get array element adjoint
|
|
// buffer.
|
|
// %index_literal = integer_literal $Builtin.IntXX, <index>
|
|
auto builtinIntType =
|
|
SILType::getPrimitiveObjectType(ctx.getIntDecl()
|
|
->getStoredProperties()
|
|
.front()
|
|
->getInterfaceType()
|
|
->getCanonicalType());
|
|
auto *eltIndexLiteral =
|
|
builder.createIntegerLiteral(loc, builtinIntType, eltIndex);
|
|
auto intType = SILType::getPrimitiveObjectType(
|
|
ctx.getIntType()->getCanonicalType());
|
|
// %index_int = struct $Int (%index_literal)
|
|
auto *eltIndexInt = builder.createStruct(loc, intType, {eltIndexLiteral});
|
|
auto *diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
|
|
auto diffConf = lookupConformance(eltTanType, diffProto);
|
|
assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`");
|
|
auto *addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
|
|
auto addArithConf = lookupConformance(eltTanType, addArithProto);
|
|
assert(!addArithConf.isInvalid() &&
|
|
"Missing conformance to `AdditiveArithmetic`");
|
|
auto subMap = SubstitutionMap::get(subscriptFnGenSig, {eltTanType},
|
|
{addArithConf, diffConf});
|
|
// %elt_adj = alloc_stack $T.TangentVector
|
|
// Create and register a local allocation.
|
|
auto *eltAdjBuffer = createFunctionLocalAllocation(
|
|
eltTanSILType, loc, /*zeroInitialize*/ true);
|
|
// Immediately destroy the emitted zero value.
|
|
// NOTE: It is not efficient to emit a zero value then immediately destroy
|
|
// it. However, it was the easiest way to to avoid "lifetime mismatch in
|
|
// predecessors" memory lifetime verification errors for control flow
|
|
// differentiation.
|
|
// Perhaps we can avoid emitting a zero value if local allocations are created
|
|
// per pullback bb instead of all in the pullback entry: TF-1075.
|
|
builder.emitDestroyOperation(loc, eltAdjBuffer);
|
|
// apply %subscript_fn<T.TangentVector>(%elt_adj, %index_int, %array_adj)
|
|
builder.createApply(loc, subscriptFnRef, subMap,
|
|
{eltAdjBuffer, eltIndexInt, arrayAdjoint});
|
|
return eltAdjBuffer;
|
|
}
|
|
|
|
} // end namespace autodiff
|
|
} // end namespace swift
|