mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
561 lines
24 KiB
C++
561 lines
24 KiB
C++
//===--- LinearMapInfo.cpp ------------------------------------*- C++ -*---===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Linear map tuple and branching trace enum information for differentiation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define DEBUG_TYPE "differentiation"
|
|
|
|
#include "swift/SIL/ApplySite.h"
|
|
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
|
|
#include "swift/SILOptimizer/Differentiation/ADContext.h"
|
|
|
|
#include "swift/AST/DeclContext.h"
|
|
#include "swift/AST/ParameterList.h"
|
|
#include "swift/AST/SourceFile.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
#include "swift/SIL/LoopInfo.h"
|
|
|
|
namespace swift {
|
|
namespace autodiff {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Local helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Clone the generic parameters of the given generic signature and return a new
|
|
/// `GenericParamList`.
|
|
static GenericParamList *cloneGenericParameters(ASTContext &ctx,
|
|
DeclContext *dc,
|
|
CanGenericSignature sig) {
|
|
SmallVector<GenericTypeParamDecl *, 2> clonedParams;
|
|
for (auto paramType : sig.getGenericParams()) {
|
|
auto *clonedParam = GenericTypeParamDecl::createImplicit(
|
|
dc, paramType->getName(), paramType->getDepth(), paramType->getIndex(),
|
|
paramType->getParamKind());
|
|
clonedParam->setDeclContext(dc);
|
|
clonedParams.push_back(clonedParam);
|
|
}
|
|
return GenericParamList::create(ctx, SourceLoc(), clonedParams, SourceLoc());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LinearMapInfo methods
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
|
|
SILFunction *original, SILFunction *derivative,
|
|
const AutoDiffConfig &config,
|
|
const DifferentiableActivityInfo &activityInfo,
|
|
SILLoopInfo *loopInfo)
|
|
: kind(kind), original(original), derivative(derivative),
|
|
activityInfo(activityInfo), loopInfo(loopInfo), config(config),
|
|
synthesizedFile(context.getOrCreateSynthesizedFile(original)),
|
|
typeConverter(context.getTypeConverter()) {
|
|
generateDifferentiationDataStructures(context, derivative);
|
|
}
|
|
|
|
SILType LinearMapInfo::remapTypeInDerivative(SILType ty) {
|
|
if (ty.hasArchetype())
|
|
return derivative->mapTypeIntoContext(ty.mapTypeOutOfContext());
|
|
return derivative->mapTypeIntoContext(ty);
|
|
}
|
|
|
|
EnumDecl *
|
|
LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
|
|
CanGenericSignature genericSig) {
|
|
assert(originalBB->getParent() == original);
|
|
auto &astCtx = original->getASTContext();
|
|
auto &file = getSynthesizedFile();
|
|
// Create a branching trace enum.
|
|
Mangle::ASTMangler mangler(astCtx);
|
|
auto config = this->config.withGenericSignature(genericSig);
|
|
auto enumName = mangler.mangleAutoDiffGeneratedDeclaration(
|
|
AutoDiffGeneratedDeclarationKind::BranchingTraceEnum,
|
|
original->getName().str(), originalBB->getDebugID(), kind, config);
|
|
auto enumId = astCtx.getIdentifier(enumName);
|
|
auto loc = original->getLocation().getSourceLoc();
|
|
GenericParamList *genericParams = nullptr;
|
|
if (genericSig)
|
|
genericParams = cloneGenericParameters(astCtx, &file, genericSig);
|
|
auto *branchingTraceDecl = new (astCtx) EnumDecl(
|
|
/*EnumLoc*/ SourceLoc(), /*Name*/ enumId, /*NameLoc*/ loc,
|
|
/*Inherited*/ {}, /*GenericParams*/ genericParams, /*DC*/ &file);
|
|
// Note: must mark enum as implicit to satisfy assertion in
|
|
// `Parser::parseDeclListDelayed`.
|
|
branchingTraceDecl->setImplicit();
|
|
if (genericSig)
|
|
branchingTraceDecl->setGenericSignature(genericSig);
|
|
switch (original->getEffectiveSymbolLinkage()) {
|
|
case swift::SILLinkage::Public:
|
|
case swift::SILLinkage::PublicNonABI:
|
|
// Branching trace enums shall not be resilient.
|
|
branchingTraceDecl->getAttrs().add(new (astCtx) FrozenAttr(/*implicit*/ true));
|
|
branchingTraceDecl->getAttrs().add(new (astCtx) UsableFromInlineAttr(/*Implicit*/ true));
|
|
LLVM_FALLTHROUGH;
|
|
case swift::SILLinkage::Hidden:
|
|
case swift::SILLinkage::Shared:
|
|
branchingTraceDecl->setAccess(AccessLevel::Internal);
|
|
break;
|
|
case swift::SILLinkage::Private:
|
|
branchingTraceDecl->setAccess(AccessLevel::FilePrivate);
|
|
break;
|
|
default:
|
|
// When the original function has external linkage, we create an internal
|
|
// struct for use by our own module. This is necessary for cross-cell
|
|
// differentiation in Jupyter.
|
|
// TODO: Add a test in the compiler that exercises a similar situation as
|
|
// cross-cell differentiation in Jupyter.
|
|
branchingTraceDecl->setAccess(AccessLevel::Internal);
|
|
}
|
|
file.addTopLevelDecl(branchingTraceDecl);
|
|
file.getParentModule()->clearLookupCache();
|
|
|
|
return branchingTraceDecl;
|
|
}
|
|
|
|
void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
|
|
SILLoopInfo *loopInfo) {
|
|
auto &astCtx = original->getASTContext();
|
|
auto *moduleDecl = original->getModule().getSwiftModule();
|
|
auto loc = original->getLocation().getSourceLoc();
|
|
auto *branchingTraceDecl = getBranchingTraceDecl(originalBB);
|
|
|
|
// Add basic block enum cases.
|
|
for (auto *predBB : originalBB->getPredecessorBlocks()) {
|
|
// Create dummy declaration representing enum case parameter.
|
|
auto *decl = new (astCtx)
|
|
ParamDecl(loc, loc, Identifier(), loc, Identifier(), moduleDecl);
|
|
decl->setSpecifier(ParamDecl::Specifier::Default);
|
|
// If predecessor block is in a loop, its linear map tuple will be
|
|
// indirectly referenced in memory owned by the context object. The payload
|
|
// is just a raw pointer.
|
|
if (loopInfo->getLoopFor(predBB)) {
|
|
heapAllocatedContext = true;
|
|
decl->setInterfaceType(astCtx.TheRawPointerType);
|
|
} else { // Otherwise the payload is the linear map tuple.
|
|
auto *linearMapTupleTy = getLinearMapTupleType(predBB);
|
|
// Do not create entries for unreachable predecessors
|
|
if (!linearMapTupleTy)
|
|
continue;
|
|
|
|
auto canLinearMapTupleTy = linearMapTupleTy->getCanonicalType();
|
|
decl->setInterfaceType(
|
|
canLinearMapTupleTy->hasArchetype()
|
|
? canLinearMapTupleTy->mapTypeOutOfContext() : canLinearMapTupleTy);
|
|
}
|
|
// Create enum element and enum case declarations.
|
|
auto *paramList = ParameterList::create(astCtx, {decl});
|
|
auto bbId = "bb" + std::to_string(predBB->getDebugID());
|
|
auto *enumEltDecl = new (astCtx) EnumElementDecl(
|
|
/*IdentifierLoc*/ loc, DeclName(astCtx.getIdentifier(bbId)), paramList,
|
|
loc, /*RawValueExpr*/ nullptr, branchingTraceDecl);
|
|
enumEltDecl->setImplicit();
|
|
auto *enumCaseDecl = EnumCaseDecl::create(
|
|
/*CaseLoc*/ loc, {enumEltDecl}, branchingTraceDecl);
|
|
enumCaseDecl->setImplicit();
|
|
branchingTraceDecl->addMember(enumEltDecl);
|
|
branchingTraceDecl->addMember(enumCaseDecl);
|
|
// Record enum element declaration.
|
|
branchingTraceEnumCases.insert({{predBB, originalBB}, enumEltDecl});
|
|
}
|
|
}
|
|
|
|
|
|
Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) {
|
|
SmallVector<SILValue, 4> allResults;
|
|
SmallVector<unsigned, 8> activeParamIndices;
|
|
SmallVector<unsigned, 8> activeResultIndices;
|
|
collectMinimalIndicesForFunctionCall(fai, config, activityInfo, allResults,
|
|
activeParamIndices, activeResultIndices);
|
|
|
|
// Check if there are any active results or arguments. If not, skip
|
|
// this instruction.
|
|
auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) {
|
|
return activityInfo.isActive(res, config);
|
|
});
|
|
|
|
bool hasActiveSemanticResultArgument = false;
|
|
bool hasActiveArguments = false;
|
|
auto numIndirectResults = fai.getNumIndirectSILResults();
|
|
for (auto argIdx : range(fai.getSubstCalleeConv().getNumParameters())) {
|
|
auto arg = fai.getArgumentsWithoutIndirectResults()[argIdx];
|
|
if (activityInfo.isActive(arg, config)) {
|
|
hasActiveArguments = true;
|
|
auto paramInfo = fai.getSubstCalleeConv().getParamInfoForSILArg(
|
|
numIndirectResults + argIdx);
|
|
if (paramInfo.isAutoDiffSemanticResult())
|
|
hasActiveSemanticResultArgument = true;
|
|
}
|
|
}
|
|
if (!hasActiveArguments)
|
|
return {};
|
|
if (!hasActiveResults && !hasActiveSemanticResultArgument)
|
|
return {};
|
|
|
|
// Compute differentiability parameters.
|
|
// - If the callee has `@differentiable` function type, use differentiation
|
|
// parameters from the function type.
|
|
// - Otherwise, use the active parameters.
|
|
IndexSubset *parameters;
|
|
auto origFnSubstTy = fai.getSubstCalleeType();
|
|
auto remappedOrigFnSubstTy =
|
|
remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy))
|
|
.castTo<SILFunctionType>()
|
|
->getUnsubstitutedType(original->getModule());
|
|
if (remappedOrigFnSubstTy->isDifferentiable()) {
|
|
parameters = remappedOrigFnSubstTy->getDifferentiabilityParameterIndices();
|
|
} else {
|
|
parameters = IndexSubset::get(
|
|
original->getASTContext(),
|
|
fai.getArgumentsWithoutIndirectResults().size(), activeParamIndices);
|
|
}
|
|
// Compute differentiability results.
|
|
auto *results = IndexSubset::get(original->getASTContext(),
|
|
remappedOrigFnSubstTy->getNumAutoDiffSemanticResults(),
|
|
activeResultIndices);
|
|
// Create autodiff indices for the `apply` instruction.
|
|
AutoDiffConfig applyConfig(parameters, results);
|
|
|
|
// Check for non-differentiable original function type.
|
|
auto checkNondifferentiableOriginalFunctionType =
|
|
[&](CanSILFunctionType origFnTy) {
|
|
// Check non-differentiable arguments.
|
|
for (auto paramIndex : applyConfig.parameterIndices->getIndices()) {
|
|
auto remappedParamType =
|
|
origFnTy->getParameters()[paramIndex].getSILStorageInterfaceType();
|
|
if (!remappedParamType.isDifferentiable(derivative->getModule()))
|
|
return true;
|
|
}
|
|
// Check non-differentiable results.
|
|
unsigned firstSemanticParamResultIdx = origFnTy->getNumResults();
|
|
unsigned firstYieldResultIndex = origFnTy->getNumResults() +
|
|
origFnTy->getNumAutoDiffSemanticResultsParameters();
|
|
for (auto resultIndex : applyConfig.resultIndices->getIndices()) {
|
|
SILType remappedResultType;
|
|
if (resultIndex >= firstYieldResultIndex) {
|
|
auto yieldResultIdx = resultIndex - firstYieldResultIndex;
|
|
const auto& yield = origFnTy->getYields()[yieldResultIdx];
|
|
// We do not have a good way to differentiate direct yields
|
|
if (!yield.isAutoDiffSemanticResult())
|
|
return true;
|
|
remappedResultType = yield.getSILStorageInterfaceType();
|
|
} else if (resultIndex >= firstSemanticParamResultIdx) {
|
|
auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx;
|
|
auto semanticResultArg =
|
|
*std::next(fai.getAutoDiffSemanticResultArguments().begin(),
|
|
semanticResultArgIdx);
|
|
remappedResultType = semanticResultArg->getType();
|
|
} else {
|
|
remappedResultType =
|
|
origFnTy->getResults()[resultIndex].getSILStorageInterfaceType();
|
|
}
|
|
if (!remappedResultType.isDifferentiable(derivative->getModule()))
|
|
return true;
|
|
}
|
|
return false;
|
|
};
|
|
if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy))
|
|
return nullptr;
|
|
|
|
AutoDiffDerivativeFunctionKind derivativeFnKind(kind);
|
|
auto derivativeFnType =
|
|
remappedOrigFnSubstTy
|
|
->getAutoDiffDerivativeFunctionType(
|
|
parameters, results, derivativeFnKind, context.getTypeConverter(),
|
|
LookUpConformanceInModule())
|
|
->getUnsubstitutedType(original->getModule());
|
|
|
|
auto linearMapSILType = derivativeFnType->getAllResultsInterfaceType();
|
|
if (auto tupleType = linearMapSILType.getAs<TupleType>()) {
|
|
linearMapSILType = SILType::getPrimitiveObjectType(
|
|
tupleType.getElementType(tupleType->getElements().size() - 1));
|
|
}
|
|
if (auto fnTy = linearMapSILType.getAs<SILFunctionType>()) {
|
|
linearMapSILType = SILType::getPrimitiveObjectType(
|
|
fnTy->getUnsubstitutedType(original->getModule()));
|
|
}
|
|
|
|
// IRGen requires decls to have AST types (not `SILFunctionType`), so we
|
|
// convert the `SILFunctionType` of the linear map to a `FunctionType` with
|
|
// the same parameters and results.
|
|
auto silFnTy = linearMapSILType.castTo<SILFunctionType>();
|
|
SmallVector<AnyFunctionType::Param, 8> params;
|
|
for (auto ¶m : silFnTy->getParameters()) {
|
|
ParameterTypeFlags flags;
|
|
if (param.isAutoDiffSemanticResult())
|
|
flags = flags.withInOut(true);
|
|
|
|
params.push_back(
|
|
AnyFunctionType::Param(param.getInterfaceType(), Identifier(), flags));
|
|
}
|
|
|
|
AnyFunctionType *astFnTy;
|
|
if (auto genSig = silFnTy->getSubstGenericSignature()) {
|
|
// FIXME: Verify ExtInfo state is correct, not working by accident.
|
|
GenericFunctionType::ExtInfo info;
|
|
astFnTy = GenericFunctionType::get(
|
|
genSig, params, silFnTy->getAllResultsInterfaceType().getASTType(),
|
|
info);
|
|
} else {
|
|
FunctionType::ExtInfo info;
|
|
astFnTy = FunctionType::get(
|
|
params, silFnTy->getAllResultsInterfaceType().getASTType(), info);
|
|
}
|
|
|
|
Type resultType = astFnTy->hasArchetype() ? astFnTy->mapTypeOutOfContext() : astFnTy;
|
|
if (fai.getKind() == FullApplySiteKind::TryApplyInst)
|
|
resultType = resultType->wrapInOptionalType();
|
|
|
|
return resultType;
|
|
}
|
|
|
|
void LinearMapInfo::generateDifferentiationDataStructures(
|
|
ADContext &context, SILFunction *derivativeFn) {
|
|
auto &astCtx = original->getASTContext();
|
|
// Get the derivative function generic signature.
|
|
CanGenericSignature derivativeFnGenSig = nullptr;
|
|
if (auto *derivativeFnGenEnv = derivativeFn->getGenericEnvironment())
|
|
derivativeFnGenSig =
|
|
derivativeFnGenEnv->getGenericSignature().getCanonicalSignature();
|
|
|
|
// Create branching trace enum for each original block and add it as a field
|
|
// in the corresponding struct.
|
|
StringRef traceEnumFieldName;
|
|
switch (kind) {
|
|
case AutoDiffLinearMapKind::Differential:
|
|
traceEnumFieldName = "successor";
|
|
break;
|
|
case AutoDiffLinearMapKind::Pullback:
|
|
traceEnumFieldName = "predecessor";
|
|
break;
|
|
}
|
|
|
|
for (auto &origBB : *original) {
|
|
auto *traceEnum =
|
|
createBranchingTraceDecl(&origBB, derivativeFnGenSig);
|
|
branchingTraceDecls.insert({&origBB, traceEnum});
|
|
}
|
|
|
|
// Add linear map fields to the linear map tuples.
|
|
//
|
|
// Now we need to be very careful as we're having a very subtle
|
|
// chicken-and-egg problem. We need lowered branch trace enum type for the
|
|
// linear map typle type. However branch trace enum type lowering depends on
|
|
// the lowering of its elements (at very least, the type classification of
|
|
// being trivial / non-trivial). As the lowering is cached we need to ensure
|
|
// we compute lowered type for the branch trace enum when the corresponding
|
|
// EnumDecl is fully complete: we cannot add more entries without causing some
|
|
// very subtle issues later on. However, the elements of the enum are linear
|
|
// map tuples of predecessors, that correspondingly may contain branch trace
|
|
// enums of corresponding predecessor BBs.
|
|
//
|
|
// Traverse all BBs in reverse post-order traversal order to ensure we process
|
|
// each BB before its predecessors.
|
|
llvm::ReversePostOrderTraversal<SILFunction *> RPOT(original);
|
|
for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) {
|
|
auto *origBB = *Iter;
|
|
SmallVector<TupleTypeElt, 4> linearTupleTypes;
|
|
if (!origBB->isEntry()) {
|
|
populateBranchingTraceDecl(origBB, loopInfo);
|
|
|
|
CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType();
|
|
linearTupleTypes.emplace_back(traceEnumType,
|
|
astCtx.getIdentifier(traceEnumFieldName));
|
|
}
|
|
|
|
if (isSemanticMemberAccessor(original)) {
|
|
// Do not add linear map fields for semantic member accessors, which have
|
|
// special-case pullback generation. Linear map tuples should be empty.
|
|
} else {
|
|
for (auto &inst : *origBB) {
|
|
if (auto *ai = dyn_cast<ApplyInst>(&inst))
|
|
// Skip array literal intrinsic applications since array literal
|
|
// initialization is linear and handled separately.
|
|
if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC) ||
|
|
ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC))
|
|
continue;
|
|
|
|
if (!isa<FullApplySite>(&inst))
|
|
continue;
|
|
|
|
FullApplySite fai(&inst);
|
|
// Add linear map field to struct for active apply sites instructions.
|
|
if (!shouldDifferentiateApplySite(fai))
|
|
continue;
|
|
|
|
LLVM_DEBUG(getADDebugStream()
|
|
<< "Adding linear map tuple field for " << inst);
|
|
if (Type linearMapType = getLinearMapType(context, fai)) {
|
|
LLVM_DEBUG(getADDebugStream() << "Computed type: " << linearMapType << '\n');
|
|
linearMapIndexMap.insert({fai, linearTupleTypes.size()});
|
|
linearTupleTypes.emplace_back(linearMapType);
|
|
}
|
|
}
|
|
}
|
|
|
|
linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)});
|
|
}
|
|
|
|
// Print generated linear map structs and branching trace enums.
|
|
// These declarations do not show up with `-emit-sil` because they are
|
|
// implicit. Instead, use `-Xllvm -debug-only=differentiation` to test
|
|
// declarations with FileCheck.
|
|
LLVM_DEBUG({
|
|
auto &s = getADDebugStream();
|
|
PrintOptions printOptions;
|
|
printOptions.TypeDefinitions = true;
|
|
printOptions.ExplodePatternBindingDecls = true;
|
|
printOptions.SkipImplicit = false;
|
|
s << "Generated linear map tuples and branching trace enums for @"
|
|
<< original->getName() << ":\n";
|
|
for (auto &origBB : *original) {
|
|
auto *linearMapTuple = getLinearMapTupleType(&origBB);
|
|
linearMapTuple->print(s, printOptions);
|
|
s << '\n';
|
|
}
|
|
|
|
for (auto &origBB : *original) {
|
|
auto *traceEnum = getBranchingTraceDecl(&origBB);
|
|
traceEnum->print(s, printOptions);
|
|
s << '\n';
|
|
}
|
|
});
|
|
}
|
|
|
|
/// Returns a flag that indicates whether the `apply` instruction should be
|
|
/// differentiated, given the differentiation indices of the instruction's
|
|
/// parent function. Whether the `apply` should be differentiated is determined
|
|
/// sequentially from the following conditions:
|
|
/// 1. The instruction has an active `inout` argument.
|
|
/// 2. The instruction is a call to the array literal initialization intrinsic
|
|
/// ("array.uninitialized_intrinsic"), where the result is active and where
|
|
/// there is a `store` of an active value into the array's buffer.
|
|
/// 3. The instruction has both an active result (direct or indirect) and an
|
|
/// active argument.
|
|
bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) {
|
|
// Function applications with an active inout argument should be
|
|
// differentiated.
|
|
for (auto inoutArg : applySite.getInoutArguments())
|
|
if (activityInfo.isActive(inoutArg, config))
|
|
return true;
|
|
|
|
bool hasActiveDirectResults = false;
|
|
forEachApplyDirectResult(applySite, [&](SILValue directResult) {
|
|
hasActiveDirectResults |= activityInfo.isActive(directResult, config);
|
|
});
|
|
bool hasActiveIndirectResults =
|
|
llvm::any_of(applySite.getIndirectSILResults(), [&](SILValue result) {
|
|
return activityInfo.isActive(result, config);
|
|
});
|
|
bool hasActiveResults = hasActiveDirectResults || hasActiveIndirectResults;
|
|
|
|
// TODO: Pattern match to make sure there is at least one `store` to the
|
|
// array's active buffer.
|
|
if (ArraySemanticsCall(applySite.getInstruction(),
|
|
semantics::ARRAY_UNINITIALIZED_INTRINSIC) &&
|
|
hasActiveResults)
|
|
return true;
|
|
|
|
auto arguments = applySite.getArgumentsWithoutIndirectResults();
|
|
bool hasActiveArguments = llvm::any_of(arguments, [&](SILValue arg) {
|
|
return activityInfo.isActive(arg, config);
|
|
});
|
|
return hasActiveResults && hasActiveArguments;
|
|
}
|
|
|
|
static bool shouldDifferentiateInjectEnumAddr(
|
|
const InjectEnumAddrInst &inject,
|
|
const DifferentiableActivityInfo &activityInfo,
|
|
const AutoDiffConfig &config) {
|
|
SILValue en = inject.getOperand();
|
|
for (auto use : en->getUses()) {
|
|
auto *init = dyn_cast<InitEnumDataAddrInst>(use->getUser());
|
|
if (init && activityInfo.isActive(init, config))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Returns a flag indicating whether the instruction should be differentiated,
|
|
/// given the differentiation indices of the instruction's parent function.
|
|
/// Whether the instruction should be differentiated is determined sequentially
|
|
/// from any of the following conditions:
|
|
/// 1. The instruction is a full apply site and `shouldDifferentiateApplyInst`
|
|
/// returns true.
|
|
/// 2. The instruction has a source operand and a destination operand, both
|
|
/// being active.
|
|
/// 3. The instruction is an allocation instruction and has an active result.
|
|
/// 4. The instruction performs reference counting, lifetime ending, access
|
|
/// ending, or destroying on an active operand.
|
|
/// 5. The instruction creates an SSA copy of an active operand.
|
|
bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
|
|
// A full apply site with an active argument and an active result (direct or
|
|
// indirect) should be differentiated.
|
|
if (FullApplySite::isa(inst))
|
|
return shouldDifferentiateApplySite(FullApplySite(inst));
|
|
// Anything with an active result and an active operand should be
|
|
// differentiated.
|
|
auto hasActiveOperands =
|
|
llvm::any_of(inst->getAllOperands(), [&](Operand &op) {
|
|
return activityInfo.isActive(op.get(), config);
|
|
});
|
|
auto hasActiveResults = llvm::any_of(inst->getResults(), [&](SILValue val) {
|
|
return activityInfo.isActive(val, config);
|
|
});
|
|
if (hasActiveOperands && hasActiveResults)
|
|
return true;
|
|
// `store`-like instructions do not have an SSA result, but have two
|
|
// operands that represent the source and the destination. We treat them as
|
|
// the input and the output, respectively.
|
|
// For `store`-like instructions whose destination is an element address
|
|
// from an `array.uninitialized_intrinsic` application, return true if the
|
|
// intrinsic application (representing the semantic destination) is active.
|
|
#define CHECK_INST_TYPE_ACTIVE_DEST(INST) \
|
|
if (auto *castInst = dyn_cast<INST##Inst>(inst)) \
|
|
return activityInfo.isActive(castInst->getDest(), config);
|
|
CHECK_INST_TYPE_ACTIVE_DEST(Store)
|
|
CHECK_INST_TYPE_ACTIVE_DEST(StoreBorrow)
|
|
CHECK_INST_TYPE_ACTIVE_DEST(CopyAddr)
|
|
CHECK_INST_TYPE_ACTIVE_DEST(UnconditionalCheckedCastAddr)
|
|
#undef CHECK_INST_TYPE_ACTIVE_DEST
|
|
// Should differentiate any allocation instruction that has an active result.
|
|
if ((isa<AllocationInst>(inst) && hasActiveResults))
|
|
return true;
|
|
// Should differentiate end_apply if the corresponding begin_apply is
|
|
// differentiable
|
|
if (auto *eai = dyn_cast<EndApplyInst>(inst))
|
|
return shouldDifferentiateApplySite(eai->getBeginApply());
|
|
if (hasActiveOperands) {
|
|
// Should differentiate any instruction that performs reference counting,
|
|
// lifetime ending, access ending, or destroying on an active operand.
|
|
if (isa<RefCountingInst>(inst) || isa<EndAccessInst>(inst) ||
|
|
isa<EndBorrowInst>(inst) || isa<DeallocationInst>(inst) ||
|
|
isa<DestroyValueInst>(inst) || isa<DestroyAddrInst>(inst) ||
|
|
isa<AbortApplyInst>(inst) ||
|
|
isa<YieldInst>(inst))
|
|
return true;
|
|
}
|
|
|
|
// Should differentiate `inject_enum_addr` if the corresponding
|
|
// `init_enum_addr` has an active operand.
|
|
if (auto inject = dyn_cast<InjectEnumAddrInst>(inst))
|
|
if (shouldDifferentiateInjectEnumAddr(*inject, activityInfo, config))
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
} // end namespace autodiff
|
|
} // end namespace swift
|