//===--- VJPCloner.cpp - VJP function generation --------------*- C++ -*---===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file defines a helper class for generating VJP functions for automatic // differentiation. // //===----------------------------------------------------------------------===// #define DEBUG_TYPE "differentiation" #include "swift/AST/Types.h" #include "swift/Basic/Assertions.h" #include "swift/SILOptimizer/Differentiation/VJPCloner.h" #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" #include "swift/SILOptimizer/Differentiation/Thunk.h" #include "swift/SIL/TerminatorUtils.h" #include "swift/SIL/TypeSubstCloner.h" #include "swift/SILOptimizer/Analysis/LoopAnalysis.h" #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" #include "swift/SILOptimizer/Utils/CFGOptUtils.h" #include "swift/SILOptimizer/Utils/DifferentiationMangler.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" #include "llvm/ADT/DenseMap.h" namespace swift { namespace autodiff { class VJPCloner::Implementation final : public TypeSubstCloner { friend class VJPCloner; friend class PullbackCloner; /// The parent VJP cloner. VJPCloner &cloner; /// The global context. ADContext &context; /// The original function. SILFunction *const original; /// The differentiability witness. SILDifferentiabilityWitness *const witness; /// The VJP function. SILFunction *const vjp; /// The pullback function. SILFunction *pullback; /// The differentiation invoker. DifferentiationInvoker invoker; /// Info from activity analysis on the original function. const DifferentiableActivityInfo &activityInfo; /// The loop info. SILLoopInfo *loopInfo; /// The linear map info. LinearMapInfo pullbackInfo; /// Caches basic blocks whose phi arguments have been remapped (adding a /// predecessor enum argument). SmallPtrSet remappedBasicBlocks; /// The `AutoDiffLinearMapContext` object. If null, no explicit context is /// needed (no loops). SILValue pullbackContextValue; /// The unique, borrowed context object. This is valid until the exit block. SILValue borrowedPullbackContextValue; /// The generic signature of the `Builtin.autoDiffAllocateSubcontext(_:_:)` /// declaration. It is used for creating a builtin call. GenericSignature builtinAutoDiffAllocateSubcontextGenericSignature; bool errorOccurred = false; /// Mapping from original blocks to pullback values. Used to build pullback /// struct instances. llvm::DenseMap> pullbackValues; ASTContext &getASTContext() const { return vjp->getASTContext(); } SILModule &getModule() const { return vjp->getModule(); } const AutoDiffConfig &getConfig() const { return witness->getConfig(); } Implementation(VJPCloner &parent, ADContext &context, SILDifferentiabilityWitness *witness, SILFunction *vjp, DifferentiationInvoker invoker); /// Creates an empty pullback function, to be filled in by `PullbackCloner`. SILFunction *createEmptyPullback(); /// Run VJP generation. Returns true on error. bool run(); /// Initializes a context object if needed. void emitLinearMapContextInitializationIfNeeded() { if (!pullbackInfo.hasHeapAllocatedContext()) return; // Get linear map struct size. auto *returnBB = &*original->findReturnBB(); auto pullbackTupleType = remapASTType(pullbackInfo.getLinearMapTupleType(returnBB)->getCanonicalType()); Builder.setInsertionPoint(vjp->getEntryBlock()); auto pbTupleMetatypeType = CanMetatypeType::get(pullbackTupleType, MetatypeRepresentation::Thick); auto pbTupleMetatypeSILType = SILType::getPrimitiveObjectType(pbTupleMetatypeType); auto pbTupleMetatype = Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType); // Create an context. pullbackContextValue = Builder.createBuiltin( original->getLocation(), getASTContext().getIdentifier(getBuiltinName( BuiltinValueKind::AutoDiffCreateLinearMapContextWithType)), SILType::getNativeObjectType(getASTContext()), SubstitutionMap(), {pbTupleMetatype}); borrowedPullbackContextValue = Builder.createBeginBorrow( original->getLocation(), pullbackContextValue); LLVM_DEBUG(getADDebugStream() << "Context object initialized because there are loops\n" << *vjp->getEntryBlock() << '\n' << "pullback tuple type: " << pullbackTupleType << '\n'); } /// Get the lowered SIL type of the given AST type. SILType getLoweredType(Type type) { auto vjpGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature(); Lowering::AbstractionPattern pattern(vjpGenSig, type->getReducedType(vjpGenSig)); return vjp->getLoweredType(pattern, type); } SILType getPullbackType() { auto vjpFuncTy = vjp->getLoweredFunctionType(); const auto &conv = vjp->getConventions(); return conv.getSILType(vjpFuncTy->getResults().back(), vjp->getTypeExpansionContext()); } GenericSignature getBuiltinAutoDiffAllocateSubcontextDecl() { if (builtinAutoDiffAllocateSubcontextGenericSignature) return builtinAutoDiffAllocateSubcontextGenericSignature; auto &ctx = getASTContext(); auto *decl = cast(getBuiltinValueDecl( ctx, ctx.getIdentifier(getBuiltinName( BuiltinValueKind::AutoDiffAllocateSubcontextWithType)))); builtinAutoDiffAllocateSubcontextGenericSignature = decl->getGenericSignature(); assert(builtinAutoDiffAllocateSubcontextGenericSignature); return builtinAutoDiffAllocateSubcontextGenericSignature; } // Creates a trampoline block for given original terminator instruction, the // pullback struct value for its parent block, and a successor basic block. // // The trampoline block has the same arguments as and branches to the remapped // successor block, but drops the last predecessor enum argument. // // Used for cloning branching terminator instructions with specific // requirements on successor block arguments, where an additional predecessor // enum argument is not acceptable. SILBasicBlock *createTrampolineBasicBlock(TermInst *termInst, TupleInst *pbTupleVal, SILBasicBlock *succBB); /// Build a pullback tuple value for the given original terminator /// instruction. TupleInst *buildPullbackValueTupleValue(TermInst *termInst); llvm::SmallVector getPullbackValues(SILBasicBlock *origBB); /// Build a predecessor enum instance using the given builder for the given /// original predecessor/successor blocks and pullback struct value. EnumInst *buildPredecessorEnumValue(SILBuilder &builder, SILBasicBlock *predBB, SILBasicBlock *succBB, SILValue pbTupleVal); public: /// Remap original basic blocks, adding predecessor enum arguments. SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) { auto *vjpBB = BBMap[bb]; // If error has occurred, or if block has already been remapped, return // remapped, return remapped block. if (errorOccurred || remappedBasicBlocks.count(bb)) return vjpBB; // Add predecessor enum argument to the remapped block. auto *predEnum = pullbackInfo.getBranchingTraceDecl(bb); auto enumTy = getOpASTType(predEnum->getDeclaredInterfaceType()->getCanonicalType()); auto enumLoweredTy = context.getTypeConverter().getLoweredType( enumTy, TypeExpansionContext::minimal()); vjpBB->createPhiArgument(enumLoweredTy, OwnershipKind::Owned); remappedBasicBlocks.insert(bb); return vjpBB; } /// General visitor for all instructions. If any error is emitted by previous /// visits, bail out. void visit(SILInstruction *inst) { if (errorOccurred) return; TypeSubstCloner::visit(inst); } void visitSILInstruction(SILInstruction *inst) { context.emitNondifferentiabilityError( inst, invoker, diag::autodiff_expression_not_differentiable_note); errorOccurred = true; } void postProcess(SILInstruction *orig, SILInstruction *cloned) { if (errorOccurred) return; SILClonerWithScopes::postProcess(orig, cloned); } void visitReturnInst(ReturnInst *ri) { Builder.setCurrentDebugScope(getOpScope(ri->getDebugScope())); auto loc = ri->getOperand().getLoc(); // Build pullback tuple value for original block. auto *origExit = ri->getParent(); // Get the value in the VJP corresponding to the original result. auto *origRetInst = cast(origExit->getTerminator()); auto origResult = getOpValue(origRetInst->getOperand()); SmallVector origResults; extractAllElements(origResult, Builder, origResults); // Get and partially apply the pullback. auto vjpSubstMap = vjp->getForwardingSubstitutionMap(); auto *pullbackRef = Builder.createFunctionRef(loc, pullback); // Prepare partial application arguments. SILValue partialApplyArg; PartialApplyInst *pullbackPartialApply; if (borrowedPullbackContextValue) { auto *pbTupleVal = buildPullbackValueTupleValue(ri); // Initialize the top-level subcontext buffer with the top-level pullback // tuple. auto addr = emitProjectTopLevelSubcontext( Builder, loc, borrowedPullbackContextValue, pbTupleVal->getType()); Builder.createStore( loc, pbTupleVal, addr, pbTupleVal->getType().isTrivial(*pullback) ? StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init); Builder.createEndBorrow(loc, borrowedPullbackContextValue); pullbackPartialApply = Builder.createPartialApply( loc, pullbackRef, vjpSubstMap, {pullbackContextValue}, ParameterConvention::Direct_Guaranteed); } else { pullbackPartialApply = Builder.createPartialApply( loc, pullbackRef, vjpSubstMap, getPullbackValues(origExit), ParameterConvention::Direct_Guaranteed); } auto pullbackType = vjp->mapTypeIntoContext(getPullbackType()); auto pullbackFnType = pullbackType.castTo(); auto pullbackSubstType = pullbackPartialApply->getType().castTo(); // If necessary, convert the pullback value to the returned pullback // function type. SILValue pullbackValue; if (pullbackSubstType == pullbackFnType) { pullbackValue = pullbackPartialApply; } else if (pullbackSubstType->isABICompatibleWith(pullbackFnType, *vjp) .isCompatible()) { pullbackValue = Builder.createConvertFunction(loc, pullbackPartialApply, pullbackType, /*withoutActuallyEscaping*/ false); } else { llvm::report_fatal_error("Pullback value type is not ABI-compatible " "with the returned pullback type"); } // Return a tuple of the original result and pullback. SmallVector directResults; directResults.append(origResults.begin(), origResults.end()); directResults.push_back(pullbackValue); Builder.createReturn(ri->getLoc(), joinElements(directResults, Builder, loc)); } void visitUnwindInst(UnwindInst *ui) { Builder.setCurrentDebugScope(getOpScope(ui->getDebugScope())); auto loc = ui->getLoc(); auto *origExit = ui->getParent(); // Consume unused pullback values if (borrowedPullbackContextValue) { auto *pbTupleVal = buildPullbackValueTupleValue(ui); // Initialize the top-level subcontext buffer with the top-level pullback // tuple. auto addr = emitProjectTopLevelSubcontext( Builder, loc, borrowedPullbackContextValue, pbTupleVal->getType()); Builder.createStore( loc, pbTupleVal, addr, pbTupleVal->getType().isTrivial(*pullback) ? StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init); Builder.createEndBorrow(loc, borrowedPullbackContextValue); Builder.emitDestroyValueOperation(loc, pullbackContextValue); } else { for (SILValue val : getPullbackValues(origExit)) Builder.emitDestroyValueOperation(loc, val); } Builder.createUnwind(loc); } void visitBranchInst(BranchInst *bi) { Builder.setCurrentDebugScope(getOpScope(bi->getDebugScope())); // Build pullback struct value for original block. // Build predecessor enum value for destination block. auto *origBB = bi->getParent(); auto *pbTupleVal = buildPullbackValueTupleValue(bi); auto *enumVal = buildPredecessorEnumValue(getBuilder(), origBB, bi->getDestBB(), pbTupleVal); // Remap arguments, appending the new enum values. SmallVector args; for (auto origArg : bi->getArgs()) args.push_back(getOpValue(origArg)); args.push_back(enumVal); // Create a new `br` instruction. getBuilder().createBranch(bi->getLoc(), getOpBasicBlock(bi->getDestBB()), args); } void visitYieldInst(YieldInst *yi) { Builder.setCurrentDebugScope(getOpScope(yi->getDebugScope())); // Build pullback struct value for original block. auto *pbTupleVal = buildPullbackValueTupleValue(yi); // Create a new `yield` instruction. Note that resume / unwind blocks cannot // have arguments, so we're building trampolines with branch tracing enum // values. getBuilder().createYield( yi->getLoc(), getOpValueArray<1>(yi->getOperandValues()), createTrampolineBasicBlock(yi, pbTupleVal, yi->getResumeBB()), createTrampolineBasicBlock(yi, pbTupleVal, yi->getUnwindBB())); } void visitCondBranchInst(CondBranchInst *cbi) { Builder.setCurrentDebugScope(getOpScope(cbi->getDebugScope())); // Build pullback struct value for original block. auto *pbTupleVal = buildPullbackValueTupleValue(cbi); // Create a new `cond_br` instruction. getBuilder().createCondBranch( cbi->getLoc(), getOpValue(cbi->getCondition()), createTrampolineBasicBlock(cbi, pbTupleVal, cbi->getTrueBB()), createTrampolineBasicBlock(cbi, pbTupleVal, cbi->getFalseBB())); } void visitSwitchEnumTermInst(SwitchEnumTermInst inst) { Builder.setCurrentDebugScope(getOpScope(inst->getDebugScope())); // Build pullback tuple value for original block. auto *pbTupleVal = buildPullbackValueTupleValue(*inst); // Create trampoline successor basic blocks. SmallVector, 4> caseBBs; for (unsigned i : range(inst.getNumCases())) { auto caseBB = inst.getCase(i); auto *trampolineBB = createTrampolineBasicBlock(inst, pbTupleVal, caseBB.second); caseBBs.push_back({caseBB.first, trampolineBB}); } // Create trampoline default basic block. SILBasicBlock *newDefaultBB = nullptr; if (auto *defaultBB = inst.getDefaultBBOrNull().getPtrOrNull()) newDefaultBB = createTrampolineBasicBlock(inst, pbTupleVal, defaultBB); // Create a new `switch_enum` instruction. switch (inst->getKind()) { case SILInstructionKind::SwitchEnumInst: getBuilder().createSwitchEnum( inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs); break; case SILInstructionKind::SwitchEnumAddrInst: getBuilder().createSwitchEnumAddr( inst->getLoc(), getOpValue(inst.getOperand()), newDefaultBB, caseBBs); break; default: llvm_unreachable("Expected `switch_enum` or `switch_enum_addr`"); } } void visitSwitchEnumInst(SwitchEnumInst *sei) { visitSwitchEnumTermInst(sei); } void visitSwitchEnumAddrInst(SwitchEnumAddrInst *seai) { visitSwitchEnumTermInst(seai); } void visitCheckedCastBranchInst(CheckedCastBranchInst *ccbi) { Builder.setCurrentDebugScope(getOpScope(ccbi->getDebugScope())); // Build pullback struct value for original block. auto *pbTupleVal = buildPullbackValueTupleValue(ccbi); // Create a new `checked_cast_branch` instruction. getBuilder().createCheckedCastBranch( ccbi->getLoc(), ccbi->isExact(), ccbi->getCheckedCastOptions(), getOpValue(ccbi->getOperand()), getOpASTType(ccbi->getSourceFormalType()), getOpType(ccbi->getTargetLoweredType()), getOpASTType(ccbi->getTargetFormalType()), createTrampolineBasicBlock(ccbi, pbTupleVal, ccbi->getSuccessBB()), createTrampolineBasicBlock(ccbi, pbTupleVal, ccbi->getFailureBB()), ccbi->getTrueBBCount(), ccbi->getFalseBBCount()); } void visitCheckedCastAddrBranchInst(CheckedCastAddrBranchInst *ccabi) { Builder.setCurrentDebugScope(getOpScope(ccabi->getDebugScope())); // Build pullback struct value for original block. auto *pbTupleVal = buildPullbackValueTupleValue(ccabi); // Create a new `checked_cast_addr_branch` instruction. getBuilder().createCheckedCastAddrBranch( ccabi->getLoc(), ccabi->getCheckedCastOptions(), ccabi->getConsumptionKind(), getOpValue(ccabi->getSrc()), getOpASTType(ccabi->getSourceFormalType()), getOpValue(ccabi->getDest()), getOpASTType(ccabi->getTargetFormalType()), createTrampolineBasicBlock(ccabi, pbTupleVal, ccabi->getSuccessBB()), createTrampolineBasicBlock(ccabi, pbTupleVal, ccabi->getFailureBB()), ccabi->getTrueBBCount(), ccabi->getFalseBBCount()); } void visitEndApplyInst(EndApplyInst *eai) { BeginApplyInst *bai = eai->getBeginApply(); // If callee should not be differentiated, do standard cloning. if (!pullbackInfo.shouldDifferentiateApplySite(bai)) { LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *bai << '\n'); TypeSubstCloner::visitEndApplyInst(eai); return; } Builder.setCurrentDebugScope(getOpScope(eai->getDebugScope())); auto loc = eai->getLoc(); auto &builder = getBuilder(); auto token = getMappedValue(bai->getTokenResult()); LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *eai << '\n'); FullApplySite fai(token->getDefiningInstruction()); auto vjpResult = builder.createEndApply(loc, token, fai.getType()); LLVM_DEBUG(getADDebugStream() << "Created end_apply\n" << *vjpResult); builder.emitDestroyValueOperation(loc, fai.getCallee()); // Checkpoint the pullback. SmallVector vjpDirectResults; extractAllElements(vjpResult, getBuilder(), vjpDirectResults); ArrayRef originalDirectResults = ArrayRef(vjpDirectResults).drop_back(1); SILValue originalDirectResult = joinElements(originalDirectResults, getBuilder(), loc); SILValue pullback = vjpDirectResults.back(); { auto pullbackFnType = pullback->getType().castTo(); auto pullbackUnsubstFnType = pullbackFnType->getUnsubstitutedType(getModule()); if (pullbackFnType != pullbackUnsubstFnType) { pullback = builder.createConvertFunction( loc, pullback, SILType::getPrimitiveObjectType(pullbackUnsubstFnType), /*withoutActuallyEscaping*/ false); } } // Store the original result to the value map. mapValue(eai, originalDirectResult); auto pullbackType = pullbackInfo.lookUpLinearMapType(bai); // If actual pullback type does not match lowered pullback type, reabstract // the pullback using a thunk. auto actualPullbackType = getOpType(pullback->getType()).getAs(); auto loweredPullbackType = getOpType(getLoweredType(pullbackType)).castTo(); auto applyInfoIt = context.getNestedApplyInfo().find(bai); assert(applyInfoIt != context.getNestedApplyInfo().end()); if (!loweredPullbackType->isEqual(actualPullbackType)) { // Set non-reabstracted original pullback type in nested apply info. applyInfoIt->second.originalPullbackType = actualPullbackType; SILOptFunctionBuilder fb(context.getTransform()); pullback = reabstractCoroutine( getBuilder(), fb, loc, pullback, loweredPullbackType, [this](SubstitutionMap subs) -> SubstitutionMap { return this->getOpSubstitutionMap(subs); }); } unsigned pullbackIdx = applyInfoIt->second.pullbackIdx; pullbackValues[bai->getParent()][pullbackIdx] = pullback; // Some instructions that produce the callee may have been cloned. // If the original callee did not have any users beyond this `apply`, // recursively kill the cloned callee. if (auto *origCallee = cast_or_null( bai->getCallee()->getDefiningInstruction())) if (origCallee->hasOneUse()) recursivelyDeleteTriviallyDeadInstructions( getOpValue(origCallee)->getDefiningInstruction()); } // Check and diagnose non-differentiable original function type. bool diagnoseNondifferentiableOriginalFunctionType(CanSILFunctionType originalFnTy, FullApplySite fai, SILValue origCallee, const AutoDiffConfig &config) const { // Check and diagnose non-differentiable arguments. for (auto paramIndex : config.parameterIndices->getIndices()) { if (!originalFnTy->getParameters()[paramIndex] .getSILStorageInterfaceType() .isDifferentiable(getModule())) { auto arg = fai.getArgumentsWithoutIndirectResults()[paramIndex]; // FIXME: This shouldn't be necessary and might indicate a bug in // the transformation. RegularLocation nonAutoGenLoc(arg.getLoc()); nonAutoGenLoc.markNonAutoGenerated(); auto startLoc = nonAutoGenLoc.getStartSourceLoc(); auto endLoc = nonAutoGenLoc.getEndSourceLoc(); context.emitNondifferentiabilityError( arg, invoker, diag::autodiff_nondifferentiable_argument) .fixItInsert(startLoc, "withoutDerivative(at: ") .fixItInsertAfter(endLoc, ")"); return true; } } // Check and diagnose non-differentiable results. unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults(); unsigned firstYieldResultIndex = originalFnTy->getNumResults() + originalFnTy->getNumAutoDiffSemanticResultsParameters(); for (auto resultIndex : config.resultIndices->getIndices()) { SILType remappedResultType; if (resultIndex >= firstYieldResultIndex) { auto yieldResultIdx = resultIndex - firstYieldResultIndex; const auto& yield = originalFnTy->getYields()[yieldResultIdx]; // We do not have a good way to differentiate direct yields if (yield.isAutoDiffSemanticResult()) remappedResultType = yield.getSILStorageInterfaceType(); else { context.emitNondifferentiabilityError( origCallee, invoker, diag::autodiff_cannot_differentiate_through_direct_yield); return true; } } else if (resultIndex >= firstSemanticParamResultIdx) { auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx; auto semanticResultArg = *std::next(fai.getAutoDiffSemanticResultArguments().begin(), semanticResultArgIdx); remappedResultType = semanticResultArg->getType(); } else { remappedResultType = originalFnTy->getResults()[resultIndex] .getSILStorageInterfaceType(); } if (!remappedResultType || !remappedResultType.isDifferentiable(getModule())) { auto startLoc = fai.getLoc().getStartSourceLoc(); auto endLoc = fai.getLoc().getEndSourceLoc(); context.emitNondifferentiabilityError( origCallee, invoker, diag::autodiff_nondifferentiable_result) .fixItInsert(startLoc, "withoutDerivative(at: ") .fixItInsertAfter(endLoc, ")"); return true; } } return false; } void visitBeginApplyInst(BeginApplyInst *bai) { // If callee should not be differentiated, do standard cloning. if (!pullbackInfo.shouldDifferentiateApplySite(bai)) { LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *bai << '\n'); TypeSubstCloner::visitBeginApplyInst(bai); return; } Builder.setCurrentDebugScope(getOpScope(bai->getDebugScope())); auto loc = bai->getLoc(); auto &builder = getBuilder(); auto origCallee = getOpValue(bai->getCallee()); auto originalFnTy = origCallee->getType().castTo(); LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *bai << '\n'); SmallVector allResults; SmallVector activeParamIndices; SmallVector activeResultIndices; collectMinimalIndicesForFunctionCall(bai, getConfig(), activityInfo, allResults, activeParamIndices, activeResultIndices); assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); assert(!activeResultIndices.empty() && "Result indices cannot be empty"); LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=("; llvm::interleave( activeParamIndices.begin(), activeParamIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); s << "), results=("; llvm::interleave( activeResultIndices.begin(), activeResultIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); s << ")\n";); // Form expected indices. AutoDiffConfig config( IndexSubset::get(getASTContext(), bai->getArgumentsWithoutIndirectResults().size(), activeParamIndices), IndexSubset::get(getASTContext(), bai->getSubstCalleeType()->getNumAutoDiffSemanticResults(), activeResultIndices)); if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, bai, origCallee, config)) { errorOccurred = true; return; } // Emit the VJP. SILValue vjpValue; // If the original `apply` instruction has a substitution map, then the // applied function is specialized. // In the VJP, specialization is also necessary for parity. The original // function operand is specialized with a remapped version of same // substitution map using an argument-less `partial_apply`. if (bai->getSubstitutionMap().empty()) { origCallee = builder.emitCopyValueOperation(loc, origCallee); } else { auto substMap = getOpSubstitutionMap(bai->getSubstitutionMap()); auto vjpPartialApply = getBuilder().createPartialApply( bai->getLoc(), origCallee, substMap, {}, ParameterConvention::Direct_Guaranteed); origCallee = vjpPartialApply; originalFnTy = origCallee->getType().castTo(); // Diagnose if new original function type is non-differentiable. if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, bai, origCallee, config)) { errorOccurred = true; return; } } auto *diffFuncInst = context.createDifferentiableFunction(getBuilder(), loc, config.parameterIndices, config.resultIndices, origCallee); // Record the `differentiable_function` instruction. context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst); builder.emitScopedBorrowOperation( loc, diffFuncInst, [&](SILValue borrowedADFunc) { auto extractedVJP = getBuilder().createDifferentiableFunctionExtract( loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedADFunc); vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); }); builder.emitDestroyValueOperation(loc, diffFuncInst); // Record desired/actual VJP indices. // Temporarily set original pullback type to `None`. NestedApplyInfo info{config, /*originalPullbackType*/ std::nullopt}; auto insertion = context.getNestedApplyInfo().try_emplace(bai, info); auto &nestedApplyInfo = insertion.first->getSecond(); nestedApplyInfo = info; // Call the VJP using the original parameters. SmallVector vjpArgs; auto vjpFnTy = getOpType(vjpValue->getType()).castTo(); auto numVJPArgs = vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults(); vjpArgs.reserve(numVJPArgs); // Collect substituted arguments. for (auto origArg : bai->getArguments()) vjpArgs.push_back(getOpValue(origArg)); // Apply the VJP. // The VJP should be specialized, so no substitution map is necessary. auto *vjpCall = getBuilder().createBeginApply(loc, vjpValue, SubstitutionMap(), vjpArgs, bai->getApplyOptions()); LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall); // Note that vjpValue is destroyed after end_apply // Store all the results (yields and token) to the value map. assert(bai->getNumResults() == vjpCall->getNumResults()); for (unsigned i = 0; i < vjpCall->getNumResults(); ++i) mapValue(bai->getResult(i), vjpCall->getResult(i)); // Checkpoint the pullback. nestedApplyInfo.pullbackIdx = pullbackValues[bai->getParent()].size(); pullbackValues[bai->getParent()].push_back(SILValue()); // The rest of the cloning magic happens during `end_apply` cloning. } // If an `apply` has active results or active inout arguments, replace it // with an `apply` of its VJP. void visitApplyInst(ApplyInst *ai) { // If callee should not be differentiated, do standard cloning. if (!pullbackInfo.shouldDifferentiateApplySite(ai)) { LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); TypeSubstCloner::visitApplyInst(ai); return; } // If callee is `array.uninitialized_intrinsic`, do standard cloning. // `array.uninitialized_intrinsic` differentiation is handled separately. if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) { LLVM_DEBUG(getADDebugStream() << "Cloning `array.uninitialized_intrinsic` `apply`:\n" << *ai << '\n'); TypeSubstCloner::visitApplyInst(ai); return; } // If callee is `array.finalize_intrinsic`, do standard cloning. // `array.finalize_intrinsic` has special-case pullback generation. if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) { LLVM_DEBUG(getADDebugStream() << "Cloning `array.finalize_intrinsic` `apply`:\n" << *ai << '\n'); TypeSubstCloner::visitApplyInst(ai); return; } // If the original function is a semantic member accessor, do standard // cloning. Semantic member accessors have special pullback generation // logic, so all `apply` instructions can be directly cloned to the VJP. if (isSemanticMemberAccessor(original)) { LLVM_DEBUG(getADDebugStream() << "Cloning `apply` in semantic member accessor:\n" << *ai << '\n'); TypeSubstCloner::visitApplyInst(ai); return; } Builder.setCurrentDebugScope(getOpScope(ai->getDebugScope())); auto loc = ai->getLoc(); auto &builder = getBuilder(); auto origCallee = getOpValue(ai->getCallee()); auto originalFnTy = origCallee->getType().castTo(); LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n'); // Get the minimal parameter and result indices required for differentiating // this `apply`. SmallVector allResults; SmallVector activeParamIndices; SmallVector activeResultIndices; collectMinimalIndicesForFunctionCall(ai, getConfig(), activityInfo, allResults, activeParamIndices, activeResultIndices); assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); assert(!activeResultIndices.empty() && "Result indices cannot be empty"); LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=("; llvm::interleave( activeParamIndices.begin(), activeParamIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); s << "), results=("; llvm::interleave( activeResultIndices.begin(), activeResultIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); s << ")\n";); // Form expected indices. AutoDiffConfig config( IndexSubset::get(getASTContext(), ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices), IndexSubset::get(getASTContext(), ai->getSubstCalleeType()->getNumAutoDiffSemanticResults(), activeResultIndices)); // Emit the VJP. SILValue vjpValue; // If functionSource is a `@differentiable` function, just extract it. if (originalFnTy->isDifferentiable()) { auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices(); for (auto i : config.parameterIndices->getIndices()) { if (!paramIndices->contains(i)) { context.emitNondifferentiabilityError( origCallee, invoker, diag:: autodiff_function_noderivative_parameter_not_differentiable); errorOccurred = true; return; } } builder.emitScopedBorrowOperation( loc, origCallee, [&](SILValue borrowedDiffFunc) { auto origFnType = origCallee->getType().castTo(); auto origFnUnsubstType = origFnType->getUnsubstitutedType(getModule()); if (origFnType != origFnUnsubstType) { borrowedDiffFunc = builder.createConvertFunction( loc, borrowedDiffFunc, SILType::getPrimitiveObjectType(origFnUnsubstType), /*withoutActuallyEscaping*/ false); } vjpValue = builder.createDifferentiableFunctionExtract( loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedDiffFunc); vjpValue = builder.emitCopyValueOperation(loc, vjpValue); }); auto vjpFnType = vjpValue->getType().castTo(); auto vjpFnUnsubstType = vjpFnType->getUnsubstitutedType(getModule()); if (vjpFnType != vjpFnUnsubstType) { vjpValue = builder.createConvertFunction( loc, vjpValue, SILType::getPrimitiveObjectType(vjpFnUnsubstType), /*withoutActuallyEscaping*/ false); } } if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, ai, origCallee, config)) { errorOccurred = true; return; } // If VJP has not yet been found, emit an `differentiable_function` // instruction on the remapped original function operand and // an `differentiable_function_extract` instruction to get the VJP. // The `differentiable_function` instruction will be canonicalized during // the transform main loop. if (!vjpValue) { // FIXME: Handle indirect differentiation invokers. This may require some // redesign: currently, each original function + witness pair is mapped // only to one invoker. /* DifferentiationInvoker indirect(ai, attr); auto insertion = context.getInvokers().try_emplace({original, attr}, indirect); auto &invoker = insertion.first->getSecond(); invoker = indirect; */ // If the original `apply` instruction has a substitution map, then the // applied function is specialized. // In the VJP, specialization is also necessary for parity. The original // function operand is specialized with a remapped version of same // substitution map using an argument-less `partial_apply`. if (ai->getSubstitutionMap().empty()) { origCallee = builder.emitCopyValueOperation(loc, origCallee); } else { auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); auto vjpPartialApply = getBuilder().createPartialApply( ai->getLoc(), origCallee, substMap, {}, ParameterConvention::Direct_Guaranteed); origCallee = vjpPartialApply; originalFnTy = origCallee->getType().castTo(); // Diagnose if new original function type is non-differentiable. if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, ai, origCallee, config)) { errorOccurred = true; return; } } auto *diffFuncInst = context.createDifferentiableFunction( getBuilder(), loc, config.parameterIndices, config.resultIndices, origCallee); // Record the `differentiable_function` instruction. context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst); builder.emitScopedBorrowOperation( loc, diffFuncInst, [&](SILValue borrowedADFunc) { auto extractedVJP = getBuilder().createDifferentiableFunctionExtract( loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedADFunc); vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); }); builder.emitDestroyValueOperation(loc, diffFuncInst); } // Record desired/actual VJP indices. // Temporarily set original pullback type to `None`. NestedApplyInfo info{config, /*originalPullbackType*/ std::nullopt}; auto insertion = context.getNestedApplyInfo().try_emplace(ai, info); auto &nestedApplyInfo = insertion.first->getSecond(); nestedApplyInfo = info; // Call the VJP using the original parameters. SmallVector vjpArgs; auto vjpFnTy = getOpType(vjpValue->getType()).castTo(); auto numVJPArgs = vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults(); vjpArgs.reserve(numVJPArgs); // Collect substituted arguments. for (auto origArg : ai->getArguments()) vjpArgs.push_back(getOpValue(origArg)); assert(vjpArgs.size() == numVJPArgs); // Apply the VJP. // The VJP should be specialized, so no substitution map is necessary. auto *vjpCall = getBuilder().createApply(loc, vjpValue, SubstitutionMap(), vjpArgs, ai->getApplyOptions()); LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall); builder.emitDestroyValueOperation(loc, vjpValue); // Get the VJP results (original results and pullback). SmallVector vjpDirectResults; extractAllElements(vjpCall, getBuilder(), vjpDirectResults); ArrayRef originalDirectResults = ArrayRef(vjpDirectResults).drop_back(1); SILValue originalDirectResult = joinElements(originalDirectResults, getBuilder(), vjpCall->getLoc()); SILValue pullback = vjpDirectResults.back(); { auto pullbackFnType = pullback->getType().castTo(); auto pullbackUnsubstFnType = pullbackFnType->getUnsubstitutedType(getModule()); if (pullbackFnType != pullbackUnsubstFnType) { pullback = builder.createConvertFunction( loc, pullback, SILType::getPrimitiveObjectType(pullbackUnsubstFnType), /*withoutActuallyEscaping*/ false); } } // Store the original result to the value map. mapValue(ai, originalDirectResult); // Checkpoint the pullback. auto pullbackType = pullbackInfo.lookUpLinearMapType(ai); // If actual pullback type does not match lowered pullback type, reabstract // the pullback using a thunk. auto actualPullbackType = getOpType(pullback->getType()).getAs(); auto loweredPullbackType = getOpType(getLoweredType(pullbackType)).castTo(); if (!loweredPullbackType->isEqual(actualPullbackType)) { // Set non-reabstracted original pullback type in nested apply info. nestedApplyInfo.originalPullbackType = actualPullbackType; SILOptFunctionBuilder fb(context.getTransform()); pullback = reabstractFunction( getBuilder(), fb, ai->getLoc(), pullback, loweredPullbackType, [this](SubstitutionMap subs) -> SubstitutionMap { return this->getOpSubstitutionMap(subs); }); } nestedApplyInfo.pullbackIdx = pullbackValues[ai->getParent()].size(); pullbackValues[ai->getParent()].push_back(pullback); // Some instructions that produce the callee may have been cloned. // If the original callee did not have any users beyond this `apply`, // recursively kill the cloned callee. if (auto *origCallee = cast_or_null( ai->getCallee()->getDefiningInstruction())) if (origCallee->hasOneUse()) recursivelyDeleteTriviallyDeadInstructions( getOpValue(origCallee)->getDefiningInstruction()); } void visitTryApplyInst(TryApplyInst *tai) { Builder.setCurrentDebugScope(getOpScope(tai->getDebugScope())); // If callee should not be differentiated, do standard cloning. if (!pullbackInfo.shouldDifferentiateApplySite(tai)) { LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *tai << '\n'); // Build pullback struct value for original block. auto *pbTupleVal = buildPullbackValueTupleValue(tai); // Create a new `try_apply` instruction. auto args = getOpValueArray<8>(tai->getArguments()); getBuilder().createTryApply( tai->getLoc(), getOpValue(tai->getCallee()), getOpSubstitutionMap(tai->getSubstitutionMap()), args, createTrampolineBasicBlock(tai, pbTupleVal, tai->getNormalBB()), createTrampolineBasicBlock(tai, pbTupleVal, tai->getErrorBB()), tai->getApplyOptions()); return; } auto loc = tai->getLoc(); auto &builder = getBuilder(); auto origCallee = getOpValue(tai->getCallee()); auto originalFnTy = origCallee->getType().castTo(); auto *origBB = tai->getParent(); LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *tai << '\n'); // Get the minimal parameter and result indices required for differentiating // this `apply`. SmallVector allResults; SmallVector activeParamIndices; SmallVector activeResultIndices; collectMinimalIndicesForFunctionCall(tai, getConfig(), activityInfo, allResults, activeParamIndices, activeResultIndices); assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); assert(!activeResultIndices.empty() && "Result indices cannot be empty"); LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=("; llvm::interleave( activeParamIndices.begin(), activeParamIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); s << "), results=("; llvm::interleave( activeResultIndices.begin(), activeResultIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); s << ")\n";); // Form expected indices. AutoDiffConfig config( IndexSubset::get(getASTContext(), tai->getArgumentsWithoutIndirectResults().size(), activeParamIndices), IndexSubset::get(getASTContext(), tai->getSubstCalleeType()->getNumAutoDiffSemanticResults(), activeResultIndices)); // Emit the VJP. SILValue vjpValue; if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, tai, origCallee, config)) { errorOccurred = true; return; } // If the original `apply` instruction has a substitution map, then the // applied function is specialized. // In the VJP, specialization is also necessary for parity. The original // function operand is specialized with a remapped version of same // substitution map using an argument-less `partial_apply`. if (tai->getSubstitutionMap().empty()) { origCallee = builder.emitCopyValueOperation(loc, origCallee); } else { auto substMap = getOpSubstitutionMap(tai->getSubstitutionMap()); auto vjpPartialApply = builder.createPartialApply( tai->getLoc(), origCallee, substMap, {}, ParameterConvention::Direct_Guaranteed); origCallee = vjpPartialApply; originalFnTy = origCallee->getType().castTo(); // Diagnose if new original function type is non-differentiable. if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy, tai, origCallee, config)) { errorOccurred = true; return; } } auto *diffFuncInst = context.createDifferentiableFunction( builder, loc, config.parameterIndices, config.resultIndices, origCallee); // Record the `differentiable_function` instruction. context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst); builder.emitScopedBorrowOperation( loc, diffFuncInst, [&](SILValue borrowedADFunc) { auto extractedVJP = builder.createDifferentiableFunctionExtract( loc, NormalDifferentiableFunctionTypeComponent::VJP, borrowedADFunc); vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); }); builder.emitDestroyValueOperation(loc, diffFuncInst); // Record desired/actual VJP indices. // Temporarily set original pullback type to `None`. NestedApplyInfo info{config, /*originalPullbackType*/ std::nullopt}; auto insertion = context.getNestedApplyInfo().try_emplace(tai, info); auto &nestedApplyInfo = insertion.first->getSecond(); nestedApplyInfo = info; // Call the VJP using the original parameters. SmallVector vjpArgs; auto vjpFnTy = getOpType(vjpValue->getType()).castTo(); auto numVJPArgs = vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults() + (vjpFnTy->hasIndirectErrorResult() ? 1 : 0); vjpArgs.reserve(numVJPArgs); // Collect substituted arguments. for (auto origArg : tai->getArguments()) vjpArgs.push_back(getOpValue(origArg)); assert(vjpArgs.size() == numVJPArgs); // Create placeholder trampoline BB for error destination auto *errorBB = vjp->createBasicBlockBefore(getOpBasicBlock(tai->getErrorBB())); for (auto *arg : getOpBasicBlock(tai->getErrorBB())->getArguments().drop_back()) errorBB->createPhiArgument(arg->getType(), arg->getOwnershipKind()); // Create placeholder trampoline BB for normal destination auto *normalBB = vjp->createBasicBlockBefore(getOpBasicBlock(tai->getNormalBB())); normalBB->createPhiArgument( vjpFnTy->getDirectFormalResultsType(getModule(), TypeExpansionContext::minimal()), tai->getNormalBB()->getArgument(0)->getOwnershipKind()); // Apply the VJP. // The VJP should be specialized, so no substitution map is necessary. auto args = getOpValueArray<8>(tai->getArguments()); auto *vjpCall = builder.createTryApply(loc, vjpValue, SubstitutionMap(), vjpArgs, normalBB, errorBB, tai->getApplyOptions()); LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall); // Perform necessary cleanup on error path and forward the error result. // There is no pullback here. { SILBuilder trampolineBuilder(errorBB); trampolineBuilder.setCurrentDebugScope(getOpScope(tai->getDebugScope())); trampolineBuilder.emitDestroyValueOperation(loc, vjpValue); auto pullbackType = pullbackInfo.lookUpLinearMapType(tai); auto pullbackFnType = pullbackType->getOptionalObjectType(); auto loweredPullbackType = getOpType(getLoweredType(pullbackFnType)); auto tupleLoweredTy = remapType(pullbackInfo.getLinearMapTupleLoweredType(origBB)); // Find `Optional.none` EnumElementDecl. auto noneEltDecl = getASTContext().getOptionalNoneDecl(); // %enum = enum $Optional, #Optional.none!enumelt auto bbPullbackValues = getPullbackValues(origBB); bbPullbackValues.push_back( trampolineBuilder.createEnum(loc, SILValue(), noneEltDecl, SILType::getOptionalType(loweredPullbackType))); auto pbTupleVal = trampolineBuilder.createTuple(loc, tupleLoweredTy, bbPullbackValues); auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB, tai->getErrorBB(), pbTupleVal); SmallVector forwardedArguments( errorBB->getArguments().begin(), errorBB->getArguments().end()); forwardedArguments.push_back(succEnumVal); trampolineBuilder.createBranch(loc, getOpBasicBlock(tai->getErrorBB()), forwardedArguments); } // Capture the pullback on normal path and forward result { SILBuilder trampolineBuilder(normalBB); trampolineBuilder.setCurrentDebugScope(getOpScope(tai->getDebugScope())); trampolineBuilder.emitDestroyValueOperation(loc, vjpValue); // Get the VJP results (original results and pullback). SmallVector vjpDirectResults; extractAllElements(normalBB->getArgument(0), trampolineBuilder, vjpDirectResults); ArrayRef originalDirectResults = ArrayRef(vjpDirectResults).drop_back(1); SILValue originalDirectResult = joinElements(originalDirectResults, trampolineBuilder, vjpCall->getLoc()); SILValue pullback = vjpDirectResults.back(); { auto pullbackFnType = pullback->getType().castTo(); auto pullbackUnsubstFnType = pullbackFnType->getUnsubstitutedType(getModule()); if (pullbackFnType != pullbackUnsubstFnType) { pullback = trampolineBuilder.createConvertFunction( loc, pullback, SILType::getPrimitiveObjectType(pullbackUnsubstFnType), /*withoutActuallyEscaping*/ false); } } // Checkpoint the pullback. auto pullbackType = pullbackInfo.lookUpLinearMapType(tai); auto pullbackFnType = pullbackType->getOptionalObjectType(); // If actual pullback type does not match lowered pullback type, reabstract // the pullback using a thunk. auto actualPullbackType = getOpType(pullback->getType()).getAs(); auto loweredPullbackType = getOpType(getLoweredType(pullbackFnType)).castTo(); if (!loweredPullbackType->isEqual(actualPullbackType)) { // Set non-reabstracted original pullback type in nested apply info. nestedApplyInfo.originalPullbackType = actualPullbackType; SILOptFunctionBuilder fb(context.getTransform()); pullback = reabstractFunction( trampolineBuilder, fb, loc, pullback, loweredPullbackType, [this](SubstitutionMap subs) -> SubstitutionMap { return this->getOpSubstitutionMap(subs); }); } // Technically, the pullback value is not available in originalBB, // however, we record it for the try_apply's BB. This is safe as try_apply // is a terminator and we are emitting the pullback manually nestedApplyInfo.pullbackIdx = pullbackValues[origBB].size(); // Find `Optional.some` EnumElementDecl. auto someEltDecl = getASTContext().getOptionalSomeDecl(); // %enum = enum $Optional, #Optional.some!enumelt, // %pullback : $PullbackType pullback = trampolineBuilder.createEnum(loc, pullback, someEltDecl, SILType::getOptionalType(pullback->getType()), OwnershipKind::Owned); pullbackValues[origBB].push_back(pullback); auto tupleLoweredTy = remapType(pullbackInfo.getLinearMapTupleLoweredType(origBB)); auto pbTupleVal = trampolineBuilder.createTuple(loc, tupleLoweredTy, getPullbackValues(origBB)); auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, tai->getParent(), tai->getNormalBB(), pbTupleVal); SmallVector forwardedArguments{originalDirectResult, succEnumVal}; trampolineBuilder.createBranch(loc, getOpBasicBlock(tai->getNormalBB()), forwardedArguments); } // Some instructions that produce the callee may have been cloned. // If the original callee did not have any users beyond this `apply`, // recursively kill the cloned callee. if (auto *origCallee = cast_or_null( tai->getCallee()->getDefiningInstruction())) if (origCallee->hasOneUse()) recursivelyDeleteTriviallyDeadInstructions( getOpValue(origCallee)->getDefiningInstruction()); } void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { // Clone `differentiable_function` from original to VJP, then add the cloned // instruction to the `differentiable_function` worklist. TypeSubstCloner::visitDifferentiableFunctionInst(dfi); auto *newDFI = cast(getOpValue(dfi)); context.getDifferentiableFunctionInstWorklist().push_back(newDFI); } void visitLinearFunctionInst(LinearFunctionInst *lfi) { // Clone `linear_function` from original to VJP, then add the cloned // instruction to the `linear_function` worklist. TypeSubstCloner::visitLinearFunctionInst(lfi); auto *newLFI = cast(getOpValue(lfi)); context.getLinearFunctionInstWorklist().push_back(newLFI); } }; /// Initialization helper function. /// /// Returns the substitution map used for type remapping. static SubstitutionMap getSubstitutionMap(SILFunction *original, SILFunction *vjp) { auto substMap = original->getForwardingSubstitutionMap(); if (auto *vjpGenEnv = vjp->getGenericEnvironment()) { auto vjpSubstMap = vjpGenEnv->getForwardingSubstitutionMap(); substMap = SubstitutionMap::get( vjpGenEnv->getGenericSignature(), QuerySubstitutionMap{vjpSubstMap}, LookUpConformanceInSubstitutionMap(vjpSubstMap)); } return substMap; } /// Initialization helper function. /// /// Returns the activity info for the given original function, autodiff indices, /// and VJP generic signature. static const DifferentiableActivityInfo & getActivityInfoHelper(ADContext &context, SILFunction *original, const AutoDiffConfig &config, SILFunction *vjp) { // Get activity info of the original function. auto &passManager = context.getPassManager(); auto *activityAnalysis = passManager.getAnalysis(); auto &activityCollection = *activityAnalysis->get(original); auto &activityInfo = activityCollection.getActivityInfo( vjp->getLoweredFunctionType()->getSubstGenericSignature(), AutoDiffDerivativeFunctionKind::VJP); LLVM_DEBUG(activityInfo.dump(config, getADDebugStream())); return activityInfo; } VJPCloner::Implementation::Implementation(VJPCloner &cloner, ADContext &context, SILDifferentiabilityWitness *witness, SILFunction *vjp, DifferentiationInvoker invoker) : TypeSubstCloner(*vjp, *witness->getOriginalFunction(), getSubstitutionMap(witness->getOriginalFunction(), vjp)), cloner(cloner), context(context), original(witness->getOriginalFunction()), witness(witness), vjp(vjp), invoker(invoker), activityInfo(getActivityInfoHelper( context, original, witness->getConfig(), vjp)), loopInfo(context.getPassManager().getAnalysis() ->get(original)), pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp, witness->getConfig(), activityInfo, loopInfo) { // Create empty pullback function. pullback = createEmptyPullback(); context.recordGeneratedFunction(pullback); } VJPCloner::VJPCloner(ADContext &context, SILDifferentiabilityWitness *witness, SILFunction *vjp, DifferentiationInvoker invoker) : impl(*new Implementation(*this, context, witness, vjp, invoker)) {} VJPCloner::~VJPCloner() { delete &impl; } ADContext &VJPCloner::getContext() const { return impl.context; } SILModule &VJPCloner::getModule() const { return impl.getModule(); } SILFunction &VJPCloner::getOriginal() const { return *impl.original; } SILFunction &VJPCloner::getVJP() const { return *impl.vjp; } SILFunction &VJPCloner::getPullback() const { return *impl.pullback; } SILDifferentiabilityWitness *VJPCloner::getWitness() const { return impl.witness; } const AutoDiffConfig &VJPCloner::getConfig() const { return impl.getConfig(); } DifferentiationInvoker VJPCloner::getInvoker() const { return impl.invoker; } LinearMapInfo &VJPCloner::getPullbackInfo() const { return impl.pullbackInfo; } SILLoopInfo *VJPCloner::getLoopInfo() const { return impl.loopInfo; } const DifferentiableActivityInfo &VJPCloner::getActivityInfo() const { return impl.activityInfo; } SILFunction *VJPCloner::Implementation::createEmptyPullback() { auto origTy = original->getLoweredFunctionType(); // Get witness generic signature for remapping types. // Witness generic signature may have more requirements than VJP generic // signature: when witness generic signature has same-type requirements // binding all generic parameters to concrete types, VJP function type uses // all the concrete types and VJP generic signature is null. auto witnessCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature(); auto lookupConformance = LookUpConformanceInModule(); // Given a type, returns its formal SIL parameter info. auto getTangentParameterInfoForOriginalResult = [&](CanType tanType, ResultConvention origResConv) -> SILParameterInfo { tanType = tanType->getReducedType(witnessCanGenSig); Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType); auto &tl = context.getTypeConverter().getTypeLowering( pattern, tanType, TypeExpansionContext::minimal()); ParameterConvention conv; switch (origResConv) { case ResultConvention::Unowned: case ResultConvention::UnownedInnerPointer: case ResultConvention::Owned: case ResultConvention::Autoreleased: if (tl.isAddressOnly()) { conv = ParameterConvention::Indirect_In_Guaranteed; } else { conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned : ParameterConvention::Direct_Guaranteed; } break; case ResultConvention::Indirect: conv = ParameterConvention::Indirect_In_Guaranteed; break; case ResultConvention::Pack: conv = ParameterConvention::Pack_Guaranteed; break; } return {tanType, conv}; }; // Given a type, returns its formal SIL result info. auto getTangentResultInfoForOriginalParameter = [&](CanType tanType, ParameterConvention origParamConv) -> SILResultInfo { tanType = tanType->getReducedType(witnessCanGenSig); Lowering::AbstractionPattern pattern(witnessCanGenSig, tanType); auto &tl = context.getTypeConverter().getTypeLowering( pattern, tanType, TypeExpansionContext::minimal()); ResultConvention conv; switch (origParamConv) { case ParameterConvention::Direct_Owned: case ParameterConvention::Direct_Guaranteed: case ParameterConvention::Direct_Unowned: if (tl.isAddressOnly()) { conv = ResultConvention::Indirect; } else { conv = tl.isTrivial() ? ResultConvention::Unowned : ResultConvention::Owned; } break; case ParameterConvention::Indirect_In: case ParameterConvention::Indirect_Inout: case ParameterConvention::Indirect_In_Guaranteed: case ParameterConvention::Indirect_InoutAliasable: case ParameterConvention::Indirect_In_CXX: conv = ResultConvention::Indirect; break; case ParameterConvention::Pack_Guaranteed: case ParameterConvention::Pack_Owned: case ParameterConvention::Pack_Inout: conv = ResultConvention::Pack; break; } return {tanType, conv}; }; // Parameters of the pullback are: // - the tangent vectors of the original results, and // - a pullback struct. // Results of the pullback are in the tangent space of the original // parameters. SmallVector pbParams; SmallVector pbYields; SmallVector adjResults; auto origParams = origTy->getParameters(); auto config = witness->getConfig(); // Add pullback parameters based on original result indices. SmallVector semanticResultParamIndices; for (auto i : range(origTy->getNumParameters())) { auto origParam = origParams[i]; if (!origParam.isAutoDiffSemanticResult()) continue; semanticResultParamIndices.push_back(i); } unsigned firstSemanticParamResultIdx = origTy->getNumResults(); unsigned firstYieldResultIndex = firstSemanticParamResultIdx + origTy->getNumAutoDiffSemanticResultsParameters(); for (auto resultIndex : config.resultIndices->getIndices()) { // Handle formal result. if (resultIndex < firstSemanticParamResultIdx) { auto origResult = origTy->getResults()[resultIndex]; origResult = origResult.getWithInterfaceType( origResult.getInterfaceType()->getReducedType(witnessCanGenSig)); auto paramInfo = getTangentParameterInfoForOriginalResult( origResult.getInterfaceType() ->getAutoDiffTangentSpace(lookupConformance) ->getType() ->getReducedType(witnessCanGenSig), origResult.getConvention()); pbParams.push_back(paramInfo); } else if (resultIndex < firstYieldResultIndex) { // Handle semantic result parameter. unsigned paramIndex = 0; unsigned resultParamIndex = 0; for (auto i : range(origTy->getNumParameters())) { auto origParam = origTy->getParameters()[i]; if (!origParam.isAutoDiffSemanticResult()) { ++paramIndex; continue; } if (resultParamIndex == resultIndex - firstSemanticParamResultIdx) break; ++paramIndex; ++resultParamIndex; } auto resultParam = origParams[paramIndex]; auto origResult = resultParam.getWithInterfaceType( resultParam.getInterfaceType()->getReducedType(witnessCanGenSig)); auto resultParamTanConvention = resultParam.getConvention(); if (!config.isWrtParameter(paramIndex)) resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed; pbParams.emplace_back(origResult.getInterfaceType() ->getAutoDiffTangentSpace(lookupConformance) ->getType() ->getReducedType(witnessCanGenSig), resultParamTanConvention); } else { assert(origTy->isCoroutine()); assert(origTy->getCoroutineKind() == SILCoroutineKind::YieldOnce); auto yieldResultIndex = resultIndex - firstYieldResultIndex; auto yieldResult = origTy->getYields()[yieldResultIndex]; auto origYield = yieldResult.getWithInterfaceType( yieldResult.getInterfaceType()->getReducedType(witnessCanGenSig)); assert(yieldResult.getConvention() == ParameterConvention::Indirect_Inout); pbYields.emplace_back(origYield.getInterfaceType() ->getAutoDiffTangentSpace(lookupConformance) ->getType() ->getReducedType(witnessCanGenSig), yieldResult.getConvention()); } } if (pullbackInfo.hasHeapAllocatedContext()) { // Accept a `AutoDiffLinarMapContext` heap object if there are loops. pbParams.push_back({ getASTContext().TheNativeObjectType, ParameterConvention::Direct_Guaranteed }); } else { // Accept a pullback tuple in the pullback parameter list. This is the // returned pullback's closure context. auto *origExit = &*original->findReturnBB(); auto pbTupleType = pullbackInfo.getLinearMapTupleLoweredType(origExit).getAs(); for (Type eltTy : pbTupleType->getElementTypes()) pbParams.emplace_back(CanType(eltTy), ParameterConvention::Direct_Owned); } // Add pullback results for the requested wrt parameters. for (auto i : config.parameterIndices->getIndices()) { auto origParam = origParams[i]; if (origParam.isAutoDiffSemanticResult()) continue; origParam = origParam.getWithInterfaceType( origParam.getInterfaceType()->getReducedType(witnessCanGenSig)); adjResults.push_back(getTangentResultInfoForOriginalParameter( origParam.getInterfaceType() ->getAutoDiffTangentSpace(lookupConformance) ->getType() ->getReducedType(witnessCanGenSig), origParam.getConvention())); } Mangle::DifferentiationMangler mangler(getASTContext()); auto pbName = mangler.mangleLinearMap( original->getName(), AutoDiffLinearMapKind::Pullback, config); // Set pullback generic signature equal to VJP generic signature. // Do not use witness generic signature, which may have same-type requirements // binding all generic parameters to concrete types. auto pbGenericSig = vjp->getLoweredFunctionType()->getSubstGenericSignature(); auto *pbGenericEnv = pbGenericSig.getGenericEnvironment(); auto pbType = SILFunctionType::get( pbGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(), origTy->getCalleeConvention(), pbParams, pbYields, adjResults, std::nullopt, origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(), original->getASTContext()); SILOptFunctionBuilder fb(context.getTransform()); auto linkage = vjp->isSerialized() ? SILLinkage::Public : SILLinkage::Private; auto *pullback = fb.createFunction( linkage, context.getASTContext().getIdentifier(pbName).str(), pbType, pbGenericEnv, original->getLocation(), original->isBare(), IsNotTransparent, vjp->getSerializedKind(), original->isDynamicallyReplaceable(), original->isDistributed(), original->isRuntimeAccessible()); auto &module = context.getModule(); pullback->setDebugScope(new (module) SILDebugScope(original->getLocation(), pullback)); return pullback; } SILBasicBlock *VJPCloner::Implementation::createTrampolineBasicBlock( TermInst *termInst, TupleInst *pbTupleVal, SILBasicBlock *succBB) { assert(llvm::find(termInst->getSuccessorBlocks(), succBB) != termInst->getSuccessorBlocks().end() && "Basic block is not a successor of terminator instruction"); // Create the trampoline block. auto *vjpSuccBB = getOpBasicBlock(succBB); auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); for (auto *arg : vjpSuccBB->getArguments().drop_back()) trampolineBB->createPhiArgument(arg->getType(), arg->getOwnershipKind()); // In the trampoline block, build predecessor enum value for VJP successor // block and branch to it. SILBuilder trampolineBuilder(trampolineBB); trampolineBuilder.setCurrentDebugScope(getOpScope(termInst->getDebugScope())); auto *origBB = termInst->getParent(); auto *succEnumVal = buildPredecessorEnumValue(trampolineBuilder, origBB, succBB, pbTupleVal); SmallVector forwardedArguments( trampolineBB->getArguments().begin(), trampolineBB->getArguments().end()); forwardedArguments.push_back(succEnumVal); trampolineBuilder.createBranch(termInst->getLoc(), vjpSuccBB, forwardedArguments); return trampolineBB; } llvm::SmallVector VJPCloner::Implementation::getPullbackValues(SILBasicBlock *origBB) { auto *vjpBB = BBMap[origBB]; auto bbPullbackValues = pullbackValues[origBB]; if (!origBB->isEntry()) { auto *predEnumArg = vjpBB->getArguments().back(); bbPullbackValues.insert(bbPullbackValues.begin(), predEnumArg); } return bbPullbackValues; } TupleInst * VJPCloner::Implementation::buildPullbackValueTupleValue(TermInst *termInst) { assert(termInst->getFunction() == original); auto loc = RegularLocation::getAutoGeneratedLocation(); auto origBB = termInst->getParent(); auto tupleLoweredTy = remapType(pullbackInfo.getLinearMapTupleLoweredType(origBB)); auto bbPullbackValues = getPullbackValues(origBB); return getBuilder().createTuple(loc, tupleLoweredTy, bbPullbackValues); } EnumInst *VJPCloner::Implementation::buildPredecessorEnumValue( SILBuilder &builder, SILBasicBlock *predBB, SILBasicBlock *succBB, SILValue pbTupleVal) { auto loc = RegularLocation::getAutoGeneratedLocation(); auto enumLoweredTy = remapType(pullbackInfo.getBranchingTraceEnumLoweredType(succBB)); auto *enumEltDecl = pullbackInfo.lookUpBranchingTraceEnumElement(predBB, succBB); auto enumEltType = getOpType(enumLoweredTy.getEnumElementType( enumEltDecl, getModule(), TypeExpansionContext::minimal())); // If the predecessor block is in a loop, its predecessor enum payload is a // `Builtin.RawPointer`. if (loopInfo->getLoopFor(predBB)) { auto rawPtrType = SILType::getRawPointerType(getASTContext()); assert(enumEltType == rawPtrType); auto pbTupleType = remapASTType(pullbackInfo.getLinearMapTupleType(predBB)->getCanonicalType()); auto pbTupleMetatypeType = CanMetatypeType::get(pbTupleType, MetatypeRepresentation::Thick); auto pbTupleMetatypeSILType = SILType::getPrimitiveObjectType(pbTupleMetatypeType); auto pbTupleMetatype = Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType); auto rawBufferValue = builder.createBuiltin( loc, getASTContext().getIdentifier(getBuiltinName( BuiltinValueKind::AutoDiffAllocateSubcontextWithType)), rawPtrType, SubstitutionMap(), {borrowedPullbackContextValue, pbTupleMetatype}); auto typedBufferValue = builder.createPointerToAddress( loc, rawBufferValue, pbTupleVal->getType().getAddressType(), /*isStrict*/ true); builder.createStore( loc, pbTupleVal, typedBufferValue, pbTupleVal->getType().isTrivial(*pullback) ? StoreOwnershipQualifier::Trivial : StoreOwnershipQualifier::Init); return builder.createEnum(loc, rawBufferValue, enumEltDecl, enumLoweredTy); } return builder.createEnum(loc, pbTupleVal, enumEltDecl, enumLoweredTy); } bool VJPCloner::Implementation::run() { PrettyStackTraceSILFunction trace("generating VJP for", original); LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName() << " to vjp @" << vjp->getName() << '\n'); // Create entry BB and arguments. auto *entry = vjp->createBasicBlock(); createEntryArguments(vjp); emitLinearMapContextInitializationIfNeeded(); // Clone. SmallVector entryArgs; entryArgs.assign(entry->getArguments().begin(), entry->getArguments().end()); cloneFunctionBody(original, entry, entryArgs); // If errors occurred, back out. if (errorOccurred) return true; // Merge VJP basic blocks. This is significant for control flow // differentiation: trampoline destination bbs are merged into trampoline bbs. // NOTE(TF-990): Merging basic blocks ensures that `@guaranteed` trampoline // bb arguments have a lifetime-ending `end_borrow` use, and is robust when // `-enable-strip-ownership-after-serialization` is true. mergeBasicBlocks(vjp); LLVM_DEBUG(getADDebugStream() << "Generated VJP for " << original->getName() << ":\n" << *vjp); // Generate pullback code. PullbackCloner PullbackCloner(cloner); if (PullbackCloner.run()) { errorOccurred = true; return true; } return errorOccurred; } bool VJPCloner::run() { bool foundError = impl.run(); #ifndef NDEBUG if (!foundError) getVJP().verify(); #endif return foundError; } } // end namespace autodiff } // end namespace swift