//===--- LinearMapInfo.cpp ------------------------------------*- C++ -*---===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // Linear map tuple and branching trace enum information for differentiation. // //===----------------------------------------------------------------------===// #define DEBUG_TYPE "differentiation" #include "swift/SIL/ApplySite.h" #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" #include "swift/AST/DeclContext.h" #include "swift/AST/ParameterList.h" #include "swift/AST/SourceFile.h" #include "swift/Basic/Assertions.h" #include "swift/SIL/LoopInfo.h" namespace swift { namespace autodiff { //===----------------------------------------------------------------------===// // Local helpers //===----------------------------------------------------------------------===// /// Clone the generic parameters of the given generic signature and return a new /// `GenericParamList`. static GenericParamList *cloneGenericParameters(ASTContext &ctx, DeclContext *dc, CanGenericSignature sig) { SmallVector clonedParams; for (auto paramType : sig.getGenericParams()) { auto *clonedParam = GenericTypeParamDecl::createImplicit( dc, paramType->getName(), paramType->getDepth(), paramType->getIndex(), paramType->getParamKind()); clonedParam->setDeclContext(dc); clonedParams.push_back(clonedParam); } return GenericParamList::create(ctx, SourceLoc(), clonedParams, SourceLoc()); } //===----------------------------------------------------------------------===// // LinearMapInfo methods //===----------------------------------------------------------------------===// LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind, SILFunction *original, SILFunction *derivative, const AutoDiffConfig &config, const DifferentiableActivityInfo &activityInfo, SILLoopInfo *loopInfo) : kind(kind), original(original), derivative(derivative), activityInfo(activityInfo), loopInfo(loopInfo), config(config), synthesizedFile(context.getOrCreateSynthesizedFile(original)), typeConverter(context.getTypeConverter()) { generateDifferentiationDataStructures(context, derivative); } SILType LinearMapInfo::remapTypeInDerivative(SILType ty) { if (ty.hasArchetype()) return derivative->mapTypeIntoContext(ty.mapTypeOutOfContext()); return derivative->mapTypeIntoContext(ty); } EnumDecl * LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB, CanGenericSignature genericSig) { assert(originalBB->getParent() == original); auto &astCtx = original->getASTContext(); auto &file = getSynthesizedFile(); // Create a branching trace enum. Mangle::ASTMangler mangler(astCtx); auto config = this->config.withGenericSignature(genericSig); auto enumName = mangler.mangleAutoDiffGeneratedDeclaration( AutoDiffGeneratedDeclarationKind::BranchingTraceEnum, original->getName().str(), originalBB->getDebugID(), kind, config); auto enumId = astCtx.getIdentifier(enumName); auto loc = original->getLocation().getSourceLoc(); GenericParamList *genericParams = nullptr; if (genericSig) genericParams = cloneGenericParameters(astCtx, &file, genericSig); auto *branchingTraceDecl = new (astCtx) EnumDecl( /*EnumLoc*/ SourceLoc(), /*Name*/ enumId, /*NameLoc*/ loc, /*Inherited*/ {}, /*GenericParams*/ genericParams, /*DC*/ &file); // Note: must mark enum as implicit to satisfy assertion in // `Parser::parseDeclListDelayed`. branchingTraceDecl->setImplicit(); if (genericSig) branchingTraceDecl->setGenericSignature(genericSig); switch (original->getEffectiveSymbolLinkage()) { case swift::SILLinkage::Public: case swift::SILLinkage::PublicNonABI: // Branching trace enums shall not be resilient. branchingTraceDecl->getAttrs().add(new (astCtx) FrozenAttr(/*implicit*/ true)); branchingTraceDecl->getAttrs().add(new (astCtx) UsableFromInlineAttr(/*Implicit*/ true)); LLVM_FALLTHROUGH; case swift::SILLinkage::Hidden: case swift::SILLinkage::Shared: branchingTraceDecl->setAccess(AccessLevel::Internal); break; case swift::SILLinkage::Private: branchingTraceDecl->setAccess(AccessLevel::FilePrivate); break; default: // When the original function has external linkage, we create an internal // struct for use by our own module. This is necessary for cross-cell // differentiation in Jupyter. // TODO: Add a test in the compiler that exercises a similar situation as // cross-cell differentiation in Jupyter. branchingTraceDecl->setAccess(AccessLevel::Internal); } file.addTopLevelDecl(branchingTraceDecl); file.getParentModule()->clearLookupCache(); return branchingTraceDecl; } void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB, SILLoopInfo *loopInfo) { auto &astCtx = original->getASTContext(); auto *moduleDecl = original->getModule().getSwiftModule(); auto loc = original->getLocation().getSourceLoc(); auto *branchingTraceDecl = getBranchingTraceDecl(originalBB); // Add basic block enum cases. for (auto *predBB : originalBB->getPredecessorBlocks()) { // Create dummy declaration representing enum case parameter. auto *decl = new (astCtx) ParamDecl(loc, loc, Identifier(), loc, Identifier(), moduleDecl); decl->setSpecifier(ParamDecl::Specifier::Default); // If predecessor block is in a loop, its linear map tuple will be // indirectly referenced in memory owned by the context object. The payload // is just a raw pointer. if (loopInfo->getLoopFor(predBB)) { heapAllocatedContext = true; decl->setInterfaceType(astCtx.TheRawPointerType); } else { // Otherwise the payload is the linear map tuple. auto *linearMapStructTy = getLinearMapTupleType(predBB); // Do not create entries for unreachable predecessors if (!linearMapStructTy) continue; auto canLinearMapStructTy = linearMapStructTy->getCanonicalType(); decl->setInterfaceType( canLinearMapStructTy->hasArchetype() ? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy); } // Create enum element and enum case declarations. auto *paramList = ParameterList::create(astCtx, {decl}); auto bbId = "bb" + std::to_string(predBB->getDebugID()); auto *enumEltDecl = new (astCtx) EnumElementDecl( /*IdentifierLoc*/ loc, DeclName(astCtx.getIdentifier(bbId)), paramList, loc, /*RawValueExpr*/ nullptr, branchingTraceDecl); enumEltDecl->setImplicit(); auto *enumCaseDecl = EnumCaseDecl::create( /*CaseLoc*/ loc, {enumEltDecl}, branchingTraceDecl); enumCaseDecl->setImplicit(); branchingTraceDecl->addMember(enumEltDecl); branchingTraceDecl->addMember(enumCaseDecl); // Record enum element declaration. branchingTraceEnumCases.insert({{predBB, originalBB}, enumEltDecl}); } } Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) { SmallVector allResults; SmallVector activeParamIndices; SmallVector activeResultIndices; collectMinimalIndicesForFunctionCall(fai, config, activityInfo, allResults, activeParamIndices, activeResultIndices); // Check if there are any active results or arguments. If not, skip // this instruction. auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) { return activityInfo.isActive(res, config); }); bool hasActiveSemanticResultArgument = false; bool hasActiveArguments = false; auto numIndirectResults = fai.getNumIndirectSILResults(); for (auto argIdx : range(fai.getSubstCalleeConv().getNumParameters())) { auto arg = fai.getArgumentsWithoutIndirectResults()[argIdx]; if (activityInfo.isActive(arg, config)) { hasActiveArguments = true; auto paramInfo = fai.getSubstCalleeConv().getParamInfoForSILArg( numIndirectResults + argIdx); if (paramInfo.isAutoDiffSemanticResult()) hasActiveSemanticResultArgument = true; } } if (!hasActiveArguments) return {}; if (!hasActiveResults && !hasActiveSemanticResultArgument) return {}; // Compute differentiability parameters. // - If the callee has `@differentiable` function type, use differentiation // parameters from the function type. // - Otherwise, use the active parameters. IndexSubset *parameters; auto origFnSubstTy = fai.getSubstCalleeType(); auto remappedOrigFnSubstTy = remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy)) .castTo() ->getUnsubstitutedType(original->getModule()); if (remappedOrigFnSubstTy->isDifferentiable()) { parameters = remappedOrigFnSubstTy->getDifferentiabilityParameterIndices(); } else { parameters = IndexSubset::get( original->getASTContext(), fai.getArgumentsWithoutIndirectResults().size(), activeParamIndices); } // Compute differentiability results. auto *results = IndexSubset::get(original->getASTContext(), remappedOrigFnSubstTy->getNumAutoDiffSemanticResults(), activeResultIndices); // Create autodiff indices for the `apply` instruction. AutoDiffConfig applyConfig(parameters, results); // Check for non-differentiable original function type. auto checkNondifferentiableOriginalFunctionType = [&](CanSILFunctionType origFnTy) { // Check non-differentiable arguments. for (auto paramIndex : applyConfig.parameterIndices->getIndices()) { auto remappedParamType = origFnTy->getParameters()[paramIndex].getSILStorageInterfaceType(); if (!remappedParamType.isDifferentiable(derivative->getModule())) return true; } // Check non-differentiable results. unsigned firstSemanticParamResultIdx = origFnTy->getNumResults(); unsigned firstYieldResultIndex = origFnTy->getNumResults() + origFnTy->getNumAutoDiffSemanticResultsParameters(); for (auto resultIndex : applyConfig.resultIndices->getIndices()) { SILType remappedResultType; if (resultIndex >= firstYieldResultIndex) { auto yieldResultIdx = resultIndex - firstYieldResultIndex; const auto& yield = origFnTy->getYields()[yieldResultIdx]; // We do not have a good way to differentiate direct yields if (!yield.isAutoDiffSemanticResult()) return true; remappedResultType = yield.getSILStorageInterfaceType(); } else if (resultIndex >= firstSemanticParamResultIdx) { auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx; auto semanticResultArg = *std::next(fai.getAutoDiffSemanticResultArguments().begin(), semanticResultArgIdx); remappedResultType = semanticResultArg->getType(); } else { remappedResultType = origFnTy->getResults()[resultIndex].getSILStorageInterfaceType(); } if (!remappedResultType.isDifferentiable(derivative->getModule())) return true; } return false; }; if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy)) return nullptr; AutoDiffDerivativeFunctionKind derivativeFnKind(kind); auto derivativeFnType = remappedOrigFnSubstTy ->getAutoDiffDerivativeFunctionType( parameters, results, derivativeFnKind, context.getTypeConverter(), LookUpConformanceInModule()) ->getUnsubstitutedType(original->getModule()); auto linearMapSILType = derivativeFnType->getAllResultsInterfaceType(); if (auto tupleType = linearMapSILType.getAs()) { linearMapSILType = SILType::getPrimitiveObjectType( tupleType.getElementType(tupleType->getElements().size() - 1)); } if (auto fnTy = linearMapSILType.getAs()) { linearMapSILType = SILType::getPrimitiveObjectType( fnTy->getUnsubstitutedType(original->getModule())); } // IRGen requires decls to have AST types (not `SILFunctionType`), so we // convert the `SILFunctionType` of the linear map to a `FunctionType` with // the same parameters and results. auto silFnTy = linearMapSILType.castTo(); SmallVector params; for (auto ¶m : silFnTy->getParameters()) { ParameterTypeFlags flags; if (param.isAutoDiffSemanticResult()) flags = flags.withInOut(true); params.push_back( AnyFunctionType::Param(param.getInterfaceType(), Identifier(), flags)); } AnyFunctionType *astFnTy; if (auto genSig = silFnTy->getSubstGenericSignature()) { // FIXME: Verify ExtInfo state is correct, not working by accident. GenericFunctionType::ExtInfo info; astFnTy = GenericFunctionType::get( genSig, params, silFnTy->getAllResultsInterfaceType().getASTType(), info); } else { FunctionType::ExtInfo info; astFnTy = FunctionType::get( params, silFnTy->getAllResultsInterfaceType().getASTType(), info); } if (astFnTy->hasArchetype()) return astFnTy->mapTypeOutOfContext(); return astFnTy; } void LinearMapInfo::generateDifferentiationDataStructures( ADContext &context, SILFunction *derivativeFn) { auto &astCtx = original->getASTContext(); // Get the derivative function generic signature. CanGenericSignature derivativeFnGenSig = nullptr; if (auto *derivativeFnGenEnv = derivativeFn->getGenericEnvironment()) derivativeFnGenSig = derivativeFnGenEnv->getGenericSignature().getCanonicalSignature(); // Create branching trace enum for each original block and add it as a field // in the corresponding struct. StringRef traceEnumFieldName; switch (kind) { case AutoDiffLinearMapKind::Differential: traceEnumFieldName = "successor"; break; case AutoDiffLinearMapKind::Pullback: traceEnumFieldName = "predecessor"; break; } for (auto &origBB : *original) { auto *traceEnum = createBranchingTraceDecl(&origBB, derivativeFnGenSig); branchingTraceDecls.insert({&origBB, traceEnum}); } // Add linear map fields to the linear map tuples. // // Now we need to be very careful as we're having a very subtle // chicken-and-egg problem. We need lowered branch trace enum type for the // linear map typle type. However branch trace enum type lowering depends on // the lowering of its elements (at very least, the type classification of // being trivial / non-trivial). As the lowering is cached we need to ensure // we compute lowered type for the branch trace enum when the corresponding // EnumDecl is fully complete: we cannot add more entries without causing some // very subtle issues later on. However, the elements of the enum are linear // map tuples of predecessors, that correspondingly may contain branch trace // enums of corresponding predecessor BBs. // // Traverse all BBs in reverse post-order traversal order to ensure we process // each BB before its predecessors. llvm::ReversePostOrderTraversal RPOT(original); for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) { auto *origBB = *Iter; SmallVector linearTupleTypes; if (!origBB->isEntry()) { populateBranchingTraceDecl(origBB, loopInfo); CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType(); linearTupleTypes.emplace_back(traceEnumType, astCtx.getIdentifier(traceEnumFieldName)); } if (isSemanticMemberAccessor(original)) { // Do not add linear map fields for semantic member accessors, which have // special-case pullback generation. Linear map tuples should be empty. } else { for (auto &inst : *origBB) { if (auto *ai = dyn_cast(&inst)) // Skip array literal intrinsic applications since array literal // initialization is linear and handled separately. if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC) || ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) continue; if (!isa(&inst)) continue; FullApplySite fai(&inst); // Add linear map field to struct for active apply sites instructions. if (!shouldDifferentiateApplySite(fai)) continue; LLVM_DEBUG(getADDebugStream() << "Adding linear map tuple field for " << inst); if (Type linearMapType = getLinearMapType(context, fai)) { LLVM_DEBUG(getADDebugStream() << "Computed type: " << linearMapType << '\n'); linearMapIndexMap.insert({fai, linearTupleTypes.size()}); linearTupleTypes.emplace_back(linearMapType); } } } linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)}); } // Print generated linear map structs and branching trace enums. // These declarations do not show up with `-emit-sil` because they are // implicit. Instead, use `-Xllvm -debug-only=differentiation` to test // declarations with FileCheck. LLVM_DEBUG({ auto &s = getADDebugStream(); PrintOptions printOptions; printOptions.TypeDefinitions = true; printOptions.ExplodePatternBindingDecls = true; printOptions.SkipImplicit = false; s << "Generated linear map tuples and branching trace enums for @" << original->getName() << ":\n"; for (auto &origBB : *original) { auto *linearMapTuple = getLinearMapTupleType(&origBB); linearMapTuple->print(s, printOptions); s << '\n'; } for (auto &origBB : *original) { auto *traceEnum = getBranchingTraceDecl(&origBB); traceEnum->print(s, printOptions); s << '\n'; } }); } /// Returns a flag that indicates whether the `apply` instruction should be /// differentiated, given the differentiation indices of the instruction's /// parent function. Whether the `apply` should be differentiated is determined /// sequentially from the following conditions: /// 1. The instruction has an active `inout` argument. /// 2. The instruction is a call to the array literal initialization intrinsic /// ("array.uninitialized_intrinsic"), where the result is active and where /// there is a `store` of an active value into the array's buffer. /// 3. The instruction has both an active result (direct or indirect) and an /// active argument. bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) { // Function applications with an active inout argument should be // differentiated. for (auto inoutArg : applySite.getInoutArguments()) if (activityInfo.isActive(inoutArg, config)) return true; bool hasActiveDirectResults = false; forEachApplyDirectResult(applySite, [&](SILValue directResult) { hasActiveDirectResults |= activityInfo.isActive(directResult, config); }); bool hasActiveIndirectResults = llvm::any_of(applySite.getIndirectSILResults(), [&](SILValue result) { return activityInfo.isActive(result, config); }); bool hasActiveResults = hasActiveDirectResults || hasActiveIndirectResults; // TODO: Pattern match to make sure there is at least one `store` to the // array's active buffer. if (ArraySemanticsCall(applySite.getInstruction(), semantics::ARRAY_UNINITIALIZED_INTRINSIC) && hasActiveResults) return true; auto arguments = applySite.getArgumentsWithoutIndirectResults(); bool hasActiveArguments = llvm::any_of(arguments, [&](SILValue arg) { return activityInfo.isActive(arg, config); }); return hasActiveResults && hasActiveArguments; } static bool shouldDifferentiateInjectEnumAddr( const InjectEnumAddrInst &inject, const DifferentiableActivityInfo &activityInfo, const AutoDiffConfig &config) { SILValue en = inject.getOperand(); for (auto use : en->getUses()) { auto *init = dyn_cast(use->getUser()); if (init && activityInfo.isActive(init, config)) return true; } return false; } /// Returns a flag indicating whether the instruction should be differentiated, /// given the differentiation indices of the instruction's parent function. /// Whether the instruction should be differentiated is determined sequentially /// from any of the following conditions: /// 1. The instruction is a full apply site and `shouldDifferentiateApplyInst` /// returns true. /// 2. The instruction has a source operand and a destination operand, both /// being active. /// 3. The instruction is an allocation instruction and has an active result. /// 4. The instruction performs reference counting, lifetime ending, access /// ending, or destroying on an active operand. /// 5. The instruction creates an SSA copy of an active operand. bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) { // A full apply site with an active argument and an active result (direct or // indirect) should be differentiated. if (FullApplySite::isa(inst)) return shouldDifferentiateApplySite(FullApplySite(inst)); // Anything with an active result and an active operand should be // differentiated. auto hasActiveOperands = llvm::any_of(inst->getAllOperands(), [&](Operand &op) { return activityInfo.isActive(op.get(), config); }); auto hasActiveResults = llvm::any_of(inst->getResults(), [&](SILValue val) { return activityInfo.isActive(val, config); }); if (hasActiveOperands && hasActiveResults) return true; // `store`-like instructions do not have an SSA result, but have two // operands that represent the source and the destination. We treat them as // the input and the output, respectively. // For `store`-like instructions whose destination is an element address // from an `array.uninitialized_intrinsic` application, return true if the // intrinsic application (representing the semantic destination) is active. #define CHECK_INST_TYPE_ACTIVE_DEST(INST) \ if (auto *castInst = dyn_cast(inst)) \ return activityInfo.isActive(castInst->getDest(), config); CHECK_INST_TYPE_ACTIVE_DEST(Store) CHECK_INST_TYPE_ACTIVE_DEST(StoreBorrow) CHECK_INST_TYPE_ACTIVE_DEST(CopyAddr) CHECK_INST_TYPE_ACTIVE_DEST(UnconditionalCheckedCastAddr) #undef CHECK_INST_TYPE_ACTIVE_DEST // Should differentiate any allocation instruction that has an active result. if ((isa(inst) && hasActiveResults)) return true; // Should differentiate end_apply if the corresponding begin_apply is // differentiable if (auto *eai = dyn_cast(inst)) return shouldDifferentiateApplySite(eai->getBeginApply()); if (hasActiveOperands) { // Should differentiate any instruction that performs reference counting, // lifetime ending, access ending, or destroying on an active operand. if (isa(inst) || isa(inst) || isa(inst) || isa(inst) || isa(inst) || isa(inst) || isa(inst) || isa(inst)) return true; } // Should differentiate `inject_enum_addr` if the corresponding // `init_enum_addr` has an active operand. if (auto inject = dyn_cast(inst)) if (shouldDifferentiateInjectEnumAddr(*inject, activityInfo, config)) return true; return false; } } // end namespace autodiff } // end namespace swift