//===--- PullbackCloner.cpp - Pullback function generation ---*- C++ -*----===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file defines a helper class for generating pullback functions for // automatic differentiation. // //===----------------------------------------------------------------------===// #define DEBUG_TYPE "differentiation" #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" #include "swift/SILOptimizer/Differentiation/AdjointValue.h" #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" #include "swift/SILOptimizer/Differentiation/VJPCloner.h" #include "swift/AST/ConformanceLookup.h" #include "swift/AST/Expr.h" #include "swift/AST/PropertyWrappers.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/Basic/Assertions.h" #include "swift/Basic/STLExtras.h" #include "swift/SIL/ApplySite.h" #include "swift/SIL/InstructionUtils.h" #include "swift/SIL/Projection.h" #include "swift/SIL/TypeSubstCloner.h" #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallSet.h" namespace swift { class SILDifferentiabilityWitness; class SILBasicBlock; class SILFunction; class SILInstruction; namespace autodiff { class ADContext; class VJPCloner; /// The implementation class for `PullbackCloner`. /// /// The implementation class is a `SILInstructionVisitor`. Effectively, it acts /// as a `SILCloner` that visits basic blocks in post-order and that visits /// instructions per basic block in reverse order. This visitation order is /// necessary for generating pullback functions, whose control flow graph is /// ~a transposed version of the original function's control flow graph. class PullbackCloner::Implementation final : public SILInstructionVisitor { public: explicit Implementation(VJPCloner &vjpCloner); private: /// The parent VJP cloner. VJPCloner &vjpCloner; /// Dominance info for the original function. DominanceInfo *domInfo = nullptr; /// Post-dominance info for the original function. PostDominanceInfo *postDomInfo = nullptr; /// Post-order info for the original function. PostOrderFunctionInfo *postOrderInfo = nullptr; /// Mapping from original basic blocks to corresponding pullback basic blocks. /// Pullback basic blocks always have the predecessor as the single argument. llvm::DenseMap pullbackBBMap; /// Mapping from original basic blocks and original values to corresponding /// adjoint values. llvm::DenseMap, AdjointValue> valueMap; /// Mapping from original basic blocks and original values to corresponding /// adjoint buffers. llvm::DenseMap, SILValue> bufferMap; /// Mapping from pullback struct field declarations to pullback struct /// elements destructured from the linear map basic block argument. In the /// beginning of each pullback basic block, the block's pullback struct is /// destructured into individual elements stored here. llvm::DenseMap> pullbackTupleElements; /// Mapping from original basic blocks and successor basic blocks to /// corresponding pullback trampoline basic blocks. Trampoline basic blocks /// take additional arguments in addition to the predecessor enum argument. llvm::DenseMap, SILBasicBlock *> pullbackTrampolineBBMap; /// Mapping from original basic blocks to dominated active values. llvm::DenseMap> activeValues; /// Mapping from original basic blocks and original active values to /// corresponding pullback block arguments. llvm::DenseMap, SILArgument *> activeValuePullbackBBArgumentMap; /// Mapping from original basic blocks to local temporary values to be cleaned /// up. This is populated when pullback emission is run on one basic block and /// cleaned before processing another basic block. llvm::DenseMap> blockTemporaries; /// The scope cloner. ScopeCloner scopeCloner; /// The main builder. TangentBuilder builder; /// An auxiliary local allocation builder. TangentBuilder localAllocBuilder; /// The original function's exit block. SILBasicBlock *originalExitBlock = nullptr; /// Stack buffers allocated for storing local adjoint values. SmallVector functionLocalAllocations; /// Copies created to deal with destructive enum operations /// (unchecked_take_enum_addr) llvm::SmallDenseMap enumDataAdjCopies; /// A set used to remember local allocations that were destroyed. llvm::SmallDenseSet destroyedLocalAllocations; /// The seed arguments of the pullback function. SmallVector seeds; /// The `AutoDiffLinearMapContext` object, if any. SILValue contextValue = nullptr; llvm::BumpPtrAllocator allocator; bool errorOccurred = false; ADContext &getContext() const { return vjpCloner.getContext(); } SILModule &getModule() const { return getContext().getModule(); } ASTContext &getASTContext() const { return getPullback().getASTContext(); } SILFunction &getOriginal() const { return vjpCloner.getOriginal(); } SILDifferentiabilityWitness *getWitness() const { return vjpCloner.getWitness(); } DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); } LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); } const AutoDiffConfig &getConfig() const { return vjpCloner.getConfig(); } const DifferentiableActivityInfo &getActivityInfo() const { return vjpCloner.getActivityInfo(); } //--------------------------------------------------------------------------// // Pullback struct mapping //--------------------------------------------------------------------------// void initializePullbackTupleElements(SILBasicBlock *origBB, SILInstructionResultArray values) { auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB); assert(pbTupleTyple->getNumElements() == values.size() && "The number of pullback tuple fields must equal the number of " "pullback tuple element values"); auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }}); (void)res; assert(res.second && "A pullback tuple element already exists!"); } void initializePullbackTupleElements(SILBasicBlock *origBB, const llvm::ArrayRef &values) { auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB); assert(pbTupleTyple->getNumElements() == values.size() && "The number of pullback tuple fields must equal the number of " "pullback tuple element values"); auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }}); (void)res; assert(res.second && "A pullback struct element already exists!"); } /// Returns the pullback tuple element value corresponding to the given /// original block and apply inst. SILValue getPullbackTupleElement(FullApplySite fai) { unsigned idx = getPullbackInfo().lookUpLinearMapIndex(fai); assert((idx > 0 || (idx == 0 && fai.getParent()->isEntry())) && "impossible linear map index"); auto values = pullbackTupleElements.lookup(fai.getParent()); assert(idx < values.size() && "pullback tuple element for this apply does not exist!"); return values[idx]; } /// Returns the pullback tuple element value corresponding to the predecessor /// for the given original block. SILValue getPullbackPredTupleElement(SILBasicBlock *origBB) { assert(!origBB->isEntry() && "no predecessors for entry block"); auto values = pullbackTupleElements.lookup(origBB); assert(values.size() && "pullback tuple cannot be empty"); return values[0]; } //--------------------------------------------------------------------------// // Type transformer //--------------------------------------------------------------------------// /// Get the type lowering for the given AST type. /// /// Explicitly use minimal type expansion context: in general, differentiation /// happens on function types, so it cannot know if the original function is /// resilient or not. const Lowering::TypeLowering &getTypeLowering(Type type) { auto pbGenSig = getPullback().getLoweredFunctionType()->getSubstGenericSignature(); Lowering::AbstractionPattern pattern(pbGenSig, type->getReducedType(pbGenSig)); return getContext().getTypeConverter().getTypeLowering( pattern, type, TypeExpansionContext::minimal()); } /// Remap any archetypes into the current function's context. SILType remapType(SILType ty) { if (ty.hasArchetype()) ty = ty.mapTypeOutOfContext(); auto remappedType = ty.getASTType()->getReducedType( getPullback().getLoweredFunctionType()->getSubstGenericSignature()); auto remappedSILType = SILType::getPrimitiveType(remappedType, ty.getCategory()); // FIXME: Sometimes getPullback() doesn't have a generic environment, in which // case callers are apparently happy to receive an interface type. if (getPullback().getGenericEnvironment()) return getPullback().mapTypeIntoContext(remappedSILType); return remappedSILType; } std::optional getTangentSpace(CanType type) { // Use witness generic signature to remap types. type = getWitness()->getDerivativeGenericSignature().getReducedType( type); return type->getAutoDiffTangentSpace( LookUpConformanceInModule()); } /// Returns the tangent value category of the given value. SILValueCategory getTangentValueCategory(SILValue v) { // Tangent value category table: // // Let $L be a loadable type and $*A be an address-only type. // // Original type | Tangent type loadable? | Tangent value category and type // --------------|------------------------|-------------------------------- // $L | loadable | object, $L' (no mismatch) // $*A | loadable | address, $*L' (create a buffer) // $L | address-only | address, $*A' (no alternative) // $*A | address-only | address, $*A' (no alternative) // TODO(https://github.com/apple/swift/issues/55523): Make "tangent value category" depend solely on whether the tangent type is loadable or address-only. // // For loadable tangent types, using symbolic adjoint values instead of // concrete adjoint buffers is more efficient. // Quick check: if the value has an address type, the tangent value category // is currently always "address". if (v->getType().isAddress()) return SILValueCategory::Address; // If the value has an object type and the tangent type is not address-only, // then the tangent value category is "object". auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType()); auto tanASTType = tanSpace->getCanonicalType(); if (v->getType().isObject() && getTypeLowering(tanASTType).isLoadable()) return SILValueCategory::Object; // Otherwise, the tangent value category is "address". return SILValueCategory::Address; } /// Assuming the given type conforms to `Differentiable` after remapping, /// returns the associated tangent space type. SILType getRemappedTangentType(SILType type) { return SILType::getPrimitiveType( getTangentSpace(remapType(type).getASTType())->getCanonicalType(), type.getCategory()); } /// Substitutes all replacement types of the given substitution map using the /// pullback function's substitution map. SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) { return substMap.subst(getPullback().getForwardingSubstitutionMap()); } //--------------------------------------------------------------------------// // Temporary value management //--------------------------------------------------------------------------// /// Record a temporary value for cleanup before its block's terminator. SILValue recordTemporary(SILValue value) { assert(value->getType().isObject()); assert(value->getFunction() == &getPullback()); auto inserted = blockTemporaries[value->getParentBlock()].insert(value); (void)inserted; LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value); assert(inserted && "Temporary already recorded?"); return value; } /// Clean up all temporary values for the given pullback block. void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) { assert(bb->getParent() == &getPullback()); LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb" << bb->getDebugID() << '\n'); for (auto temp : blockTemporaries[bb]) builder.emitDestroyValueOperation(loc, temp); blockTemporaries[bb].clear(); } //--------------------------------------------------------------------------// // Adjoint value factory methods //--------------------------------------------------------------------------// AdjointValue makeZeroAdjointValue(SILType type) { return AdjointValue::createZero(allocator, remapType(type)); } AdjointValue makeConcreteAdjointValue(SILValue value) { return AdjointValue::createConcrete(allocator, value); } AdjointValue makeAggregateAdjointValue(SILType type, ArrayRef elements) { return AdjointValue::createAggregate(allocator, remapType(type), elements); } AdjointValue makeAddElementAdjointValue(AdjointValue baseAdjoint, AdjointValue eltToAdd, FieldLocator fieldLocator) { auto *addElementValue = new AddElementValue(baseAdjoint, eltToAdd, fieldLocator); return AdjointValue::createAddElement(allocator, baseAdjoint.getType(), addElementValue); } //--------------------------------------------------------------------------// // Adjoint value materialization //--------------------------------------------------------------------------// /// Materializes an adjoint value. The type of the given adjoint value must be /// loadable. SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) { assert(val.getType().isObject()); LLVM_DEBUG(getADDebugStream() << "Materializing adjoint for " << val << '\n'); SILValue result; switch (val.getKind()) { case AdjointValueKind::Zero: result = recordTemporary(builder.emitZero(loc, val.getSwiftType())); break; case AdjointValueKind::Aggregate: { SmallVector elements; for (auto i : range(val.getNumAggregateElements())) { auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc); elements.push_back(builder.emitCopyValueOperation(loc, eltVal)); } if (val.getType().is()) result = recordTemporary( builder.createTuple(loc, val.getType(), elements)); else result = recordTemporary( builder.createStruct(loc, val.getType(), elements)); break; } case AdjointValueKind::Concrete: result = val.getConcreteValue(); break; case AdjointValueKind::AddElement: { auto adjointSILType = val.getAddElementValue()->baseAdjoint.getType(); auto *baseAdjAlloc = builder.createAllocStack(loc, adjointSILType); materializeAdjointIndirect(val, baseAdjAlloc, loc); auto baseAdjConcrete = recordTemporary(builder.emitLoadValueOperation( loc, baseAdjAlloc, LoadOwnershipQualifier::Take)); builder.createDeallocStack(loc, baseAdjAlloc); result = baseAdjConcrete; break; } } if (auto debugInfo = val.getDebugInfo()) builder.createDebugValue( debugInfo->first.getLocation(), result, debugInfo->second); return result; } /// Materializes an adjoint value indirectly to a SIL buffer. void materializeAdjointIndirect(AdjointValue val, SILValue destAddress, SILLocation loc) { assert(destAddress->getType().isAddress()); switch (val.getKind()) { /// If adjoint value is a symbolic zero, emit a call to /// `AdditiveArithmetic.zero`. case AdjointValueKind::Zero: builder.emitZeroIntoBuffer(loc, destAddress, IsInitialization); break; /// If adjoint value is a symbolic aggregate (tuple or struct), recursively /// materialize the symbolic tuple or struct, filling the /// buffer. case AdjointValueKind::Aggregate: { if (auto *tupTy = val.getSwiftType()->getAs()) { for (auto idx : range(val.getNumAggregateElements())) { auto eltTy = SILType::getPrimitiveAddressType( tupTy->getElementType(idx)->getCanonicalType()); auto *eltBuf = builder.createTupleElementAddr(loc, destAddress, idx, eltTy); materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc); } } else if (auto *structDecl = val.getSwiftType()->getStructOrBoundGenericStruct()) { auto fieldIt = structDecl->getStoredProperties().begin(); for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end(); ++fieldIt, ++i) { auto eltBuf = builder.createStructElementAddr(loc, destAddress, *fieldIt); materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc); } } else { llvm_unreachable("Not an aggregate type"); } break; } /// If adjoint value is concrete, it is already materialized. Store it in /// the destination address. case AdjointValueKind::Concrete: { auto concreteVal = val.getConcreteValue(); auto copyOfConcreteVal = builder.emitCopyValueOperation(loc, concreteVal); builder.emitStoreValueOperation(loc, copyOfConcreteVal, destAddress, StoreOwnershipQualifier::Init); break; } case AdjointValueKind::AddElement: { auto baseAdjoint = val; auto baseAdjointType = baseAdjoint.getType(); // Current adjoint may be made up of layers of `AddElement` adjoints. // We can iteratively gather the list of elements to add instead of making // recursive calls to `materializeAdjointIndirect`. SmallVector addEltAdjValues; do { auto addElementValue = baseAdjoint.getAddElementValue(); addEltAdjValues.push_back(addElementValue); baseAdjoint = addElementValue->baseAdjoint; assert(baseAdjointType == baseAdjoint.getType()); } while (baseAdjoint.getKind() == AdjointValueKind::AddElement); materializeAdjointIndirect(baseAdjoint, destAddress, loc); for (auto *addElementValue : addEltAdjValues) { auto eltToAdd = addElementValue->eltToAdd; SILValue baseAdjEltAddr; if (baseAdjoint.getType().is()) { baseAdjEltAddr = builder.createTupleElementAddr( loc, destAddress, addElementValue->getFieldIndex()); } else { baseAdjEltAddr = builder.createStructElementAddr( loc, destAddress, addElementValue->getFieldDecl()); } auto eltToAddMaterialized = materializeAdjointDirect(eltToAdd, loc); // Copy `eltToAddMaterialized` so we have a value with owned ownership // semantics, required for using `eltToAddMaterialized` in a `store` // instruction. auto eltToAddMaterializedCopy = builder.emitCopyValueOperation(loc, eltToAddMaterialized); auto *eltToAddAlloc = builder.createAllocStack(loc, eltToAdd.getType()); builder.emitStoreValueOperation(loc, eltToAddMaterializedCopy, eltToAddAlloc, StoreOwnershipQualifier::Init); builder.emitInPlaceAdd(loc, baseAdjEltAddr, eltToAddAlloc); builder.createDestroyAddr(loc, eltToAddAlloc); builder.createDeallocStack(loc, eltToAddAlloc); } break; } } } //--------------------------------------------------------------------------// // Adjoint value mapping //--------------------------------------------------------------------------// /// Returns true if the given value in the original function has a /// corresponding adjoint value. bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const { assert(origBB->getParent() == &getOriginal()); assert(originalValue->getType().isObject()); return valueMap.count({origBB, originalValue}); } /// Initializes the adjoint value for the original value. Asserts that the /// original value does not already have an adjoint value. void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue, AdjointValue adjointValue) { LLVM_DEBUG(getADDebugStream() << "Setting adjoint value for " << originalValue); assert(origBB->getParent() == &getOriginal()); assert(originalValue->getType().isObject()); assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); assert(adjointValue.getType().isObject()); assert(originalValue->getFunction() == &getOriginal()); // The adjoint value must be in the tangent space. assert(adjointValue.getType() == getRemappedTangentType(originalValue->getType())); // Try to assign a debug variable. if (auto debugInfo = findDebugLocationAndVariable(originalValue)) { LLVM_DEBUG({ auto &s = getADDebugStream(); s << "Found debug variable: \"" << debugInfo->second.Name << "\"\nLocation: "; debugInfo->first.getLocation().print(s, getASTContext().SourceMgr); s << '\n'; }); adjointValue.setDebugInfo(*debugInfo); } else { LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n"); } // Insert into dictionary. auto insertion = valueMap.try_emplace({origBB, originalValue}, adjointValue); LLVM_DEBUG(getADDebugStream() << "The new adjoint value, replacing the existing one, is: " << insertion.first->getSecond() << '\n'); if (!insertion.second) insertion.first->getSecond() = adjointValue; } /// Returns the adjoint value for a value in the original function. /// /// This method first tries to find an existing entry in the adjoint value /// mapping. If no entry exists, creates a zero adjoint value. AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) { assert(origBB->getParent() == &getOriginal()); assert(originalValue->getType().isObject()); assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); assert(originalValue->getFunction() == &getOriginal()); auto insertion = valueMap.try_emplace( {origBB, originalValue}, makeZeroAdjointValue(getRemappedTangentType(originalValue->getType()))); auto it = insertion.first; return it->getSecond(); } /// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets /// the sum as the new adjoint value. void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue, AdjointValue newAdjointValue, SILLocation loc) { assert(origBB->getParent() == &getOriginal()); assert(originalValue->getType().isObject()); assert(newAdjointValue.getType().isObject()); assert(originalValue->getFunction() == &getOriginal()); LLVM_DEBUG(getADDebugStream() << "Adding adjoint value for " << originalValue); // The adjoint value must be in the tangent space. assert(newAdjointValue.getType() == getRemappedTangentType(originalValue->getType())); // Try to assign a debug variable. if (auto debugInfo = findDebugLocationAndVariable(originalValue)) { LLVM_DEBUG({ auto &s = getADDebugStream(); s << "Found debug variable: \"" << debugInfo->second.Name << "\"\nLocation: "; debugInfo->first.getLocation().print(s, getASTContext().SourceMgr); s << '\n'; }); newAdjointValue.setDebugInfo(*debugInfo); } else { LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n"); } auto insertion = valueMap.try_emplace({origBB, originalValue}, newAdjointValue); auto inserted = insertion.second; if (inserted) return; // If adjoint already exists, accumulate the adjoint onto the existing // adjoint. auto it = insertion.first; auto existingValue = it->getSecond(); valueMap.erase(it); auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc); // If the original value is the `Array` result of an // `array.uninitialized_intrinsic` application, accumulate adjoint buffers // for the array element addresses. accumulateArrayLiteralElementAddressAdjoints(origBB, originalValue, adjVal, loc); setAdjointValue(origBB, originalValue, adjVal); } /// Get the pullback block argument corresponding to the given original block /// and active value. SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, SILValue activeValue) { assert(getTangentValueCategory(activeValue) == SILValueCategory::Object); assert(origBB->getParent() == &getOriginal()); auto pullbackBBArg = activeValuePullbackBBArgumentMap[{origBB, activeValue}]; assert(pullbackBBArg); assert(pullbackBBArg->getParent() == getPullbackBlock(origBB)); return pullbackBBArg; } //--------------------------------------------------------------------------// // Adjoint value accumulation //--------------------------------------------------------------------------// /// Given two adjoint values, accumulates them and returns their sum. AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs, SILLocation loc); //--------------------------------------------------------------------------// // Adjoint buffer mapping //--------------------------------------------------------------------------// /// If the given original value is an address projection, returns a /// corresponding adjoint projection to be used as its adjoint buffer. /// /// Helper function for `getAdjointBuffer`. SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue); /// Returns the adjoint buffer for the original value. /// /// This method first tries to find an existing entry in the adjoint buffer /// mapping. If no entry exists, creates a zero adjoint buffer. SILValue getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) { assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); assert(originalValue->getFunction() == &getOriginal()); auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue()); if (!insertion.second) // not inserted return insertion.first->getSecond(); // If the original buffer is a projection, return a corresponding projection // into the adjoint buffer. if (auto adjProj = getAdjointProjection(origBB, originalValue)) return (bufferMap[{origBB, originalValue}] = adjProj); LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for " << originalValue << "in bb" << origBB->getDebugID() << '\n'); auto bufType = getRemappedTangentType(originalValue->getType()); // Set insertion point for local allocation builder: before the last local // allocation, or at the start of the pullback function's entry if no local // allocations exist yet. auto debugInfo = findDebugLocationAndVariable(originalValue); SILLocation loc = debugInfo ? debugInfo->first.getLocation() : RegularLocation::getAutoGeneratedLocation(); llvm::SmallString<32> adjName; auto *newBuf = createFunctionLocalAllocation( bufType, loc, /*zeroInitialize*/ true, swift::transform(debugInfo, [&](AdjointValue::DebugInfo di) { llvm::raw_svector_ostream adjNameStream(adjName); SILDebugVariable &dv = di.second; dv.ArgNo = 0; adjNameStream << "derivative of '" << dv.Name << "'"; if (SILDebugLocation origBBLoc = origBB->front().getDebugLocation()) { adjNameStream << " in scope at "; origBBLoc.getLocation().print(adjNameStream, getASTContext().SourceMgr); } adjNameStream << " (scope #" << origBB->getDebugID() << ")"; dv.Name = adjName; // We have no meaningful debug location, and the type is different. dv.Scope = nullptr; dv.Loc = {}; dv.Type = {}; dv.DIExpr = {}; return dv; })); return (insertion.first->getSecond() = newBuf); } /// Initializes the adjoint buffer for the original value. Asserts that the /// original value does not already have an adjoint buffer. void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, SILValue adjointBuffer) { assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); auto insertion = bufferMap.try_emplace({origBB, originalValue}, adjointBuffer); assert(insertion.second && "Adjoint buffer already exists"); (void)insertion; } /// Accumulates `rhsAddress` into the adjoint buffer corresponding to the /// original value. void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, SILValue rhsAddress, SILLocation loc) { assert(getTangentValueCategory(originalValue) == SILValueCategory::Address && rhsAddress->getType().isAddress()); assert(originalValue->getFunction() == &getOriginal()); assert(rhsAddress->getFunction() == &getPullback()); auto adjointBuffer = getAdjointBuffer(origBB, originalValue); LLVM_DEBUG(getADDebugStream() << "Adding" << rhsAddress << "to adjoint (" << adjointBuffer << ") of " << originalValue << "in bb" << origBB->getDebugID() << '\n'); builder.emitInPlaceAdd(loc, adjointBuffer, rhsAddress); } /// Returns a next insertion point for creating a local allocation: either /// before the previous local allocation, or at the start of the pullback /// entry if no local allocations exist. /// /// Helper for `createFunctionLocalAllocation`. SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() { // If there are no local allocations, insert at the pullback entry start. if (functionLocalAllocations.empty()) return getPullback().getEntryBlock()->begin(); // Otherwise, insert before the last local allocation. Inserting before // rather than after ensures that allocation and zero initialization // instructions are grouped together. auto lastLocalAlloc = functionLocalAllocations.back(); return lastLocalAlloc->getDefiningInstruction()->getIterator(); } /// Creates and returns a local allocation with the given type. /// /// Local allocations are created uninitialized in the pullback entry and /// deallocated in the pullback exit. All local allocations not in /// `destroyedLocalAllocations` are also destroyed in the pullback exit. /// /// Helper for `getAdjointBuffer`. AllocStackInst *createFunctionLocalAllocation( SILType type, SILLocation loc, bool zeroInitialize = false, std::optional varInfo = std::nullopt) { // Set insertion point for local allocation builder: before the last local // allocation, or at the start of the pullback function's entry if no local // allocations exist yet. localAllocBuilder.setInsertionPoint( getPullback().getEntryBlock(), getNextFunctionLocalAllocationInsertionPoint()); // Create and return local allocation. auto *alloc = localAllocBuilder.createAllocStack(loc, type, varInfo); functionLocalAllocations.push_back(alloc); // Zero-initialize if requested. if (zeroInitialize) localAllocBuilder.emitZeroIntoBuffer(loc, alloc, IsInitialization); return alloc; } //--------------------------------------------------------------------------// // Optional differentiation //--------------------------------------------------------------------------// /// Given a `wrappedAdjoint` value of type `T.TangentVector` and `Optional` /// type, creates an `Optional.TangentVector` buffer from it. /// /// `wrappedAdjoint` may be an object or address value, both cases are /// handled. AllocStackInst *createOptionalAdjoint(SILBasicBlock *bb, SILValue wrappedAdjoint, SILType optionalTy); /// Accumulate adjoint of `wrappedAdjoint` into optionalBuffer. void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint); /// Accumulate adjoint of `wrappedAdjoint` into optionalValue. void accumulateAdjointValueForOptional(SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint); //--------------------------------------------------------------------------// // Array literal initialization differentiation //--------------------------------------------------------------------------// /// Given the adjoint value of an array initialized from an /// `array.uninitialized_intrinsic` application and an array element index, /// returns an `alloc_stack` containing the adjoint value of the array element /// at the given index by applying `Array.TangentVector.subscript`. AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint, int eltIndex, SILLocation loc); /// Given the adjoint value of an array initialized from an /// `array.uninitialized_intrinsic` application, accumulates the adjoint /// value's elements into the adjoint buffers of its element addresses. void accumulateArrayLiteralElementAddressAdjoints( SILBasicBlock *origBB, SILValue originalValue, AdjointValue arrayAdjointValue, SILLocation loc); //--------------------------------------------------------------------------// // CFG mapping //--------------------------------------------------------------------------// SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) { return pullbackBBMap.lookup(originalBlock); } SILBasicBlock *getPullbackTrampolineBlock(SILBasicBlock *originalBlock, SILBasicBlock *successorBlock) { return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock}); } //--------------------------------------------------------------------------// // Debug info //--------------------------------------------------------------------------// const SILDebugScope *remapScope(const SILDebugScope *DS) { return scopeCloner.getOrCreateClonedScope(DS); } //--------------------------------------------------------------------------// // Debugging utilities //--------------------------------------------------------------------------// void printAdjointValueMapping() { // Group original/adjoint values by basic block. llvm::DenseMap> tmp; for (auto pair : valueMap) { auto origPair = pair.first; auto *origBB = origPair.first; auto origValue = origPair.second; auto adjValue = pair.second; tmp[origBB].insert({origValue, adjValue}); } // Print original/adjoint values per basic block. auto &s = getADDebugStream() << "Adjoint value mapping:\n"; for (auto &origBB : getOriginal()) { if (!pullbackBBMap.count(&origBB)) continue; auto bbValueMap = tmp[&origBB]; s << "bb" << origBB.getDebugID(); s << " (size " << bbValueMap.size() << "):\n"; for (auto valuePair : bbValueMap) { auto origValue = valuePair.first; auto adjValue = valuePair.second; s << "ORIG: " << origValue; s << "ADJ: " << adjValue << '\n'; } s << '\n'; } } void printAdjointBufferMapping() { // Group original/adjoint buffers by basic block. llvm::DenseMap> tmp; for (auto pair : bufferMap) { auto origPair = pair.first; auto *origBB = origPair.first; auto origBuf = origPair.second; auto adjBuf = pair.second; tmp[origBB][origBuf] = adjBuf; } // Print original/adjoint buffers per basic block. auto &s = getADDebugStream() << "Adjoint buffer mapping:\n"; for (auto &origBB : getOriginal()) { if (!pullbackBBMap.count(&origBB)) continue; auto bbBufferMap = tmp[&origBB]; s << "bb" << origBB.getDebugID(); s << " (size " << bbBufferMap.size() << "):\n"; for (auto valuePair : bbBufferMap) { auto origBuf = valuePair.first; auto adjBuf = valuePair.second; s << "ORIG: " << origBuf; s << "ADJ: " << adjBuf << '\n'; } s << '\n'; } } public: //--------------------------------------------------------------------------// // Entry point //--------------------------------------------------------------------------// /// Performs pullback generation on the empty pullback function. Returns true /// if any error occurs. bool run(); /// Performs pullback generation on the empty pullback function, given that /// the original function is a "semantic member accessor". /// /// "Semantic member accessors" are attached to member properties that have a /// corresponding tangent stored property in the parent `TangentVector` type. /// These accessors have special-case pullback generation based on their /// semantic behavior. /// /// Returns true if any error occurs. bool runForSemanticMemberAccessor(); bool runForSemanticMemberGetter(); bool runForSemanticMemberSetter(); bool runForSemanticMemberModify(); /// If original result is non-varied, it will always have a zero derivative. /// Skip full pullback generation and simply emit zero derivatives for wrt /// parameters. void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult); /// Public helper so that our users can get the underlying newly created /// function. SILFunction &getPullback() const { return vjpCloner.getPullback(); } using TrampolineBlockSet = SmallPtrSet; /// Determines the pullback successor block for a given original block and one /// of its predecessors. When a trampoline block is necessary, emits code into /// the trampoline block to trampoline the original block's active value's /// adjoint values. /// /// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint /// values to the pullback successor blocks in which they are used. This /// allows us to release those values in pullback successor blocks that do not /// use them. SILBasicBlock * buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB, llvm::SmallDenseMap &pullbackTrampolineBlockMap); /// Emits pullback code in the corresponding pullback block. void visitSILBasicBlock(SILBasicBlock *bb); void visit(SILInstruction *inst) { if (errorOccurred) return; LLVM_DEBUG(getADDebugStream() << "PullbackCloner visited:\n[ORIG]" << *inst); #ifndef NDEBUG auto beforeInsertion = std::prev(builder.getInsertionPoint()); #endif SILInstructionVisitor::visit(inst); LLVM_DEBUG({ auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback (pb bb" << builder.getInsertionBB()->getDebugID() << "):\n"; auto afterInsertion = builder.getInsertionPoint(); for (auto it = ++beforeInsertion; it != afterInsertion; ++it) s << *it; }); } /// Fallback instruction visitor for unhandled instructions. /// Emit a general non-differentiability diagnostic. void visitSILInstruction(SILInstruction *inst) { LLVM_DEBUG(getADDebugStream() << "Unhandled instruction in PullbackCloner: " << *inst); getContext().emitNondifferentiabilityError( inst, getInvoker(), diag::autodiff_expression_not_differentiable_note); errorOccurred = true; } /// Handle `apply` instruction. /// Original: (y0, y1, ...) = apply @fn (x0, x1, ...) /// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...) void visitApplyInst(ApplyInst *ai) { assert(getPullbackInfo().shouldDifferentiateApplySite(ai)); // Skip `array.uninitialized_intrinsic` applications, which have special // `store` and `copy_addr` support. if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) return; auto loc = ai->getLoc(); auto *bb = ai->getParent(); // Handle `array.finalize_intrinsic` applications. // `array.finalize_intrinsic` semantically behaves like an identity // function. if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) { assert(ai->getNumArguments() == 1 && "Expected intrinsic to have one operand"); // Accumulate result's adjoint into argument's adjoint. auto adjResult = getAdjointValue(bb, ai); auto origArg = ai->getArgumentsWithoutIndirectResults().front(); addAdjointValue(bb, origArg, adjResult, loc); return; } buildPullbackCall(ai); } void buildPullbackCall(FullApplySite fai) { auto loc = fai->getLoc(); auto *bb = fai->getParent(); // Replace a call to a function with a call to its pullback. auto &nestedApplyInfo = getContext().getNestedApplyInfo(); auto applyInfoLookup = nestedApplyInfo.find(fai); // If no `NestedApplyInfo` was found, then this task doesn't need to be // differentiated. if (applyInfoLookup == nestedApplyInfo.end()) { // Must not be active. // TODO: Do we need to check token result for begin_apply? SILValue result = fai.getResult(); assert(!result || !getActivityInfo().isActive(result, getConfig())); return; } auto &applyInfo = applyInfoLookup->getSecond(); // Get the original result of the `apply` instruction. const auto &conv = fai.getSubstCalleeConv(); SmallVector origDirectResults; forEachApplyDirectResult(fai, [&](SILValue directResult) { origDirectResults.push_back(directResult); }); SmallVector origAllResults; collectAllActualResultsInTypeOrder(fai, origDirectResults, origAllResults); // Append semantic result arguments after original results. for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) { unsigned argIdx = fai.getNumIndirectSILResults() + paramIdx; auto paramInfo = conv.getParamInfoForSILArg(argIdx); if (!paramInfo.isAutoDiffSemanticResult()) continue; origAllResults.push_back( fai.getArgumentsWithoutIndirectResults()[paramIdx]); } // Get callee pullback arguments. SmallVector args; // Handle callee pullback indirect results. // Create local allocations for these and destroy them after the call. auto pullback = getPullbackTupleElement(fai); auto pullbackType = remapType(pullback->getType()).castTo(); auto actualPullbackType = applyInfo.originalPullbackType ? *applyInfo.originalPullbackType : pullbackType; actualPullbackType = actualPullbackType->getUnsubstitutedType(getModule()); SmallVector pullbackIndirectResults; for (auto indRes : actualPullbackType->getIndirectFormalResults()) { auto *alloc = builder.createAllocStack( loc, remapType(indRes.getSILStorageInterfaceType())); pullbackIndirectResults.push_back(alloc); args.push_back(alloc); } // Collect callee pullback formal arguments. unsigned firstSemanticParamResultIdx = conv.getResults().size(); unsigned firstYieldResultIndex = firstSemanticParamResultIdx + conv.getNumAutoDiffSemanticResultParameters(); for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) { if (resultIndex >= firstYieldResultIndex) continue; assert(resultIndex < origAllResults.size()); auto origResult = origAllResults[resultIndex]; // Get the seed (i.e. adjoint value of the original result). SILValue seed; switch (getTangentValueCategory(origResult)) { case SILValueCategory::Object: seed = materializeAdjointDirect(getAdjointValue(bb, origResult), loc); break; case SILValueCategory::Address: seed = getAdjointBuffer(bb, origResult); break; } args.push_back(seed); } // If callee pullback was reabstracted in VJP, reabstract callee pullback. if (applyInfo.originalPullbackType) { auto toType = *applyInfo.originalPullbackType; SILOptFunctionBuilder fb(getContext().getTransform()); if (toType->isCoroutine()) pullback = reabstractCoroutine( builder, fb, loc, pullback, toType, [this](SubstitutionMap subs) -> SubstitutionMap { return this->remapSubstitutionMap(subs); }); else pullback = reabstractFunction( builder, fb, loc, pullback, toType, [this](SubstitutionMap subs) -> SubstitutionMap { return this->remapSubstitutionMap(subs); }); } // Call the callee pullback. FullApplySite pullbackCall; SmallVector dirResults; if (actualPullbackType->isCoroutine()) { pullbackCall = builder.createBeginApply(loc, pullback, SubstitutionMap(), args); // Record pullback and begin_apply token: the pullback will be consumed // after end_apply. applyInfo.pullback = pullback; applyInfo.beginApplyToken = cast(pullbackCall)->getTokenResult(); } else { pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), args); builder.emitDestroyValueOperation(loc, pullback); // Extract all results from `pullbackCall`. extractAllElements(cast(pullbackCall), builder, dirResults); } // Get all results in type-defined order. SmallVector allResults; collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults); LLVM_DEBUG({ auto &s = getADDebugStream(); s << "All results of the nested pullback call:\n"; llvm::for_each(allResults, [&](SILValue v) { s << v; }); }); // Accumulate adjoints for original differentiation parameters. auto allResultsIt = allResults.begin(); for (unsigned i : applyInfo.config.parameterIndices->getIndices()) { unsigned argIdx = fai.getNumIndirectSILResults() + i; auto origArg = fai.getArgument(argIdx); // Skip adjoint accumulation for semantic results arguments. auto paramInfo = fai.getSubstCalleeConv().getParamInfoForSILArg(argIdx); if (paramInfo.isAutoDiffSemanticResult()) continue; auto tan = *allResultsIt++; if (tan->getType().isAddress()) { addToAdjointBuffer(bb, origArg, tan, loc); } else { if (origArg->getType().isAddress()) { auto *tmpBuf = builder.createAllocStack(loc, tan->getType()); builder.emitStoreValueOperation(loc, tan, tmpBuf, StoreOwnershipQualifier::Init); addToAdjointBuffer(bb, origArg, tmpBuf, loc); builder.emitDestroyAddrAndFold(loc, tmpBuf); builder.createDeallocStack(loc, tmpBuf); } else { recordTemporary(tan); addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc); } } } // Propagate adjoints for yields if (actualPullbackType->isCoroutine()) { auto originalYields = cast(fai)->getYieldedValues(); auto pullbackYields = cast(pullbackCall)->getYieldedValues(); assert(originalYields.size() == pullbackYields.size()); for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) { if (resultIndex < firstYieldResultIndex) continue; auto yieldResultIndex = resultIndex - firstYieldResultIndex; setAdjointBuffer(bb, originalYields[yieldResultIndex], pullbackYields[yieldResultIndex]); } } // Destroy unused pullback direct results. Needed for pullback results from // VJPs extracted from `@differentiable` function callees, where the // `@differentiable` function's differentiation parameter indices are a // superset of the active `apply` parameter indices. while (allResultsIt != allResults.end()) { auto unusedPullbackDirectResult = *allResultsIt++; if (unusedPullbackDirectResult->getType().isAddress()) continue; builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult); } // Destroy and deallocate pullback indirect results. for (auto *alloc : llvm::reverse(pullbackIndirectResults)) { builder.emitDestroyAddrAndFold(loc, alloc); builder.createDeallocStack(loc, alloc); } } void visitAbortApplyInst(AbortApplyInst *aai) { BeginApplyInst *bai = aai->getBeginApply(); assert(getPullbackInfo().shouldDifferentiateApplySite(bai)); // abort_apply differentiation is not yet supported. getContext().emitNondifferentiabilityError( bai, getInvoker(), diag::autodiff_coroutines_not_supported); errorOccurred = true; } void visitEndApplyInst(EndApplyInst *eai) { BeginApplyInst *bai = eai->getBeginApply(); assert(getPullbackInfo().shouldDifferentiateApplySite(bai)); // Replace a call to a function with a call to its pullback. auto &nestedApplyInfo = getContext().getNestedApplyInfo(); auto applyInfoLookup = nestedApplyInfo.find(bai); // If no `NestedApplyInfo` was found, then this task doesn't need to be // differentiated. if (applyInfoLookup == nestedApplyInfo.end()) { // Must not be active. assert(!getActivityInfo().isActive(bai->getTokenResult(), getConfig())); assert(!getActivityInfo().isActive(eai, getConfig())); return; } buildPullbackCall(bai); } void visitBeginApplyInst(BeginApplyInst *bai) { assert(getPullbackInfo().shouldDifferentiateApplySite(bai)); auto &nestedApplyInfo = getContext().getNestedApplyInfo(); auto applyInfoLookup = nestedApplyInfo.find(bai); // If no `NestedApplyInfo` was found, then this task doesn't need to be // differentiated. if (applyInfoLookup == nestedApplyInfo.end()) { // Must not be active. assert(!getActivityInfo().isActive(bai->getTokenResult(), getConfig())); return; } auto applyInfo = applyInfoLookup->getSecond(); auto loc = bai->getLoc(); builder.createEndApply(loc, applyInfo.beginApplyToken, SILType::getEmptyTupleType(getASTContext())); builder.emitDestroyValueOperation(loc, applyInfo.pullback); } /// Handle `struct` instruction. /// Original: y = struct (x0, x1, x2, ...) /// Adjoint: adj[x0] += struct_extract adj[y], #x0 /// adj[x1] += struct_extract adj[y], #x1 /// adj[x2] += struct_extract adj[y], #x2 /// ... void visitStructInst(StructInst *si) { auto *bb = si->getParent(); auto loc = si->getLoc(); auto *structDecl = si->getStructDecl(); switch (getTangentValueCategory(si)) { case SILValueCategory::Object: { auto av = getAdjointValue(bb, si); switch (av.getKind()) { case AdjointValueKind::Zero: { for (auto *field : structDecl->getStoredProperties()) { auto fv = si->getFieldValue(field); addAdjointValue( bb, fv, makeZeroAdjointValue(getRemappedTangentType(fv->getType())), loc); } break; } case AdjointValueKind::Concrete: { auto adjStruct = materializeAdjointDirect(std::move(av), loc); auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct); // Find the struct `TangentVector` type. auto structTy = remapType(si->getType()).getASTType(); #ifndef NDEBUG auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); assert(tangentVectorTy->getStructOrBoundGenericStruct()); #endif // Accumulate adjoints for the fields of the `struct` operand. unsigned fieldIndex = 0; for (auto it = structDecl->getStoredProperties().begin(); it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) { VarDecl *field = *it; if (field->getAttrs().hasAttribute()) continue; // Find the corresponding field in the tangent space. auto *tanField = getTangentStoredProperty( getContext(), field, structTy, loc, getInvoker()); if (!tanField) { errorOccurred = true; return; } auto tanElt = dti->getResult(fieldIndex); addAdjointValue(bb, si->getFieldValue(field), makeConcreteAdjointValue(tanElt), si->getLoc()); } break; } case AdjointValueKind::Aggregate: { // Note: All user-called initializations go through the calls to the // initializer, and synthesized initializers only have one level of // struct formation which will not result into any aggregate adjoint // values. llvm_unreachable( "Aggregate adjoint values should not occur for `struct` " "instructions"); } case AdjointValueKind::AddElement: { llvm_unreachable( "Adjoint of `StructInst` cannot be of kind `AddElement`"); } } break; } case SILValueCategory::Address: { auto adjBuf = getAdjointBuffer(bb, si); // Find the struct `TangentVector` type. auto structTy = remapType(si->getType()).getASTType(); // Accumulate adjoints for the fields of the `struct` operand. unsigned fieldIndex = 0; for (auto it = structDecl->getStoredProperties().begin(); it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) { VarDecl *field = *it; if (field->getAttrs().hasAttribute()) continue; // Find the corresponding field in the tangent space. auto *tanField = getTangentStoredProperty(getContext(), field, structTy, loc, getInvoker()); if (!tanField) { errorOccurred = true; return; } auto *adjFieldBuf = builder.createStructElementAddr(loc, adjBuf, tanField); auto fieldValue = si->getFieldValue(field); switch (getTangentValueCategory(fieldValue)) { case SILValueCategory::Object: { auto adjField = builder.emitLoadValueOperation( loc, adjFieldBuf, LoadOwnershipQualifier::Copy); recordTemporary(adjField); addAdjointValue(bb, fieldValue, makeConcreteAdjointValue(adjField), loc); break; } case SILValueCategory::Address: { addToAdjointBuffer(bb, fieldValue, adjFieldBuf, loc); break; } } } } break; } } /// Handle `struct_extract` instruction. /// Original: y = struct_extract x, #field /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) /// ^~~~~~~ /// field in tangent space corresponding to #field void visitStructExtractInst(StructExtractInst *sei) { auto *bb = sei->getParent(); auto loc = getValidLocation(sei); // Find the corresponding field in the tangent space. auto structTy = remapType(sei->getOperand()->getType()).getASTType(); auto *tanField = getTangentStoredProperty(getContext(), sei, structTy, getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); // Check the `struct_extract` operand's value tangent category. switch (getTangentValueCategory(sei->getOperand())) { case SILValueCategory::Object: { auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy); auto eltAdj = getAdjointValue(bb, sei); switch (eltAdj.getKind()) { case AdjointValueKind::Zero: { addAdjointValue(bb, sei->getOperand(), makeZeroAdjointValue(tangentVectorSILTy), loc); break; } case AdjointValueKind::Aggregate: case AdjointValueKind::Concrete: case AdjointValueKind::AddElement: { auto baseAdj = makeZeroAdjointValue(tangentVectorSILTy); addAdjointValue(bb, sei->getOperand(), makeAddElementAdjointValue(baseAdj, eltAdj, tanField), loc); break; } } break; } case SILValueCategory::Address: { auto adjBase = getAdjointBuffer(bb, sei->getOperand()); auto *adjBaseElt = builder.createStructElementAddr(loc, adjBase, tanField); // Check the `struct_extract`'s value tangent category. switch (getTangentValueCategory(sei)) { case SILValueCategory::Object: { auto adjElt = getAdjointValue(bb, sei); auto concreteAdjElt = materializeAdjointDirect(adjElt, loc); auto concreteAdjEltCopy = builder.emitCopyValueOperation(loc, concreteAdjElt); auto *alloc = builder.createAllocStack(loc, adjElt.getType()); builder.emitStoreValueOperation(loc, concreteAdjEltCopy, alloc, StoreOwnershipQualifier::Init); builder.emitInPlaceAdd(loc, adjBaseElt, alloc); builder.createDestroyAddr(loc, alloc); builder.createDeallocStack(loc, alloc); break; } case SILValueCategory::Address: { auto adjElt = getAdjointBuffer(bb, sei); builder.emitInPlaceAdd(loc, adjBaseElt, adjElt); break; } } break; } } } /// Handle `ref_element_addr` instruction. /// Original: y = ref_element_addr x, /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) /// ^~~~~~~ /// field in tangent space corresponding to #field void visitRefElementAddrInst(RefElementAddrInst *reai) { auto *bb = reai->getParent(); auto loc = reai->getLoc(); auto adjBuf = getAdjointBuffer(bb, reai); auto classOperand = reai->getOperand(); auto classType = remapType(reai->getOperand()->getType()).getASTType(); auto *tanField = getTangentStoredProperty(getContext(), reai, classType, getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); switch (getTangentValueCategory(classOperand)) { case SILValueCategory::Object: { auto classTy = remapType(classOperand->getType()).getASTType(); auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType(); auto tangentVectorSILTy = SILType::getPrimitiveObjectType(tangentVectorTy); auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); // Accumulate adjoint for the `ref_element_addr` operand. SmallVector eltVals; for (auto *field : tangentVectorDecl->getStoredProperties()) { if (field == tanField) { auto adjElt = builder.emitLoadValueOperation( reai->getLoc(), adjBuf, LoadOwnershipQualifier::Copy); eltVals.push_back(makeConcreteAdjointValue(adjElt)); recordTemporary(adjElt); } else { auto substMap = tangentVectorTy->getMemberSubstitutionMap( field); auto fieldTy = field->getInterfaceType().subst(substMap); auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); assert(fieldSILTy.isObject()); eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); } } addAdjointValue(bb, classOperand, makeAggregateAdjointValue(tangentVectorSILTy, eltVals), loc); break; } case SILValueCategory::Address: { auto adjBufClass = getAdjointBuffer(bb, classOperand); auto adjBufElt = builder.createStructElementAddr(loc, adjBufClass, tanField); builder.emitInPlaceAdd(loc, adjBufElt, adjBuf); break; } } } /// Handle `tuple` instruction. /// Original: y = tuple (x0, x1, x2, ...) /// Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y] /// ^~~ /// excluding non-differentiable elements void visitTupleInst(TupleInst *ti) { auto *bb = ti->getParent(); auto loc = ti->getLoc(); switch (getTangentValueCategory(ti)) { case SILValueCategory::Object: { auto av = getAdjointValue(bb, ti); switch (av.getKind()) { case AdjointValueKind::Zero: for (auto elt : ti->getElements()) { if (!getTangentSpace(elt->getType().getASTType())) continue; addAdjointValue( bb, elt, makeZeroAdjointValue(getRemappedTangentType(elt->getType())), loc); } break; case AdjointValueKind::Concrete: { auto adjVal = av.getConcreteValue(); auto adjValCopy = builder.emitCopyValueOperation(loc, adjVal); SmallVector adjElts; if (!adjVal->getType().getAs()) { recordTemporary(adjValCopy); adjElts.push_back(adjValCopy); } else { auto *dti = builder.createDestructureTuple(loc, adjValCopy); for (auto adjElt : dti->getResults()) recordTemporary(adjElt); adjElts.append(dti->getResults().begin(), dti->getResults().end()); } // Accumulate adjoints for `tuple` operands, skipping the // non-`Differentiable` ones. unsigned adjIndex = 0; for (auto i : range(ti->getNumOperands())) { if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) continue; auto adjElt = adjElts[adjIndex++]; addAdjointValue(bb, ti->getOperand(i), makeConcreteAdjointValue(adjElt), loc); } break; } case AdjointValueKind::Aggregate: { unsigned adjIndex = 0; for (auto i : range(ti->getElements().size())) { if (!getTangentSpace(ti->getElement(i)->getType().getASTType())) continue; addAdjointValue(bb, ti->getElement(i), av.getAggregateElement(adjIndex++), loc); } break; } case AdjointValueKind::AddElement: { llvm_unreachable( "Adjoint of `TupleInst` cannot be of kind `AddElement`"); } } break; } case SILValueCategory::Address: { auto adjBuf = getAdjointBuffer(bb, ti); // Accumulate adjoints for `tuple` operands, skipping the // non-`Differentiable` ones. unsigned adjIndex = 0; for (auto i : range(ti->getNumOperands())) { if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) continue; auto adjBufElt = builder.createTupleElementAddr(loc, adjBuf, adjIndex++); auto adjElt = getAdjointBuffer(bb, ti->getOperand(i)); builder.emitInPlaceAdd(loc, adjElt, adjBufElt); } break; } } } /// Handle `tuple_extract` instruction. /// Original: y = tuple_extract x, /// Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0) /// ^~~~~~ /// n'-th element, where n' is tuple tangent space /// index corresponding to n void visitTupleExtractInst(TupleExtractInst *tei) { auto *bb = tei->getParent(); auto loc = tei->getLoc(); auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType()); auto eltAdj = getAdjointValue(bb, tei); switch (eltAdj.getKind()) { case AdjointValueKind::Zero: { addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy), loc); break; } case AdjointValueKind::Aggregate: case AdjointValueKind::Concrete: case AdjointValueKind::AddElement: { auto tupleTy = tei->getTupleType(); auto tupleTanTupleTy = tupleTanTy.getAs(); if (!tupleTanTupleTy) { addAdjointValue(bb, tei->getOperand(), eltAdj, loc); break; } unsigned elements = 0; for (unsigned i : range(tupleTy->getNumElements())) { if (!getTangentSpace( tupleTy->getElement(i).getType()->getCanonicalType())) continue; elements++; } if (elements == 1) { addAdjointValue(bb, tei->getOperand(), eltAdj, loc); } else { auto baseAdj = makeZeroAdjointValue(tupleTanTy); addAdjointValue( bb, tei->getOperand(), makeAddElementAdjointValue(baseAdj, eltAdj, tei->getFieldIndex()), loc); } break; } } } /// Handle `destructure_tuple` instruction. /// Original: (y0, ..., yn) = destructure_tuple x /// Adjoint: adj[x].0 += adj[y0] /// ... /// adj[x].n += adj[yn] void visitDestructureTupleInst(DestructureTupleInst *dti) { auto *bb = dti->getParent(); auto loc = dti->getLoc(); auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType()); // Check the `destructure_tuple` operand's value tangent category. switch (getTangentValueCategory(dti->getOperand())) { case SILValueCategory::Object: { SmallVector adjValues; for (auto origElt : dti->getResults()) { // Skip non-`Differentiable` tuple elements. if (!getTangentSpace(remapType(origElt->getType()).getASTType())) continue; adjValues.push_back(getAdjointValue(bb, origElt)); } // Handle tuple tangent type. // Add adjoints for every tuple element that has a tangent space. if (tupleTanTy.is()) { assert(adjValues.size() > 1); addAdjointValue(bb, dti->getOperand(), makeAggregateAdjointValue(tupleTanTy, adjValues), loc); } // Handle non-tuple tangent type. // Add adjoint for the single tuple element that has a tangent space. else { assert(adjValues.size() == 1); addAdjointValue(bb, dti->getOperand(), adjValues.front(), loc); } break; } case SILValueCategory::Address: { auto adjBuf = getAdjointBuffer(bb, dti->getOperand()); unsigned adjIndex = 0; for (auto origElt : dti->getResults()) { // Skip non-`Differentiable` tuple elements. if (!getTangentSpace(remapType(origElt->getType()).getASTType())) continue; // Handle tuple tangent type. // Add adjoints for every tuple element that has a tangent space. if (tupleTanTy.is()) { auto adjEltBuf = getAdjointBuffer(bb, origElt); auto adjBufElt = builder.createTupleElementAddr(loc, adjBuf, adjIndex); builder.emitInPlaceAdd(loc, adjBufElt, adjEltBuf); } // Handle non-tuple tangent type. // Add adjoint for the single tuple element that has a tangent space. else { auto adjEltBuf = getAdjointBuffer(bb, origElt); addToAdjointBuffer(bb, dti->getOperand(), adjEltBuf, loc); } ++adjIndex; } break; } } } /// Handle `load` or `load_borrow` instruction /// Original: y = load/load_borrow x /// Adjoint: adj[x] += adj[y] void visitLoadOperation(SingleValueInstruction *inst) { assert(isa(inst) || isa(inst)); auto *bb = inst->getParent(); auto loc = inst->getLoc(); switch (getTangentValueCategory(inst)) { case SILValueCategory::Object: { auto adjVal = materializeAdjointDirect(getAdjointValue(bb, inst), loc); // Allocate a local buffer and store the adjoint value. This buffer will // be used for accumulation into the adjoint buffer. auto adjBuf = builder.createAllocStack(loc, adjVal->getType(), {}, DoesNotHaveDynamicLifetime, IsNotLexical, IsNotFromVarDecl, DoesNotUseMoveableValueDebugInfo, /* skipVarDeclAssert = */ true); auto copy = builder.emitCopyValueOperation(loc, adjVal); builder.emitStoreValueOperation(loc, copy, adjBuf, StoreOwnershipQualifier::Init); // Accumulate the adjoint value in the local buffer into the adjoint // buffer. addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); builder.emitDestroyAddr(loc, adjBuf); builder.createDeallocStack(loc, adjBuf); break; } case SILValueCategory::Address: { auto adjBuf = getAdjointBuffer(bb, inst); addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); break; } } } void visitLoadInst(LoadInst *li) { visitLoadOperation(li); } void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); } /// Handle `store` or `store_borrow` instruction. /// Original: store/store_borrow x to y /// Adjoint: adj[x] += load adj[y]; adj[y] = 0 void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc, SILValue origDest) { auto adjBuf = getAdjointBuffer(bb, origDest); switch (getTangentValueCategory(origSrc)) { case SILValueCategory::Object: { auto adjVal = builder.emitLoadValueOperation( loc, adjBuf, LoadOwnershipQualifier::Take); recordTemporary(adjVal); addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc); builder.emitZeroIntoBuffer(loc, adjBuf, IsInitialization); break; } case SILValueCategory::Address: { addToAdjointBuffer(bb, origSrc, adjBuf, loc); builder.emitZeroIntoBuffer(loc, adjBuf, IsNotInitialization); break; } } } void visitStoreInst(StoreInst *si) { visitStoreOperation(si->getParent(), si->getLoc(), si->getSrc(), si->getDest()); } void visitStoreBorrowInst(StoreBorrowInst *sbi) { visitStoreOperation(sbi->getParent(), sbi->getLoc(), sbi->getSrc(), sbi); } /// Handle `copy_addr` instruction. /// Original: copy_addr x to y /// Adjoint: adj[x] += adj[y]; adj[y] = 0 void visitCopyAddrInst(CopyAddrInst *cai) { auto *bb = cai->getParent(); auto adjDest = getAdjointBuffer(bb, cai->getDest()); addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc()); builder.emitZeroIntoBuffer(cai->getLoc(), adjDest, IsNotInitialization); } /// Handle any ownership instruction that deals with values: copy_value, /// move_value, begin_borrow. /// Original: y = copy_value x /// Adjoint: adj[x] += adj[y] void visitValueOwnershipInst(SingleValueInstruction *svi, bool needZeroResAdj = false) { assert(svi->getNumOperands() == 1); auto *bb = svi->getParent(); switch (getTangentValueCategory(svi)) { case SILValueCategory::Object: { auto adj = getAdjointValue(bb, svi); addAdjointValue(bb, svi->getOperand(0), adj, svi->getLoc()); if (needZeroResAdj) { assert(svi->getNumResults() == 1); SILValue val = svi->getResult(0); setAdjointValue( bb, val, makeZeroAdjointValue(getRemappedTangentType(val->getType()))); } break; } case SILValueCategory::Address: { auto adjDest = getAdjointBuffer(bb, svi); addToAdjointBuffer(bb, svi->getOperand(0), adjDest, svi->getLoc()); builder.emitZeroIntoBuffer(svi->getLoc(), adjDest, IsNotInitialization); break; } } } /// Handle `copy_value` instruction. /// Original: y = copy_value x /// Adjoint: adj[x] += adj[y] void visitCopyValueInst(CopyValueInst *cvi) { visitValueOwnershipInst(cvi); } /// Handle `begin_borrow` instruction. /// Original: y = begin_borrow x /// Adjoint: adj[x] += adj[y] void visitBeginBorrowInst(BeginBorrowInst *bbi) { visitValueOwnershipInst(bbi); } /// Handle `move_value` instruction. /// Original: y = move_value x /// Adjoint: adj[x] += adj[y]; adj[y] = 0 void visitMoveValueInst(MoveValueInst *mvi) { switch (getTangentValueCategory(mvi)) { case SILValueCategory::Address: LLVM_DEBUG(getADDebugStream() << "AutoDiff does not support move_value with " "SILValueCategory::Address"); getContext().emitNondifferentiabilityError( mvi, getInvoker(), diag::autodiff_expression_not_differentiable_note); errorOccurred = true; return; case SILValueCategory::Object: visitValueOwnershipInst(mvi, /*needZeroResAdj=*/true); } } void visitEndInitLetRefInst(EndInitLetRefInst *eir) { visitValueOwnershipInst(eir); } /// Handle `begin_access` instruction. /// Original: y = begin_access x /// Adjoint: nothing void visitBeginAccessInst(BeginAccessInst *bai) { // Check for non-differentiable writes. if (bai->getAccessKind() == SILAccessKind::Modify) { if (isa(bai->getSource())) { getContext().emitNondifferentiabilityError( bai, getInvoker(), diag::autodiff_cannot_differentiate_writes_to_global_variables); errorOccurred = true; return; } if (isa(bai->getSource())) { getContext().emitNondifferentiabilityError( bai, getInvoker(), diag::autodiff_cannot_differentiate_writes_to_mutable_captures); errorOccurred = true; return; } } } /// Handle `unconditional_checked_cast_addr` instruction. /// Original: y = unconditional_checked_cast_addr x /// Adjoint: adj[x] += unconditional_checked_cast_addr adj[y] void visitUnconditionalCheckedCastAddrInst( UnconditionalCheckedCastAddrInst *uccai) { auto *bb = uccai->getParent(); auto adjDest = getAdjointBuffer(bb, uccai->getDest()); auto adjSrc = getAdjointBuffer(bb, uccai->getSrc()); auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType()); builder.createUnconditionalCheckedCastAddr( uccai->getLoc(), uccai->getCheckedCastOptions(), adjDest, adjDest->getType().getASTType(), castBuf, adjSrc->getType().getASTType()); addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc()); builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf); builder.createDeallocStack(uccai->getLoc(), castBuf); builder.emitZeroIntoBuffer(uccai->getLoc(), adjDest, IsInitialization); } /// Handle `enum` instruction. /// Original: y = enum $Enum, #Enum.some!enumelt, x /// Adjoint: adj[x] += adj[y] void visitEnumInst(EnumInst *ei) { SILBasicBlock *bb = ei->getParent(); SILLocation loc = ei->getLoc(); auto *optionalEnumDecl = getASTContext().getOptionalDecl(); // Only `Optional`-typed operands are supported for now. Diagnose all other // enum operand types. if (ei->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) { LLVM_DEBUG(getADDebugStream() << "Unsupported enum type in PullbackCloner: " << *ei); getContext().emitNondifferentiabilityError( ei, getInvoker(), diag::autodiff_expression_not_differentiable_note); errorOccurred = true; return; } auto adjOpt = getAdjointValue(bb, ei); auto adjStruct = materializeAdjointDirect(adjOpt, loc); VarDecl *adjOptVar = getASTContext().getOptionalTanValueDecl(adjStruct->getType().getASTType()); auto *adjVal = builder.createStructExtract(loc, adjStruct, adjOptVar); EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl(); auto *adjData = builder.createUncheckedEnumData(loc, adjVal, someElemDecl); addAdjointValue(bb, ei->getOperand(), makeConcreteAdjointValue(adjData), loc); } /// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr` /// instructions. /// /// Original: x = init_enum_data_addr y : $*Enum, #Enum.Case /// inject_enum_addr y /// /// Adjoint: adj[x] += unchecked_take_enum_data_addr adj[y] void visitInjectEnumAddrInst(InjectEnumAddrInst *inject) { SILBasicBlock *bb = inject->getParent(); SILValue origEnum = inject->getOperand(); // Only `Optional`-typed operands are supported for now. Diagnose all other // enum operand types. auto *optionalEnumDecl = getASTContext().getOptionalDecl(); if (origEnum->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) { LLVM_DEBUG(getADDebugStream() << "Unsupported enum type in PullbackCloner: " << *inject); getContext().emitNondifferentiabilityError( inject, getInvoker(), diag::autodiff_expression_not_differentiable_note); errorOccurred = true; return; } // No associated value => no adjoint to propagate if (!inject->getElement()->hasAssociatedValues()) return; InitEnumDataAddrInst *origData = nullptr; for (auto use : origEnum->getUses()) { if (auto *init = dyn_cast(use->getUser())) { // We need a more complicated analysis when init_enum_data_addr and // inject_enum_addr are in different blocks, or there is more than one // such instruction. Bail out for now. if (origData || init->getParent() != bb) { LLVM_DEBUG(getADDebugStream() << "Could not find a matching init_enum_data_addr for: " << *inject); getContext().emitNondifferentiabilityError( inject, getInvoker(), diag::autodiff_expression_not_differentiable_note); errorOccurred = true; return; } origData = init; } } SILValue adjDest = getAdjointBuffer(bb, origEnum); VarDecl *adjOptVar = getASTContext().getOptionalTanValueDecl(adjDest->getType().getASTType()); SILLocation loc = origData->getLoc(); StructElementAddrInst *adjOpt = builder.createStructElementAddr(loc, adjDest, adjOptVar); // unchecked_take_enum_data_addr is destructive, so copy // Optional to a new alloca. AllocStackInst *adjOptCopy = createFunctionLocalAllocation(adjOpt->getType(), loc); builder.createCopyAddr(loc, adjOpt, adjOptCopy, IsNotTake, IsInitialization); // The Optional copy is invalidated, do not attempt to destroy it at the end // of the pullback. The value returned from unchecked_take_enum_data_addr is // destroyed in visitInitEnumDataAddrInst. auto [_, inserted] = enumDataAdjCopies.try_emplace(origData, adjOptCopy); assert(inserted && "expected single buffer"); EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl(); UncheckedTakeEnumDataAddrInst *adjData = builder.createUncheckedTakeEnumDataAddr(loc, adjOptCopy, someElemDecl); addToAdjointBuffer(bb, origData, adjData, loc); } /// Handle `init_enum_data_addr` instruction. /// Destroy the value returned from `unchecked_take_enum_data_addr`. void visitInitEnumDataAddrInst(InitEnumDataAddrInst *init) { SILValue adjOptCopy = enumDataAdjCopies.at(init); builder.emitDestroyAddr(init->getLoc(), adjOptCopy); destroyedLocalAllocations.insert(adjOptCopy); enumDataAdjCopies.erase(init); } /// Handle `unchecked_ref_cast` instruction. /// Original: y = unchecked_ref_cast x /// Adjoint: adj[x] += adj[y] /// (assuming adj[x] and adj[y] have the same type) void visitUncheckedRefCastInst(UncheckedRefCastInst *urci) { auto *bb = urci->getParent(); assert(urci->getOperand()->getType().isObject()); assert(getRemappedTangentType(urci->getOperand()->getType()) == getRemappedTangentType(urci->getType()) && "Operand/result must have the same `TangentVector` type"); switch (getTangentValueCategory(urci)) { case SILValueCategory::Object: { auto adj = getAdjointValue(bb, urci); addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc()); break; } case SILValueCategory::Address: { auto adjDest = getAdjointBuffer(bb, urci); addToAdjointBuffer(bb, urci->getOperand(), adjDest, urci->getLoc()); builder.emitZeroIntoBuffer(urci->getLoc(), adjDest, IsNotInitialization); break; } } } /// Handle `upcast` instruction. /// Original: y = upcast x /// Adjoint: adj[x] += adj[y] /// (assuming adj[x] and adj[y] have the same type) void visitUpcastInst(UpcastInst *ui) { auto *bb = ui->getParent(); assert(ui->getOperand()->getType().isObject()); assert(getRemappedTangentType(ui->getOperand()->getType()) == getRemappedTangentType(ui->getType()) && "Operand/result must have the same `TangentVector` type"); switch (getTangentValueCategory(ui)) { case SILValueCategory::Object: { auto adj = getAdjointValue(bb, ui); addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc()); break; } case SILValueCategory::Address: { auto adjDest = getAdjointBuffer(bb, ui); addToAdjointBuffer(bb, ui->getOperand(), adjDest, ui->getLoc()); builder.emitZeroIntoBuffer(ui->getLoc(), adjDest, IsNotInitialization); break; } } } /// Handle `unchecked_take_enum_data_addr` instruction. /// Currently, only `Optional`-typed operands are supported. /// Original: y = unchecked_take_enum_data_addr x : $*Enum, #Enum.Case /// Adjoint: adj[x] += $Enum.TangentVector(adj[y]) void visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *utedai) { auto *bb = utedai->getParent(); auto adjDest = getAdjointBuffer(bb, utedai); auto enumTy = utedai->getOperand()->getType(); auto *optionalEnumDecl = getASTContext().getOptionalDecl(); // Only `Optional`-typed operands are supported for now. Diagnose all other // enum operand types. if (enumTy.getASTType().getEnumOrBoundGenericEnum() != optionalEnumDecl) { LLVM_DEBUG(getADDebugStream() << "Unhandled instruction in PullbackCloner: " << *utedai); getContext().emitNondifferentiabilityError( utedai, getInvoker(), diag::autodiff_expression_not_differentiable_note); errorOccurred = true; return; } accumulateAdjointForOptionalBuffer(bb, utedai->getOperand(), adjDest); builder.emitZeroIntoBuffer(utedai->getLoc(), adjDest, IsNotInitialization); } #define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst); #undef NOT_DIFFERENTIABLE #define NO_ADJOINT(INST) \ void visit##INST##Inst(INST##Inst *inst) {} // Terminators. NO_ADJOINT(Return) NO_ADJOINT(Branch) NO_ADJOINT(CondBranch) NO_ADJOINT(Yield) // Address projections. NO_ADJOINT(StructElementAddr) NO_ADJOINT(TupleElementAddr) // Array literal initialization address projections. NO_ADJOINT(PointerToAddress) NO_ADJOINT(IndexAddr) // Memory allocation/access. NO_ADJOINT(AllocStack) NO_ADJOINT(DeallocStack) NO_ADJOINT(EndAccess) // Debugging/reference counting instructions. NO_ADJOINT(DebugValue) NO_ADJOINT(RetainValue) NO_ADJOINT(RetainValueAddr) NO_ADJOINT(ReleaseValue) NO_ADJOINT(ReleaseValueAddr) NO_ADJOINT(StrongRetain) NO_ADJOINT(StrongRelease) NO_ADJOINT(UnownedRetain) NO_ADJOINT(UnownedRelease) NO_ADJOINT(StrongRetainUnowned) NO_ADJOINT(DestroyValue) NO_ADJOINT(DestroyAddr) // Value ownership. NO_ADJOINT(EndBorrow) #undef NO_ADJOINT }; PullbackCloner::Implementation::Implementation(VJPCloner &vjpCloner) : vjpCloner(vjpCloner), scopeCloner(getPullback()), builder(getPullback(), getContext()), localAllocBuilder(getPullback(), getContext()) { // Get dominance and post-order info for the original function. auto &passManager = getContext().getPassManager(); auto *domAnalysis = passManager.getAnalysis(); auto *postDomAnalysis = passManager.getAnalysis(); auto *postOrderAnalysis = passManager.getAnalysis(); auto *original = &vjpCloner.getOriginal(); domInfo = domAnalysis->get(original); postDomInfo = postDomAnalysis->get(original); postOrderInfo = postOrderAnalysis->get(original); // Initialize `originalExitBlock`. auto origExitIt = original->findReturnBB(); assert(origExitIt != original->end() && "Functions without returns must have been diagnosed"); originalExitBlock = &*origExitIt; localAllocBuilder.setCurrentDebugScope( remapScope(originalExitBlock->getTerminator()->getDebugScope())); } PullbackCloner::PullbackCloner(VJPCloner &vjpCloner) : impl(*new Implementation(vjpCloner)) {} PullbackCloner::~PullbackCloner() { delete &impl; } static SILValue getArrayValue(ApplyInst *ai) { SILValue arrayValue; for (auto use : ai->getUses()) { auto *dti = dyn_cast(use->getUser()); if (!dti) continue; DEBUG_ASSERT(!arrayValue && "Array value already found"); // The first `destructure_tuple` result is the `Array` value. arrayValue = dti->getResult(0); #ifndef DEBUG_ASSERT_enabled break; #endif } ASSERT(arrayValue); return arrayValue; } //--------------------------------------------------------------------------// // Entry point //--------------------------------------------------------------------------// bool PullbackCloner::run() { bool foundError = impl.run(); #ifndef NDEBUG if (!foundError) impl.getPullback().verify(); #endif return foundError; } bool PullbackCloner::Implementation::run() { PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal()); auto &original = getOriginal(); auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); LLVM_DEBUG(getADDebugStream() << "Running PullbackCloner on\n" << original); // Collect original formal results. SmallVector origFormalResults; collectAllFormalResultsInTypeOrder(original, origFormalResults); for (auto resultIndex : getConfig().resultIndices->getIndices()) { auto origResult = origFormalResults[resultIndex]; // If original result is non-varied, it will always have a zero derivative. // Skip full pullback generation and simply emit zero derivatives for wrt // parameters. // // NOTE(TF-876): This shortcut is currently necessary for functions // returning non-varied result with >1 basic block where some basic blocks // have no dominated active values; control flow differentiation does not // handle this case. See TF-876 for context. if (!getActivityInfo().isVaried(origResult, getConfig().parameterIndices)) { emitZeroDerivativesForNonvariedResult(origResult); return false; } } // Collect dominated active values in original basic blocks. // Adjoint values of dominated active values are passed as pullback block // arguments. DominanceOrder domOrder(original.getEntryBlock(), domInfo); // Keep track of visited values. SmallPtrSet visited; while (auto *bb = domOrder.getNext()) { auto &bbActiveValues = activeValues[bb]; // If the current block has an immediate dominator, append the immediate // dominator block's active values to the current block's active values. if (auto *domNode = domInfo->getNode(bb)->getIDom()) { auto &domBBActiveValues = activeValues[domNode->getBlock()]; bbActiveValues.append(domBBActiveValues.begin(), domBBActiveValues.end()); } // If `v` is active and has not been visited, records it as an active value // in the original basic block. // For active values unsupported by differentiation, emits a diagnostic and // returns true. Otherwise, returns false. auto recordValueIfActive = [&](SILValue v) -> bool { // If value is not active, skip. if (!getActivityInfo().isActive(v, getConfig())) return false; // If active value has already been visited, skip. if (visited.count(v)) return false; // Mark active value as visited. visited.insert(v); // Diagnose unsupported active values. auto type = v->getType(); // Do not emit remaining activity-related diagnostics for semantic member // accessors, which have special-case pullback generation. if (isSemanticMemberAccessor(&original)) return false; // Diagnose active enum values. Differentiation of enum values requires // special adjoint value handling and is not yet supported. Diagnose // only the first active enum value to prevent too many diagnostics. // // Do not diagnose `Optional`-typed values, which will have special-case // differentiation support. if (type.getEnumOrBoundGenericEnum()) { if (!type.getASTType()->isOptional()) { getContext().emitNondifferentiabilityError( v, getInvoker(), diag::autodiff_enums_unsupported); errorOccurred = true; return true; } } // Diagnose unsupported stored property projections. if (isa(v) || isa(v) || isa(v)) { auto *inst = cast(v); assert(inst->getNumOperands() == 1); auto baseType = remapType(inst->getOperand(0)->getType()).getASTType(); if (!getTangentStoredProperty(getContext(), inst, baseType, getInvoker())) { errorOccurred = true; return true; } } // Skip address projections. // Address projections do not need their own adjoint buffers; they // become projections into their adjoint base buffer. if (Projection::isAddressProjection(v)) return false; // Co-routines borrow adjoint buffers for yields if (isa_and_nonnull(v.getDefiningInstruction())) return false; // Check that active values are differentiable. Otherwise we may crash // later when tangent space is required, but not available. if (!getTangentSpace(remapType(type).getASTType())) { getContext().emitNondifferentiabilityError( v, getInvoker(), diag::autodiff_expression_not_differentiable_note); errorOccurred = true; return true; } // Record active value. bbActiveValues.push_back(v); return false; }; // Record all active values in the basic block. for (auto *arg : bb->getArguments()) if (recordValueIfActive(arg)) return true; for (auto &inst : *bb) { for (auto op : inst.getOperandValues()) if (recordValueIfActive(op)) return true; for (auto result : inst.getResults()) if (recordValueIfActive(result)) return true; } domOrder.pushChildren(bb); } // Create pullback blocks and arguments, visiting original blocks using BFS // starting from the original exit block. Unvisited original basic blocks // (e.g unreachable blocks) are not relevant for pullback generation and thus // ignored. // The original blocks in traversal order for pullback generation. SmallVector originalBlocks; // The workqueue used for bookkeeping during the breadth-first traversal. BasicBlockWorkqueue workqueue = {originalExitBlock}; // Perform BFS from the original exit block. { while (auto *BB = workqueue.pop()) { originalBlocks.push_back(BB); for (auto *nextBB : BB->getPredecessorBlocks()) { // If there is no linear map tuple for predecessor BB, then BB is // unreachable from function entry. Do not run pullback cloner on it. if (getPullbackInfo().getLinearMapTupleType(nextBB)) workqueue.pushIfNotVisited(nextBB); } } } for (auto *origBB : originalBlocks) { auto *pullbackBB = pullback.createBasicBlock(); pullbackBBMap.insert({origBB, pullbackBB}); auto pbTupleLoweredType = remapType(getPullbackInfo().getLinearMapTupleLoweredType(origBB)); // If the BB is the original exit, then the pullback block that we just // created must be the pullback function's entry. For the pullback entry, // create entry arguments and continue to the next block. if (origBB == originalExitBlock) { assert(pullbackBB->isEntry()); createEntryArguments(&pullback); auto *origTerm = originalExitBlock->getTerminator(); builder.setCurrentDebugScope(remapScope(origTerm->getDebugScope())); builder.setInsertionPoint(pullbackBB); // Obtain the context object, if any, and the top-level subcontext, i.e. // the main pullback struct. if (getPullbackInfo().hasHeapAllocatedContext()) { // The last argument is the context object (`Builtin.NativeObject`). contextValue = pullbackBB->getArguments().back(); assert(contextValue->getType() == SILType::getNativeObjectType(getASTContext())); // Load the pullback context. auto subcontextAddr = emitProjectTopLevelSubcontext( builder, pbLoc, contextValue, pbTupleLoweredType); SILValue mainPullbackTuple = builder.createLoad( pbLoc, subcontextAddr, pbTupleLoweredType.isTrivial(getPullback()) ? LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy); auto *dsi = builder.createDestructureTuple(pbLoc, mainPullbackTuple); initializePullbackTupleElements(origBB, dsi->getAllResults()); } else { // Obtain and destructure pullback struct elements. unsigned numVals = pbTupleLoweredType.getAs()->getNumElements(); initializePullbackTupleElements(origBB, pullbackBB->getArguments().take_back(numVals)); } continue; } // Get all active values in the original block. // If the original block has no active values, continue. auto &bbActiveValues = activeValues[origBB]; if (bbActiveValues.empty()) continue; // Otherwise, if the original block has active values: // - For each active buffer in the original block, allocate a new local // buffer in the pullback entry. (All adjoint buffers are allocated in // the pullback entry and deallocated in the pullback exit.) // - For each active value in the original block, add adjoint value // arguments to the pullback block. for (auto activeValue : bbActiveValues) { // Handle the active value based on its value category. switch (getTangentValueCategory(activeValue)) { case SILValueCategory::Address: { // Allocate and zero initialize a new local buffer using // `getAdjointBuffer`. builder.setCurrentDebugScope( remapScope(originalExitBlock->getTerminator()->getDebugScope())); builder.setInsertionPoint(pullback.getEntryBlock()); getAdjointBuffer(origBB, activeValue); break; } case SILValueCategory::Object: { // Create and register pullback block argument for the active value. auto *pullbackArg = pullbackBB->createPhiArgument( getRemappedTangentType(activeValue->getType()), OwnershipKind::Owned); activeValuePullbackBBArgumentMap[{origBB, activeValue}] = pullbackArg; recordTemporary(pullbackArg); break; } } } // Add a pullback tuple argument. auto *pbTupleArg = pullbackBB->createPhiArgument(pbTupleLoweredType, OwnershipKind::Owned); // Destructure the pullback struct to get the elements. builder.setCurrentDebugScope( remapScope(origBB->getTerminator()->getDebugScope())); builder.setInsertionPoint(pullbackBB); auto *dsi = builder.createDestructureTuple(pbLoc, pbTupleArg); initializePullbackTupleElements(origBB, dsi->getResults()); // - Create pullback trampoline blocks for each successor block of the // original block. Pullback trampoline blocks only have a pullback // struct argument. They branch from a pullback successor block to the // pullback original block, passing adjoint values of active values. for (auto *succBB : origBB->getSuccessorBlocks()) { // Skip generating pullback block for original unreachable blocks. if (!workqueue.isVisited(succBB)) continue; auto *pullbackTrampolineBB = pullback.createBasicBlockBefore(pullbackBB); pullbackTrampolineBBMap.insert({{origBB, succBB}, pullbackTrampolineBB}); // Get the enum element type (i.e. the pullback struct type). The enum // element type may be boxed if the enum is indirect. auto enumLoweredTy = getPullbackInfo().getBranchingTraceEnumLoweredType(succBB); auto *enumEltDecl = getPullbackInfo().lookUpBranchingTraceEnumElement(origBB, succBB); auto enumEltType = remapType(enumLoweredTy.getEnumElementType( enumEltDecl, getModule(), TypeExpansionContext::minimal())); pullbackTrampolineBB->createPhiArgument(enumEltType, OwnershipKind::Owned); } } auto *pullbackEntry = pullback.getEntryBlock(); auto pbTupleLoweredType = remapType(getPullbackInfo().getLinearMapTupleLoweredType(originalExitBlock)); unsigned numVals = (getPullbackInfo().hasHeapAllocatedContext() ? 1 : pbTupleLoweredType.getAs()->getNumElements()); (void)numVals; // The pullback function has type: // `(seed0, seed1, ..., (exit_pb_tuple_el0, ..., )|context_obj) -> (d_arg0, ..., d_argn)`. auto conv = getOriginal().getConventions(); auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults(); assert(getConfig().resultIndices->getNumIndices() - conv.getNumYields() == pbParamArgs.size() - numVals && pbParamArgs.size() >= 1); // Assign adjoints for original result. builder.setCurrentDebugScope( remapScope(originalExitBlock->getTerminator()->getDebugScope())); builder.setInsertionPoint(pullbackEntry, getNextFunctionLocalAllocationInsertionPoint()); unsigned seedIndex = 0; unsigned firstSemanticParamResultIdx = conv.getResults().size(); unsigned firstYieldResultIndex = firstSemanticParamResultIdx + conv.getNumAutoDiffSemanticResultParameters(); for (auto resultIndex : getConfig().resultIndices->getIndices()) { // Yields seed buffers are only to be touched in yield BB and required // special handling if (resultIndex >= firstYieldResultIndex) continue; auto origResult = origFormalResults[resultIndex]; auto *seed = pbParamArgs[seedIndex]; if (seed->getType().isAddress()) { // If the seed argument is an `inout` parameter, assign it directly as // the adjoint buffer of the original result. auto seedParamInfo = pullback.getLoweredFunctionType()->getParameters()[seedIndex]; if (seedParamInfo.isIndirectInOut()) { setAdjointBuffer(originalExitBlock, origResult, seed); LLVM_DEBUG(getADDebugStream() << "Assigned seed buffer " << *seed << " as the adjoint of original indirect result " << origResult); } // Otherwise, assign a copy of the seed argument as the adjoint buffer of // the original result. else { auto *seedBufCopy = createFunctionLocalAllocation(seed->getType(), pbLoc); builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake, IsInitialization); setAdjointBuffer(originalExitBlock, origResult, seedBufCopy); LLVM_DEBUG(getADDebugStream() << "Assigned seed buffer " << *seedBufCopy << " as the adjoint of original indirect result " << origResult); } } else { addAdjointValue(originalExitBlock, origResult, makeConcreteAdjointValue(seed), pbLoc); LLVM_DEBUG(getADDebugStream() << "Assigned seed " << *seed << " as the adjoint of original result " << origResult); } ++seedIndex; } // If the original function is an accessor with special-case pullback // generation logic, do special-case generation. bool isSemanticMemberAcc = isSemanticMemberAccessor(&original); if (isSemanticMemberAcc) { if (runForSemanticMemberAccessor()) return true; } // Otherwise, perform standard pullback generation. // Visit original blocks in post-order and perform differentiation // in corresponding pullback blocks. If errors occurred, back out. else { LLVM_DEBUG(getADDebugStream() << "Begin search for adjoints of loop-local active values\n"); llvm::DenseMap> loopLocalActiveValues; for (auto *bb : originalBlocks) { const SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb); if (loop == nullptr) continue; SILBasicBlock *loopHeader = loop->getHeader(); SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader); LLVM_DEBUG(getADDebugStream() << "Original bb" << bb->getDebugID() << " belongs to a loop, original header bb" << loopHeader->getDebugID() << ", pullback header bb" << pbLoopHeader->getDebugID() << '\n'); builder.setInsertionPoint(pbLoopHeader); auto bbActiveValuesIt = activeValues.find(bb); if (bbActiveValuesIt == activeValues.end()) continue; const auto &bbActiveValues = bbActiveValuesIt->second; for (SILValue bbActiveValue : bbActiveValues) { if (vjpCloner.getLoopInfo()->getLoopFor( bbActiveValue->getParentBlock()) != loop) { LLVM_DEBUG( getADDebugStream() << "The following active value is NOT loop-local, skipping: " << bbActiveValue); continue; } auto [_, wasInserted] = loopLocalActiveValues[loop].insert(bbActiveValue); LLVM_DEBUG(getADDebugStream() << "The following active value is loop-local, "); if (!wasInserted) { LLVM_DEBUG(llvm::dbgs() << "but it was already processed, skipping: " << bbActiveValue); continue; } if (getTangentValueCategory(bbActiveValue) == SILValueCategory::Object) { LLVM_DEBUG(llvm::dbgs() << "zeroing its adjoint value in loop header: " << bbActiveValue); setAdjointValue(bb, bbActiveValue, makeZeroAdjointValue(getRemappedTangentType( bbActiveValue->getType()))); continue; } ASSERT(getTangentValueCategory(bbActiveValue) == SILValueCategory::Address); // getAdjointProjection might call materializeAdjointDirect which // writes to debug output, emit \n. LLVM_DEBUG(llvm::dbgs() << "checking if it's adjoint is a projection\n"); if (!getAdjointProjection(bb, bbActiveValue)) { LLVM_DEBUG(getADDebugStream() << "Adjoint for the following value is NOT a projection, " "zeroing its adjoint buffer in loop header: " << bbActiveValue); // All adjoint buffers are allocated in the pullback entry and // deallocated in the pullback exit. So, use IsNotInitialization to // emit destroy_addr before zeroing the buffer. ASSERT(bufferMap.contains({bb, bbActiveValue})); builder.emitZeroIntoBuffer(pbLoc, getAdjointBuffer(bb, bbActiveValue), IsNotInitialization); continue; } LLVM_DEBUG(getADDebugStream() << "Adjoint for the following value is a projection, "); // If Projection::isAddressProjection(v) is true for a value v, it // is not added to active values list (see recordValueIfActive). // // Ensure that only the following value types conforming to // getAdjointProjection but not conforming to // Projection::isAddressProjection can go here. // // Instructions conforming to Projection::isAddressProjection and // thus never corresponding to an active value do not need any // handling, because only active values can have adjoints from // previous iterations propagated via BB arguments. do { // Consider '%X = begin_access [modify] [static] %Y'. // 1. If %Y is loop-local, it's adjoint buffer will // be zeroed, and we'll have zero adjoint projection to it. // 2. Otherwise, we do not need to zero the projection buffer. // Thus, we can just skip. if (dyn_cast(bbActiveValue)) { LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue); break; } // Consider the following sequence: // %1 = function_ref @allocUninitArray // %2 = apply %1(%0) // (%3, %4) = destructure_tuple %2 // %5 = mark_dependence %4 on %3 // %6 = pointer_to_address %6 to [strict] $*Float // Since %6 is active, %3 (which is an array) must also be active. // Thus, adjoint for %3 will be zeroed if needed. Ensure that expected // invariants hold and then skip. if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress( bbActiveValue)) { ASSERT(isa(bbActiveValue)); SILValue arrayValue = getArrayValue(ai); ASSERT(llvm::find(bbActiveValues, arrayValue) != bbActiveValues.end()); ASSERT(vjpCloner.getLoopInfo()->getLoopFor( arrayValue->getParentBlock()) == loop); LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue); break; } ASSERT(false); } while (false); } } LLVM_DEBUG(getADDebugStream() << "End search for adjoints of loop-local active values\n"); for (auto *bb : originalBlocks) { visitSILBasicBlock(bb); if (errorOccurred) return true; } } // Prepare and emit a `return` in the pullback exit block. auto *origEntry = getOriginal().getEntryBlock(); auto *pbExit = getPullbackBlock(origEntry); builder.setCurrentDebugScope(pbExit->back().getDebugScope()); builder.setInsertionPoint(pbExit); // This vector will contain all the materialized return elements. SmallVector retElts; // This vector will contain all indirect parameter adjoint buffers. SmallVector indParamAdjoints; // This vector will identify the locations where initialization is needed. SmallBitVector outputsToInitialize; auto origParams = getOriginal().getArgumentsWithoutIndirectResults(); // Materializes the return element corresponding to the parameter // `parameterIndex` into the `retElts` vector. auto addRetElt = [&](unsigned parameterIndex) -> void { auto origParam = origParams[parameterIndex]; switch (getTangentValueCategory(origParam)) { case SILValueCategory::Object: { auto pbVal = getAdjointValue(origEntry, origParam); auto val = materializeAdjointDirect(pbVal, pbLoc); auto newVal = builder.emitCopyValueOperation(pbLoc, val); retElts.push_back(newVal); break; } case SILValueCategory::Address: { auto adjBuf = getAdjointBuffer(origEntry, origParam); indParamAdjoints.push_back(adjBuf); outputsToInitialize.push_back( !conv.getParameters()[parameterIndex].isIndirectMutating()); break; } } }; SmallVector pullbackIndirectResults( getPullback().getIndirectResults().begin(), getPullback().getIndirectResults().end()); // Collect differentiation parameter adjoints. // Do a first pass to collect non-inout values. for (auto i : getConfig().parameterIndices->getIndices()) { if (!conv.getParameters()[i].isAutoDiffSemanticResult()) { addRetElt(i); } } // Do a second pass for all inout parameters, however this is only necessary // for functions with multiple basic blocks. For functions with a single // basic block adjoint accumulation for those parameters is already done by // per-instruction visitors. if (getOriginal().size() > 1) { const auto &pullbackConv = pullback.getConventions(); SmallVector pullbackInOutArgs; for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) { if (pullbackConv.getParameters()[pullbackArg.index()].isAutoDiffSemanticResult()) pullbackInOutArgs.push_back(pullbackArg.value()); } unsigned pullbackInoutArgumentIdx = 0; for (auto i : getConfig().parameterIndices->getIndices()) { // Skip non-inout parameters. if (!conv.getParameters()[i].isAutoDiffSemanticResult()) continue; // For functions with multiple basic blocks, accumulation is needed // for `inout` parameters because pullback basic blocks have different // adjoint buffers. pullbackIndirectResults.push_back(pullbackInOutArgs[pullbackInoutArgumentIdx++]); addRetElt(i); } } // Copy them to adjoint indirect results. assert(indParamAdjoints.size() == pullbackIndirectResults.size() && "Indirect parameter adjoint count mismatch"); unsigned currentIndex = 0; for (auto pair : zip(indParamAdjoints, pullbackIndirectResults)) { auto source = std::get<0>(pair); auto *dest = std::get<1>(pair); if (outputsToInitialize[currentIndex]) { builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization); } else { builder.createCopyAddr(pbLoc, source, dest, IsTake, IsNotInitialization); } currentIndex++; // Prevent source buffer from being deallocated, since the underlying // value is moved. destroyedLocalAllocations.insert(source); } // Emit cleanups for all local values. cleanUpTemporariesForBlock(pbExit, pbLoc); // Deallocate local allocations. for (auto alloc : functionLocalAllocations) { // Assert that local allocations have at least one use. // Buffers should not be allocated needlessly. assert(!alloc->use_empty()); if (!destroyedLocalAllocations.count(alloc)) { builder.emitDestroyAddrAndFold(pbLoc, alloc); destroyedLocalAllocations.insert(alloc); } builder.createDeallocStack(pbLoc, alloc); } builder.createReturn(pbLoc, joinElements(retElts, builder, pbLoc)); #ifndef NDEBUG bool leakFound = false; // Ensure all temporaries have been cleaned up. for (auto &bb : pullback) { for (auto temp : blockTemporaries[&bb]) { if (blockTemporaries[&bb].count(temp)) { leakFound = true; getADDebugStream() << "Found leaked temporary:\n" << temp; } } } // Ensure all enum adjoint copeis have been cleaned up for (const auto &enumData : enumDataAdjCopies) { leakFound = true; getADDebugStream() << "Found leaked temporary:\n" << enumData.second; } // Ensure all local allocations have been cleaned up. for (auto localAlloc : functionLocalAllocations) { if (!destroyedLocalAllocations.count(localAlloc)) { leakFound = true; getADDebugStream() << "Found leaked local buffer:\n" << localAlloc; } } assert(!leakFound && "Leaks found!"); #endif LLVM_DEBUG(getADDebugStream() << "Generated " << (isSemanticMemberAcc ? "semantic member accessor" : "normal") << " pullback for " << original.getName() << ":\n" << pullback); return errorOccurred; } void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult( SILValue origNonvariedResult) { auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); /* // TODO(TF-788): Re-enable non-varied result warning. // Emit fixit if original non-varied result has a valid source location. auto startLoc = origNonvariedResult.getLoc().getStartSourceLoc(); auto endLoc = origNonvariedResult.getLoc().getEndSourceLoc(); if (startLoc.isValid() && endLoc.isValid()) { getContext().diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) .fixItInsert(startLoc, "withoutDerivative(at:") .fixItInsertAfter(endLoc, ")"); } */ LLVM_DEBUG(getADDebugStream() << getOriginal().getName() << " has non-varied result, returning zero" " for all pullback results\n"); auto *pullbackEntry = pullback.createBasicBlock(); createEntryArguments(&pullback); builder.setCurrentDebugScope( remapScope(originalExitBlock->getTerminator()->getDebugScope())); builder.setInsertionPoint(pullbackEntry); // Destroy all owned arguments. for (auto *arg : pullbackEntry->getArguments()) if (arg->getOwnershipKind() == OwnershipKind::Owned) builder.emitDestroyOperation(pbLoc, arg); // Return zero for each result. SmallVector directResults; auto indirectResultIt = pullback.getIndirectResults().begin(); for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) { auto resultType = pullback.mapTypeIntoContext(resultInfo.getInterfaceType()) ->getCanonicalType(); if (resultInfo.isFormalDirect()) directResults.push_back(builder.emitZero(pbLoc, resultType)); else builder.emitZeroIntoBuffer(pbLoc, *indirectResultIt++, IsInitialization); } builder.createReturn(pbLoc, joinElements(directResults, builder, pbLoc)); LLVM_DEBUG(getADDebugStream() << "Generated pullback for " << getOriginal().getName() << ":\n" << pullback); } AllocStackInst *PullbackCloner::Implementation::createOptionalAdjoint( SILBasicBlock *bb, SILValue wrappedAdjoint, SILType optionalTy) { auto pbLoc = getPullback().getLocation(); // `Optional` optionalTy = remapType(optionalTy); assert(optionalTy.getASTType()->isOptional()); // `T` auto wrappedType = optionalTy.getOptionalObjectType(); // `T.TangentVector` auto wrappedTanType = remapType(wrappedAdjoint->getType()); // `Optional` auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType); // `Optional.TangentVector` auto optionalTanTy = getRemappedTangentType(optionalTy); // Look up the `Optional.TangentVector.init` declaration. ConstructorDecl *constructorDecl = getASTContext().getOptionalTanInitDecl(optionalTanTy.getASTType()); // Allocate a local buffer for the `Optional` adjoint value. auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy); // Find `Optional.some` EnumElementDecl. auto someEltDecl = builder.getASTContext().getOptionalSomeDecl(); // Initialize an `Optional` buffer from `wrappedAdjoint` as // the input for `Optional.TangentVector.init`. auto *optArgBuf = builder.createAllocStack(pbLoc, optionalOfWrappedTanType); if (optionalOfWrappedTanType.isObject()) { // %enum = enum $Optional, #Optional.some!enumelt, // %wrappedAdjoint : $T auto *enumInst = builder.createEnum(pbLoc, wrappedAdjoint, someEltDecl, optionalOfWrappedTanType); // store %enum to %optArgBuf builder.emitStoreValueOperation(pbLoc, enumInst, optArgBuf, StoreOwnershipQualifier::Init); } else { // %enumAddr = init_enum_data_addr %optArgBuf $Optional, // #Optional.some!enumelt auto *enumAddr = builder.createInitEnumDataAddr( pbLoc, optArgBuf, someEltDecl, wrappedTanType.getAddressType()); // copy_addr %wrappedAdjoint to [init] %enumAddr builder.createCopyAddr(pbLoc, wrappedAdjoint, enumAddr, IsNotTake, IsInitialization); // inject_enum_addr %optArgBuf : $*Optional, // #Optional.some!enumelt builder.createInjectEnumAddr(pbLoc, optArgBuf, someEltDecl); } // Apply `Optional.TangentVector.init`. SILOptFunctionBuilder fb(getContext().getTransform()); // %init_fn = function_ref @Optional.TangentVector.init auto *initFn = fb.getOrCreateFunction(pbLoc, SILDeclRef(constructorDecl), NotForDefinition); auto *initFnRef = builder.createFunctionRef(pbLoc, initFn); auto *diffProto = builder.getASTContext().getProtocol(KnownProtocolKind::Differentiable); auto diffConf = lookupConformance(wrappedType.getASTType(), diffProto); assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); auto subMap = SubstitutionMap::get( initFn->getLoweredFunctionType()->getSubstGenericSignature(), ArrayRef(wrappedType.getASTType()), {diffConf}); // %metatype = metatype $Optional.TangentVector.Type auto metatypeType = CanMetatypeType::get(optionalTanTy.getASTType(), MetatypeRepresentation::Thin); auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType); auto metatype = builder.createMetatype(pbLoc, metatypeSILType); // apply %init_fn(%optTanAdjBuf, %optArgBuf, %metatype) builder.createApply(pbLoc, initFnRef, subMap, {optTanAdjBuf, optArgBuf, metatype}); builder.createDeallocStack(pbLoc, optArgBuf); return optTanAdjBuf; } // Accumulate adjoint for the incoming `Optional` buffer. void PullbackCloner::Implementation::accumulateAdjointForOptionalBuffer( SILBasicBlock *bb, SILValue optionalBuffer, SILValue wrappedAdjoint) { assert(getTangentValueCategory(optionalBuffer) == SILValueCategory::Address); auto pbLoc = getPullback().getLocation(); // Allocate and initialize Optional.TangentVector from // Wrapped.TangentVector AllocStackInst *optTanAdjBuf = createOptionalAdjoint(bb, wrappedAdjoint, optionalBuffer->getType()); // Accumulate into optionalBuffer addToAdjointBuffer(bb, optionalBuffer, optTanAdjBuf, pbLoc); builder.emitDestroyAddr(pbLoc, optTanAdjBuf); builder.createDeallocStack(pbLoc, optTanAdjBuf); } // Accumulate adjoint for the incoming `Optional` value. void PullbackCloner::Implementation::accumulateAdjointValueForOptional( SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) { assert(getTangentValueCategory(optionalValue) == SILValueCategory::Object); auto pbLoc = getPullback().getLocation(); // Allocate and initialize Optional.TangentVector from // Wrapped.TangentVector AllocStackInst *optTanAdjBuf = createOptionalAdjoint(bb, wrappedAdjoint, optionalValue->getType()); auto optTanAdjVal = builder.emitLoadValueOperation( pbLoc, optTanAdjBuf, LoadOwnershipQualifier::Take); recordTemporary(optTanAdjVal); builder.createDeallocStack(pbLoc, optTanAdjBuf); addAdjointValue(bb, optionalValue, makeConcreteAdjointValue(optTanAdjVal), pbLoc); } SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor( SILBasicBlock *origBB, SILBasicBlock *origPredBB, SmallDenseMap &pullbackTrampolineBlockMap) { // Get the pullback block and optional pullback trampoline block of the // predecessor block. auto *pullbackBB = getPullbackBlock(origPredBB); auto *pullbackTrampolineBB = getPullbackTrampolineBlock(origPredBB, origBB); // If the predecessor block does not have a corresponding pullback // trampoline block, then the pullback successor is the pullback block. if (!pullbackTrampolineBB) return pullbackBB; // Otherwise, the pullback successor is the pullback trampoline block, // which branches to the pullback block and propagates adjoint values of // active values. assert(pullbackTrampolineBB->getNumArguments() == 1); auto loc = origBB->getParent()->getLocation(); SmallVector trampolineArguments; // Propagate adjoint values/buffers of active values/buffers to // predecessor blocks. auto &predBBActiveValues = activeValues[origPredBB]; llvm::SmallSet, 32> propagatedAdjoints; for (auto activeValue : predBBActiveValues) { LLVM_DEBUG(getADDebugStream() << "Propagating adjoint of active value " << activeValue << "from bb" << origBB->getDebugID() << " to predecessors' (bb" << origPredBB->getDebugID() << ") pullback blocks\n"); switch (getTangentValueCategory(activeValue)) { case SILValueCategory::Object: { auto activeValueAdj = getAdjointValue(origBB, activeValue); auto concreteActiveValueAdj = materializeAdjointDirect(activeValueAdj, loc); if (!pullbackTrampolineBlockMap.count(concreteActiveValueAdj)) { concreteActiveValueAdj = builder.emitCopyValueOperation(loc, concreteActiveValueAdj); setAdjointValue(origBB, activeValue, makeConcreteAdjointValue(concreteActiveValueAdj)); } auto insertion = pullbackTrampolineBlockMap.try_emplace( concreteActiveValueAdj, TrampolineBlockSet()); auto &blockSet = insertion.first->getSecond(); blockSet.insert(pullbackTrampolineBB); trampolineArguments.push_back(concreteActiveValueAdj); // If the pullback block does not yet have a registered adjoint // value for the active value, set the adjoint value to the // forwarded adjoint value argument. // TODO: Hoist this logic out of loop over predecessor blocks to // remove the `hasAdjointValue` check. if (!hasAdjointValue(origPredBB, activeValue)) { auto *pullbackBBArg = getActiveValuePullbackBlockArgument(origPredBB, activeValue); auto forwardedArgAdj = makeConcreteAdjointValue(pullbackBBArg); setAdjointValue(origPredBB, activeValue, forwardedArgAdj); } break; } case SILValueCategory::Address: { // Propagate adjoint buffers using `copy_addr`. auto adjBuf = getAdjointBuffer(origBB, activeValue); auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue); if (propagatedAdjoints.insert({adjBuf, predAdjBuf}).second) builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization); break; } } } // Propagate pullback struct argument. TangentBuilder pullbackTrampolineBBBuilder( pullbackTrampolineBB, getContext()); pullbackTrampolineBBBuilder.setCurrentDebugScope( remapScope(origPredBB->getTerminator()->getDebugScope())); auto *pullbackTrampolineBBArg = pullbackTrampolineBB->getArguments().front(); if (vjpCloner.getLoopInfo()->getLoopFor(origPredBB)) { assert(pullbackTrampolineBBArg->getType() == SILType::getRawPointerType(getASTContext())); auto pbTupleType = remapType(getPullbackInfo().getLinearMapTupleLoweredType(origPredBB)); auto predPbTupleAddr = pullbackTrampolineBBBuilder.createPointerToAddress( loc, pullbackTrampolineBBArg, pbTupleType.getAddressType(), /*isStrict*/ true); auto predPbStructVal = pullbackTrampolineBBBuilder.createLoad( loc, predPbTupleAddr, pbTupleType.isTrivial(getPullback()) ? LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Copy); trampolineArguments.push_back(predPbStructVal); } else { trampolineArguments.push_back(pullbackTrampolineBBArg); } // Branch from pullback trampoline block to pullback block. pullbackTrampolineBBBuilder.createBranch(loc, pullbackBB, trampolineArguments); return pullbackTrampolineBB; } void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { auto pbLoc = getPullback().getLocation(); // Get the corresponding pullback basic block. auto *pbBB = getPullbackBlock(bb); builder.setInsertionPoint(pbBB); LLVM_DEBUG({ auto &s = getADDebugStream() << "Original bb" + std::to_string(bb->getDebugID()) << ": To differentiate or not to differentiate?\n"; for (auto &inst : llvm::reverse(*bb)) { s << (getPullbackInfo().shouldDifferentiateInstruction(&inst) ? "[x] " : "[ ] ") << inst; } }); // Visit each instruction in reverse order. for (auto &inst : llvm::reverse(*bb)) { if (!getPullbackInfo().shouldDifferentiateInstruction(&inst)) continue; // Differentiate instruction. builder.setCurrentDebugScope(remapScope(inst.getDebugScope())); visit(&inst); if (errorOccurred) return; } // Emit a branching terminator for the block. // If the original block is the original entry, then the pullback block is // the pullback exit. This is handled specially in // `PullbackCloner::Implementation::run()`, so we leave the block // non-terminated. if (bb->isEntry()) return; // If the original block is a resume yield destination, then we need to yield // the adjoint buffer and do everything else in the resume destination. Unwind // destination is unreachable as the co-routine can never be aborted. if (auto *predBB = bb->getSinglePredecessorBlock()) { if (auto *yield = dyn_cast(predBB->getTerminator())) { auto *resumeBB = pbBB->split(builder.getInsertionPoint()); auto *unwindBB = getPullback().createBasicBlock(); SmallVector adjYields; for (auto yieldedVal : yield->getYieldedValues()) adjYields.push_back(getAdjointBuffer(bb, yieldedVal)); builder.createYield(yield->getLoc(), adjYields, resumeBB, unwindBB); builder.setInsertionPoint(unwindBB); builder.createUnreachable(SILLocation::invalid()); pbBB = resumeBB; builder.setInsertionPoint(pbBB); } } // Otherwise, add a `switch_enum` terminator for non-exit // pullback blocks. // 1. Get the pullback struct pullback block argument. // 2. Extract the predecessor enum value from the pullback struct value. auto *predEnum = getPullbackInfo().getBranchingTraceDecl(bb); (void)predEnum; auto predEnumVal = getPullbackPredTupleElement(bb); // Propagate adjoint values from active basic block arguments to // incoming values (predecessor terminator operands). for (auto *bbArg : bb->getArguments()) { if (!getActivityInfo().isActive(bbArg, getConfig())) continue; LLVM_DEBUG(getADDebugStream() << "Propagating adjoint value for active bb" << bb->getDebugID() << " argument: " << *bbArg); // Get predecessor terminator operands. SmallVector, 4> incomingValues; if (bbArg->getSingleTerminatorOperands(incomingValues)) { // Returns true if the given terminator instruction is a `switch_enum` on // an `Optional`-typed value. `switch_enum` instructions require // special-case adjoint value propagation for the operand. auto isSwitchEnumInstOnOptional = [&ctx = getASTContext()](TermInst *termInst) { if (!termInst) return false; if (auto *sei = dyn_cast(termInst)) { auto operandTy = sei->getOperand()->getType(); return operandTy.getASTType()->isOptional(); } return false; }; // Check the tangent value category of the active basic block argument. switch (getTangentValueCategory(bbArg)) { // If argument has a loadable tangent value category: materialize adjoint // value of the argument, create a copy, and set the copy as the adjoint // value of incoming values. case SILValueCategory::Object: { auto bbArgAdj = getAdjointValue(bb, bbArg); auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc); auto concreteBBArgAdjCopy = builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj); for (auto pair : incomingValues) { auto *predBB = std::get<0>(pair); auto incomingValue = std::get<1>(pair); // Handle `switch_enum` on `Optional`. auto termInst = bbArg->getSingleTerminator(); if (isSwitchEnumInstOnOptional(termInst)) { accumulateAdjointValueForOptional(bb, incomingValue, concreteBBArgAdjCopy); } else { blockTemporaries[getPullbackBlock(predBB)].insert( concreteBBArgAdjCopy); addAdjointValue(predBB, incomingValue, makeConcreteAdjointValue(concreteBBArgAdjCopy), pbLoc); } } break; } // If argument has an address tangent value category: materialize adjoint // value of the argument, create a copy, and set the copy as the adjoint // value of incoming values. case SILValueCategory::Address: { auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg); for (auto pair : incomingValues) { auto incomingValue = std::get<1>(pair); // Handle `switch_enum` on `Optional`. auto termInst = bbArg->getSingleTerminator(); if (isSwitchEnumInstOnOptional(termInst)) accumulateAdjointForOptionalBuffer(bb, incomingValue, bbArgAdjBuf); else addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc); } break; } } } else { LLVM_DEBUG(getADDebugStream() << "do not know how to handle this incoming bb argument"); if (auto term = bbArg->getSingleTerminator()) { getContext().emitNondifferentiabilityError(term, getInvoker(), diag::autodiff_expression_not_differentiable_note); } else { // This will be a bit confusing, but still better than nothing. getContext().emitNondifferentiabilityError(bbArg, getInvoker(), diag::autodiff_expression_not_differentiable_note); } errorOccurred = true; return; } } // 3. Build the pullback successor cases for the `switch_enum` // instruction. The pullback successors correspond to the predecessors // of the current original block. SmallVector, 4> pullbackSuccessorCases; // A map from active values' adjoint values to the trampoline blocks that // are using them. SmallDenseMap pullbackTrampolineBlockMap; SmallDenseMap origPredpullbackSuccBBMap; for (auto *predBB : bb->getPredecessorBlocks()) { // If there is no linear map tuple for predecessor BB, then BB is // unreachable from function entry. There is no branch tracing enum for it // as well, so we should not create any branching to it in the pullback. if (!getPullbackInfo().getLinearMapTupleType(predBB)) continue; auto *pullbackSuccBB = buildPullbackSuccessor(bb, predBB, pullbackTrampolineBlockMap); origPredpullbackSuccBBMap[predBB] = pullbackSuccBB; auto *enumEltDecl = getPullbackInfo().lookUpBranchingTraceEnumElement(predBB, bb); pullbackSuccessorCases.emplace_back(enumEltDecl, pullbackSuccBB); } // Values are trampolined by only a subset of pullback successor blocks. // Other successors blocks should destroy the value. for (auto pair : pullbackTrampolineBlockMap) { auto value = pair.getFirst(); // The set of trampoline BBs that are users of `value`. auto &userTrampolineBBSet = pair.getSecond(); // For each pullback successor block that does not trampoline the value, // release the value. for (auto origPredPbSuccPair : origPredpullbackSuccBBMap) { auto *origPred = origPredPbSuccPair.getFirst(); auto *pbSucc = origPredPbSuccPair.getSecond(); if (userTrampolineBBSet.count(pbSucc)) continue; TangentBuilder pullbackSuccBuilder(pbSucc->begin(), getContext()); pullbackSuccBuilder.setCurrentDebugScope( remapScope(origPred->getTerminator()->getDebugScope())); pullbackSuccBuilder.emitDestroyValueOperation(pbLoc, value); } } // Emit cleanups for all block-local temporaries. cleanUpTemporariesForBlock(pbBB, pbLoc); // Branch to pullback successor blocks. assert(pullbackSuccessorCases.size() == predEnum->getNumElements()); builder.createSwitchEnum(pbLoc, predEnumVal, /*DefaultBB*/ nullptr, pullbackSuccessorCases, std::nullopt, ProfileCounter(), OwnershipKind::Owned); } //--------------------------------------------------------------------------// // Member accessor pullback generation //--------------------------------------------------------------------------// bool PullbackCloner::Implementation::runForSemanticMemberAccessor() { auto &original = getOriginal(); auto *accessor = cast(original.getDeclContext()->getAsDecl()); switch (accessor->getAccessorKind()) { case AccessorKind::Get: case AccessorKind::DistributedGet: return runForSemanticMemberGetter(); case AccessorKind::Set: return runForSemanticMemberSetter(); case AccessorKind::Modify: return runForSemanticMemberModify(); default: llvm_unreachable("Unsupported accessor kind; inconsistent with " "`isSemanticMemberAccessor`?"); } } bool PullbackCloner::Implementation::runForSemanticMemberGetter() { auto &original = getOriginal(); auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); auto *accessor = cast(original.getDeclContext()->getAsDecl()); assert(accessor->getAccessorKind() == AccessorKind::Get); auto *origEntry = original.getEntryBlock(); auto *pbEntry = pullback.getEntryBlock(); builder.setCurrentDebugScope( remapScope(origEntry->getScopeOfFirstNonMetaInstruction())); builder.setInsertionPoint(pbEntry); // Get getter argument and result values. // Getter type: $(Self) -> Result // Pullback type: $(Result') -> Self' assert(original.getLoweredFunctionType()->getNumParameters() == 1); assert(pullback.getLoweredFunctionType()->getNumParameters() == 1); assert(pullback.getLoweredFunctionType()->getNumResults() == 1); SILValue origSelf = original.getArgumentsWithoutIndirectResults().front(); SmallVector origFormalResults; collectAllFormalResultsInTypeOrder(original, origFormalResults); assert(getConfig().resultIndices->getNumIndices() == 1 && "Getter should have one semantic result"); auto origResult = origFormalResults[*getConfig().resultIndices->begin()]; auto tangentVectorSILTy = pullback.getConventions().getResults().front() .getSILStorageType(getModule(), pullback.getLoweredFunctionType(), TypeExpansionContext::minimal()); auto tangentVectorTy = tangentVectorSILTy.getASTType(); auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); // Look up the corresponding field in the tangent space. auto *origField = cast(accessor->getStorage()); auto baseType = remapType(origSelf->getType()).getASTType(); auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, pbLoc, getInvoker()); if (!tanField) { errorOccurred = true; return true; } // Switch based on the base tangent struct's value category. switch (getTangentValueCategory(origSelf)) { case SILValueCategory::Object: { auto adjResult = getAdjointValue(origEntry, origResult); switch (adjResult.getKind()) { case AdjointValueKind::Zero: addAdjointValue(origEntry, origSelf, makeZeroAdjointValue(tangentVectorSILTy), pbLoc); break; case AdjointValueKind::Concrete: case AdjointValueKind::Aggregate: { SmallVector eltVals; for (auto *field : tangentVectorDecl->getStoredProperties()) { if (field == tanField) { eltVals.push_back(adjResult); } else { auto substMap = tangentVectorTy->getMemberSubstitutionMap(field); auto fieldTy = field->getInterfaceType().subst(substMap); auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); assert(fieldSILTy.isObject()); eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); } } addAdjointValue(origEntry, origSelf, makeAggregateAdjointValue(tangentVectorSILTy, eltVals), pbLoc); break; } case AdjointValueKind::AddElement: llvm_unreachable("Adjoint of an aggregate type's field cannot be of kind " "`AddElement`"); } break; } case SILValueCategory::Address: { assert(pullback.getIndirectResults().size() == 1); auto pbIndRes = pullback.getIndirectResults().front(); auto *adjSelf = createFunctionLocalAllocation( pbIndRes->getType().getObjectType(), pbLoc); setAdjointBuffer(origEntry, origSelf, adjSelf); for (auto *field : tangentVectorDecl->getStoredProperties()) { auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field); // Non-tangent fields get a zero. if (field != tanField) { builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization); continue; } // Switch based on the property's value category. switch (getTangentValueCategory(origResult)) { case SILValueCategory::Object: { auto adjResult = getAdjointValue(origEntry, origResult); auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc); auto adjResultValueCopy = builder.emitCopyValueOperation(pbLoc, adjResultValue); builder.emitStoreValueOperation(pbLoc, adjResultValueCopy, adjSelfElt, StoreOwnershipQualifier::Init); break; } case SILValueCategory::Address: { auto adjResult = getAdjointBuffer(origEntry, origResult); builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake, IsInitialization); destroyedLocalAllocations.insert(adjResult); break; } } } break; } } return false; } bool PullbackCloner::Implementation::runForSemanticMemberSetter() { auto &original = getOriginal(); auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); auto *accessor = cast(original.getDeclContext()->getAsDecl()); assert(accessor->getAccessorKind() == AccessorKind::Set); auto *origEntry = original.getEntryBlock(); auto *pbEntry = pullback.getEntryBlock(); builder.setCurrentDebugScope( remapScope(origEntry->getScopeOfFirstNonMetaInstruction())); builder.setInsertionPoint(pbEntry); // Get setter argument values. // Setter type: $(inout Self, Argument) -> () // Pullback type (wrt self): $(inout Self') -> () // Pullback type (wrt both): $(inout Self') -> Argument' assert(original.getLoweredFunctionType()->getNumParameters() == 2); assert(pullback.getLoweredFunctionType()->getNumParameters() == 1); assert(pullback.getLoweredFunctionType()->getNumResults() == 0 || pullback.getLoweredFunctionType()->getNumResults() == 1); SILValue origArg = original.getArgumentsWithoutIndirectResults()[0]; SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1]; // Look up the corresponding field in the tangent space. auto *origField = cast(accessor->getStorage()); auto baseType = remapType(origSelf->getType()).getASTType(); auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, pbLoc, getInvoker()); if (!tanField) { errorOccurred = true; return true; } auto adjSelf = getAdjointBuffer(origEntry, origSelf); auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField); // Switch based on the property's value category. switch (getTangentValueCategory(origArg)) { case SILValueCategory::Object: { auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt, LoadOwnershipQualifier::Take); setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg)); blockTemporaries[pbEntry].insert(adjArg); break; } case SILValueCategory::Address: { addToAdjointBuffer(origEntry, origArg, adjSelfElt, pbLoc); builder.emitDestroyOperation(pbLoc, adjSelfElt); break; } } builder.emitZeroIntoBuffer(pbLoc, adjSelfElt, IsInitialization); return false; } bool PullbackCloner::Implementation::runForSemanticMemberModify() { auto &original = getOriginal(); auto &pullback = getPullback(); auto pbLoc = getPullback().getLocation(); auto *accessor = cast(original.getDeclContext()->getAsDecl()); assert(accessor->getAccessorKind() == AccessorKind::Modify); auto *origEntry = original.getEntryBlock(); // We assume that the accessor has a simple 3-BB structure with yield in the entry BB // plus resume and unwind BBs auto *yi = cast(origEntry->getTerminator()); auto *origResumeBB = yi->getResumeBB(); auto *pbEntry = pullback.getEntryBlock(); builder.setCurrentDebugScope( remapScope(origEntry->getScopeOfFirstNonMetaInstruction())); builder.setInsertionPoint(pbEntry); // Get _modify accessor argument values. // Accessor type : $(inout Self) -> @yields @inout Argument // Pullback type : $(inout Self', linear map tuple) -> @yields @inout Argument' // Normally pullbacks for semantic member accessors are single BB and // therefore have empty linear map tuple, however, coroutines have a branching // control flow due to possible coroutine abort, so we need to accommodate for // this. We keep branch tracing enums in order not to special case in many // other places. As there is no way to return to coroutine via abort exit, we // essentially "linearize" a coroutine. auto loweredFnTy = original.getLoweredFunctionType(); auto pullbackLoweredFnTy = pullback.getLoweredFunctionType(); assert(loweredFnTy->getNumParameters() == 1 && loweredFnTy->getNumYields() == 1); assert(pullbackLoweredFnTy->getNumParameters() == 2); assert(pullbackLoweredFnTy->getNumYields() == 1); SILValue origSelf = original.getArgumentsWithoutIndirectResults().front(); SmallVector origFormalResults; collectAllFormalResultsInTypeOrder(original, origFormalResults); assert(getConfig().resultIndices->getNumIndices() == 2 && "Modify accessor should have two semantic results"); auto origYield = origFormalResults[*std::next(getConfig().resultIndices->begin())]; // Look up the corresponding field in the tangent space. auto *origField = cast(accessor->getStorage()); auto baseType = remapType(origSelf->getType()).getASTType(); auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, pbLoc, getInvoker()); if (!tanField) { errorOccurred = true; return true; } auto adjSelf = getAdjointBuffer(origResumeBB, origSelf); auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField); // Modify accessors have inout yields and therefore should yield addresses. assert(getTangentValueCategory(origYield) == SILValueCategory::Address && "Modify accessors should yield indirect"); // Yield the adjoint buffer and do everything else in the resume // destination. Unwind destination is unreachable as the coroutine can never // be aborted. auto *unwindBB = getPullback().createBasicBlock(); auto *resumeBB = getPullbackBlock(origEntry); builder.createYield(yi->getLoc(), {adjSelfElt}, resumeBB, unwindBB); builder.setInsertionPoint(unwindBB); builder.createUnreachable(SILLocation::invalid()); builder.setInsertionPoint(resumeBB); addToAdjointBuffer(origEntry, origSelf, adjSelf, pbLoc); return false; } //--------------------------------------------------------------------------// // Adjoint buffer mapping //--------------------------------------------------------------------------// SILValue PullbackCloner::Implementation::getAdjointProjection( SILBasicBlock *origBB, SILValue originalProjection) { // Handle `struct_element_addr`. // Adjoint projection: a `struct_element_addr` into the base adjoint buffer. if (auto *seai = dyn_cast(originalProjection)) { assert(!seai->getField()->getAttrs().hasAttribute() && "`@noDerivative` struct projections should never be active"); auto adjSource = getAdjointBuffer(origBB, seai->getOperand()); auto structType = remapType(seai->getOperand()->getType()).getASTType(); auto *tanField = getTangentStoredProperty(getContext(), seai, structType, getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField); } // Handle `tuple_element_addr`. // Adjoint projection: a `tuple_element_addr` into the base adjoint buffer. if (auto *teai = dyn_cast(originalProjection)) { auto source = teai->getOperand(); auto adjSource = getAdjointBuffer(origBB, source); if (!adjSource->getType().is()) return adjSource; auto origTupleTy = remapType(source->getType()).castTo(); unsigned adjIndex = 0; for (unsigned i : range(teai->getFieldIndex())) { if (getTangentSpace( origTupleTy->getElement(i).getType()->getCanonicalType())) ++adjIndex; } return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex); } // Handle `ref_element_addr`. // Adjoint projection: a local allocation initialized with the corresponding // field value from the class's base adjoint value. if (auto *reai = dyn_cast(originalProjection)) { assert(!reai->getField()->getAttrs().hasAttribute() && "`@noDerivative` class projections should never be active"); auto loc = reai->getLoc(); // Get the class operand, stripping `begin_borrow`. auto classOperand = stripBorrow(reai->getOperand()); auto classType = remapType(reai->getOperand()->getType()).getASTType(); auto *tanField = getTangentStoredProperty(getContext(), reai->getField(), classType, reai->getLoc(), getInvoker()); assert(tanField && "Invalid projections should have been diagnosed"); // Create a local allocation for the element adjoint buffer. auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType(); auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType)); auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); // Check the class operand's `TangentVector` value category. switch (getTangentValueCategory(classOperand)) { case SILValueCategory::Object: { // Get the class operand's adjoint value. Currently, it must be a // `TangentVector` struct. auto adjClass = materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc); builder.emitScopedBorrowOperation( loc, adjClass, [&](SILValue borrowedAdjClass) { // Initialize the element adjoint buffer with the base adjoint // value. auto *adjElt = builder.createStructExtract(loc, borrowedAdjClass, tanField); auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt); builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer, StoreOwnershipQualifier::Init); }); return eltAdjBuffer; } case SILValueCategory::Address: { // Get the class operand's adjoint buffer. Currently, it must be a // `TangentVector` struct. auto adjClass = getAdjointBuffer(origBB, classOperand); // Initialize the element adjoint buffer with the base adjoint buffer. auto *adjElt = builder.createStructElementAddr(loc, adjClass, tanField); builder.createCopyAddr(loc, adjElt, eltAdjBuffer, IsNotTake, IsInitialization); return eltAdjBuffer; } } } // Handle `begin_access`. // Adjoint projection: the base adjoint buffer itself. if (auto *bai = dyn_cast(originalProjection)) { auto adjBase = getAdjointBuffer(origBB, bai->getOperand()); if (errorOccurred) return (bufferMap[{origBB, originalProjection}] = SILValue()); // Return the base buffer's adjoint buffer. return adjBase; } // Handle `array.uninitialized_intrinsic` application element addresses. // Adjoint projection: a local allocation initialized by applying // `Array.TangentVector.subscript` to the base array's adjoint value. auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress(originalProjection); auto *definingInst = dyn_cast_or_null( originalProjection->getDefiningInstruction()); bool isAllocateUninitializedArrayIntrinsicElementAddress = ai && definingInst && (isa(definingInst) || isa(definingInst)); if (isAllocateUninitializedArrayIntrinsicElementAddress) { // Get the array element index of the result address. int eltIndex = 0; if (auto *iai = dyn_cast(definingInst)) { auto *ili = cast(iai->getIndex()); eltIndex = ili->getValue().getLimitedValue(); } // Get the array adjoint value. SILValue arrayValue = getArrayValue(ai); SILValue arrayAdjoint = materializeAdjointDirect( getAdjointValue(origBB, arrayValue), definingInst->getLoc()); // Apply `Array.TangentVector.subscript` to get array element adjoint value. auto *eltAdjBuffer = getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc()); return eltAdjBuffer; } return SILValue(); } //----------------------------------------------------------------------------// // Adjoint value accumulation //----------------------------------------------------------------------------// AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect( AdjointValue lhs, AdjointValue rhs, SILLocation loc) { LLVM_DEBUG(getADDebugStream() << "Accumulating adjoint directly.\nLHS: " << lhs << "\nRHS: " << rhs << '\n'); switch (lhs.getKind()) { // x case AdjointValueKind::Concrete: { auto lhsVal = lhs.getConcreteValue(); switch (rhs.getKind()) { // x + y case AdjointValueKind::Concrete: { auto rhsVal = rhs.getConcreteValue(); auto sum = recordTemporary(builder.emitAdd(loc, lhsVal, rhsVal)); return makeConcreteAdjointValue(sum); } // x + 0 => x case AdjointValueKind::Zero: return lhs; // x + (y, z) => (x.0 + y, x.1 + z) case AdjointValueKind::Aggregate: { SmallVector newElements; auto lhsTy = lhsVal->getType().getASTType(); auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal); if (lhsTy->is()) { auto elts = builder.createDestructureTuple(loc, lhsValCopy); llvm::for_each(elts->getResults(), [this](SILValue result) { recordTemporary(result); }); for (auto i : indices(elts->getResults())) { auto rhsElt = rhs.getAggregateElement(i); newElements.push_back(accumulateAdjointsDirect( makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc)); } } else if (lhsTy->getStructOrBoundGenericStruct()) { auto elts = builder.createDestructureStruct(lhsVal.getLoc(), lhsValCopy); llvm::for_each(elts->getResults(), [this](SILValue result) { recordTemporary(result); }); for (unsigned i : indices(elts->getResults())) { auto rhsElt = rhs.getAggregateElement(i); newElements.push_back(accumulateAdjointsDirect( makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc)); } } else { llvm_unreachable("Not an aggregate type"); } return makeAggregateAdjointValue(lhsVal->getType(), newElements); } // x + (baseAdjoint, index, eltToAdd) => (x+baseAdjoint, index, eltToAdd) case AdjointValueKind::AddElement: { auto *addElementValue = rhs.getAddElementValue(); auto baseAdjoint = addElementValue->baseAdjoint; auto eltToAdd = addElementValue->eltToAdd; auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc); return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd, addElementValue->fieldLocator); } } } // 0 case AdjointValueKind::Zero: // 0 + x => x return rhs; // (x, y) case AdjointValueKind::Aggregate: { switch (rhs.getKind()) { // (x, y) + z => (z.0 + x, z.1 + y) case AdjointValueKind::Concrete: return accumulateAdjointsDirect(rhs, lhs, loc); // x + 0 => x case AdjointValueKind::Zero: return lhs; // (x, y) + (z, w) => (x + z, y + w) case AdjointValueKind::Aggregate: { SmallVector newElements; for (auto i : range(lhs.getNumAggregateElements())) newElements.push_back(accumulateAdjointsDirect( lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc)); return makeAggregateAdjointValue(lhs.getType(), newElements); } // (x.0, ..., x.n) + (baseAdjoint, index, eltToAdd) => (x + baseAdjoint, // index, eltToAdd) case AdjointValueKind::AddElement: { auto *addElementValue = rhs.getAddElementValue(); auto baseAdjoint = addElementValue->baseAdjoint; auto eltToAdd = addElementValue->eltToAdd; auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc); return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd, addElementValue->fieldLocator); } } } // (baseAdjoint, index, eltToAdd) case AdjointValueKind::AddElement: { switch (rhs.getKind()) { case AdjointValueKind::Zero: return lhs; // (baseAdjoint, index, eltToAdd) + x => (x + baseAdjoint, index, eltToAdd) case AdjointValueKind::Concrete: // (baseAdjoint, index, eltToAdd) + (x.0, ..., x.n) => (x + baseAdjoint, // index, eltToAdd) case AdjointValueKind::Aggregate: return accumulateAdjointsDirect(rhs, lhs, loc); // (baseAdjoint1, index1, eltToAdd1) + (baseAdjoint2, index2, eltToAdd2) // => ((baseAdjoint1 + baseAdjoint2, index1, eltToAdd1), index2, eltToAdd2) case AdjointValueKind::AddElement: { auto *addElementValueLhs = lhs.getAddElementValue(); auto baseAdjointLhs = addElementValueLhs->baseAdjoint; auto eltToAddLhs = addElementValueLhs->eltToAdd; auto *addElementValueRhs = rhs.getAddElementValue(); auto baseAdjointRhs = addElementValueRhs->baseAdjoint; auto eltToAddRhs = addElementValueRhs->eltToAdd; auto sumOfBaseAdjoints = accumulateAdjointsDirect(baseAdjointLhs, baseAdjointRhs, loc); auto newBaseAdjoint = makeAddElementAdjointValue( sumOfBaseAdjoints, eltToAddLhs, addElementValueLhs->fieldLocator); return makeAddElementAdjointValue(newBaseAdjoint, eltToAddRhs, addElementValueRhs->fieldLocator); } } } } llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715 } //----------------------------------------------------------------------------// // Array literal initialization differentiation //----------------------------------------------------------------------------// void PullbackCloner::Implementation:: accumulateArrayLiteralElementAddressAdjoints(SILBasicBlock *origBB, SILValue originalValue, AdjointValue arrayAdjointValue, SILLocation loc) { // Return if the original value is not the `Array` result of an // `array.uninitialized_intrinsic` application. auto *dti = dyn_cast_or_null( originalValue->getDefiningInstruction()); if (!dti) return; if (!ArraySemanticsCall(dti->getOperand(), semantics::ARRAY_UNINITIALIZED_INTRINSIC)) return; if (originalValue != dti->getResult(0)) return; // Accumulate the array's adjoint value into the adjoint buffers of its // element addresses: `pointer_to_address` and (optionally) `index_addr` // instructions. // The input code looks like as follows: // %17 = integer_literal $Builtin.Word, 1 // function_ref _allocateUninitializedArray(_:) // %18 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer) // %19 = apply %18(%17) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer) // (%20, %21) = destructure_tuple %19 // %22 = mark_dependence %21 on %20 // %23 = pointer_to_address %22 to [strict] $*Float // store %0 to [trivial] %23 // function_ref _finalizeUninitializedArray(_:) // %25 = function_ref @$ss27_finalizeUninitializedArrayySayxGABnlF : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // %26 = apply %25(%20) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // user: %27 // Note that %20 and %21 are in some sense "aliases" for each other. Here our `originalValue` is %20 in the code above. // We need to trace from %21 down to %23 and propagate (decomposed) adjoint of originalValue to adjoint of %23. // Then the generic adjoint propagation code would do its job to propagate %23' to %0'. // If we're initializing multiple values we're having additional `index_addr` instructions, but // the handling is similar. LLVM_DEBUG(getADDebugStream() << "Accumulating adjoint value for array literal into element " "address adjoint buffers" << originalValue); auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc); builder.setCurrentDebugScope(remapScope(dti->getDebugScope())); for (auto use : dti->getResult(1)->getUses()) { auto *mdi = dyn_cast(use->getUser()); assert(mdi && "Expected mark_dependence user"); auto *ptai = dyn_cast_or_null(getSingleNonDebugUser(mdi)); assert(ptai && "Expected pointer_to_address user"); auto adjBuf = getAdjointBuffer(origBB, ptai); auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc); builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf); for (auto use : ptai->getUses()) { if (auto *iai = dyn_cast(use->getUser())) { auto *ili = cast(iai->getIndex()); auto eltIndex = ili->getValue().getLimitedValue(); auto adjBuf = getAdjointBuffer(origBB, iai); auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc); builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf); } } } } AllocStackInst *PullbackCloner::Implementation::getArrayAdjointElementBuffer( SILValue arrayAdjoint, int eltIndex, SILLocation loc) { auto &ctx = builder.getASTContext(); auto arrayTanType = cast(arrayAdjoint->getType().getASTType()); auto arrayType = arrayTanType->getParent()->castTo(); auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType(); auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType)); // Get `function_ref` and generic signature of // `Array.TangentVector.subscript.getter`. auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct(); auto subscriptLookup = arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript()); SubscriptDecl *subscriptDecl = nullptr; for (auto *candidate : subscriptLookup) { auto candidateModule = candidate->getModuleContext(); if (candidateModule->getName() == ctx.Id_Differentiation || candidateModule->isStdlibModule()) { assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s"); subscriptDecl = cast(candidate); #ifdef NDEBUG break; #endif } } assert(subscriptDecl && "No `Array.TangentVector.subscript`"); auto *subscriptGetterDecl = subscriptDecl->getOpaqueAccessor(AccessorKind::Get); assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter"); SILOptFunctionBuilder fb(getContext().getTransform()); auto *subscriptGetterFn = fb.getOrCreateFunction( loc, SILDeclRef(subscriptGetterDecl), NotForDefinition); // %subscript_fn = function_ref @Array.TangentVector.subscript.getter auto *subscriptFnRef = builder.createFunctionRef(loc, subscriptGetterFn); auto subscriptFnGenSig = subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature(); // Apply `Array.TangentVector.subscript.getter` to get array element adjoint // buffer. // %index_literal = integer_literal $Builtin.IntXX, auto builtinIntType = SILType::getPrimitiveObjectType(ctx.getIntDecl() ->getStoredProperties() .front() ->getInterfaceType() ->getCanonicalType()); auto *eltIndexLiteral = builder.createIntegerLiteral(loc, builtinIntType, eltIndex); auto intType = SILType::getPrimitiveObjectType( ctx.getIntType()->getCanonicalType()); // %index_int = struct $Int (%index_literal) auto *eltIndexInt = builder.createStruct(loc, intType, {eltIndexLiteral}); auto *diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable); auto diffConf = lookupConformance(eltTanType, diffProto); assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); auto *addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic); auto addArithConf = lookupConformance(eltTanType, addArithProto); assert(!addArithConf.isInvalid() && "Missing conformance to `AdditiveArithmetic`"); auto subMap = SubstitutionMap::get(subscriptFnGenSig, {eltTanType}, {addArithConf, diffConf}); // %elt_adj = alloc_stack $T.TangentVector // Create and register a local allocation. auto *eltAdjBuffer = createFunctionLocalAllocation( eltTanSILType, loc, /*zeroInitialize*/ true); // Immediately destroy the emitted zero value. // NOTE: It is not efficient to emit a zero value then immediately destroy // it. However, it was the easiest way to to avoid "lifetime mismatch in // predecessors" memory lifetime verification errors for control flow // differentiation. // Perhaps we can avoid emitting a zero value if local allocations are created // per pullback bb instead of all in the pullback entry: TF-1075. builder.emitDestroyOperation(loc, eltAdjBuffer); // apply %subscript_fn(%elt_adj, %index_int, %array_adj) builder.createApply(loc, subscriptFnRef, subMap, {eltAdjBuffer, eltIndexInt, arrayAdjoint}); return eltAdjBuffer; } } // end namespace autodiff } // end namespace swift