mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Some fixes for coroutines with normal results and `partial_apply` of coroutines were required. Fixes #55084
553 lines
22 KiB
C++
553 lines
22 KiB
C++
//===--- Common.cpp - Automatic differentiation common utils --*- C++ -*---===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Automatic differentiation common utilities.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "swift/Basic/STLExtras.h"
|
|
#define DEBUG_TYPE "differentiation"
|
|
|
|
#include "swift/SIL/ApplySite.h"
|
|
#include "swift/SILOptimizer/Differentiation/Common.h"
|
|
#include "swift/AST/TypeCheckRequests.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
#include "swift/SILOptimizer/Differentiation/ADContext.h"
|
|
|
|
namespace swift {
|
|
namespace autodiff {
|
|
|
|
raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) {
|
|
// Find the `pointer_to_address` result, peering through `index_addr`.
|
|
auto *ptai = dyn_cast<PointerToAddressInst>(v);
|
|
if (auto *iai = dyn_cast<IndexAddrInst>(v))
|
|
ptai = dyn_cast<PointerToAddressInst>(iai->getOperand(0));
|
|
if (!ptai)
|
|
return nullptr;
|
|
auto *mdi = dyn_cast<MarkDependenceInst>(
|
|
ptai->getOperand()->getDefiningInstruction());
|
|
if (!mdi)
|
|
return nullptr;
|
|
// Return the `array.uninitialized_intrinsic` application, if it exists.
|
|
if (auto *dti = dyn_cast<DestructureTupleInst>(
|
|
mdi->getValue()->getDefiningInstruction()))
|
|
return ArraySemanticsCall(dti->getOperand(),
|
|
semantics::ARRAY_UNINITIALIZED_INTRINSIC);
|
|
return nullptr;
|
|
}
|
|
|
|
bool isSemanticMemberAccessor(SILFunction *original) {
|
|
auto *dc = original->getDeclContext();
|
|
if (!dc)
|
|
return false;
|
|
auto *decl = dc->getAsDecl();
|
|
if (!decl)
|
|
return false;
|
|
auto *accessor = dyn_cast<AccessorDecl>(decl);
|
|
if (!accessor)
|
|
return false;
|
|
// Currently, only getters, setters and _modify accessors are supported.
|
|
if (accessor->getAccessorKind() != AccessorKind::Get &&
|
|
accessor->getAccessorKind() != AccessorKind::Set &&
|
|
accessor->getAccessorKind() != AccessorKind::Modify)
|
|
return false;
|
|
// Accessor must come from a `var` declaration.
|
|
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
|
|
if (!varDecl)
|
|
return false;
|
|
// Return true for stored property accessors.
|
|
if (varDecl->hasStorage() && varDecl->isInstanceMember())
|
|
return true;
|
|
// Return true for properties that have attached property wrappers.
|
|
if (varDecl->hasAttachedPropertyWrapper())
|
|
return true;
|
|
// Otherwise, return false.
|
|
// User-defined accessors can never be supported because they may use custom
|
|
// logic that does not semantically perform a member access.
|
|
return false;
|
|
}
|
|
|
|
bool hasSemanticMemberAccessorCallee(ApplySite applySite) {
|
|
if (auto *FRI = dyn_cast<FunctionRefBaseInst>(applySite.getCallee()))
|
|
if (auto *F = FRI->getReferencedFunctionOrNull())
|
|
return isSemanticMemberAccessor(F);
|
|
return false;
|
|
}
|
|
|
|
void forEachApplyDirectResult(
|
|
FullApplySite applySite,
|
|
llvm::function_ref<void(SILValue)> resultCallback) {
|
|
switch (applySite.getKind()) {
|
|
case FullApplySiteKind::ApplyInst: {
|
|
auto *ai = cast<ApplyInst>(applySite.getInstruction());
|
|
if (!ai->getType().is<TupleType>()) {
|
|
resultCallback(ai);
|
|
return;
|
|
}
|
|
if (auto *dti = ai->getSingleUserOfType<DestructureTupleInst>())
|
|
for (auto directResult : dti->getResults())
|
|
resultCallback(directResult);
|
|
break;
|
|
}
|
|
case FullApplySiteKind::BeginApplyInst: {
|
|
auto *bai = cast<BeginApplyInst>(applySite.getInstruction());
|
|
for (auto directResult : bai->getResults())
|
|
resultCallback(directResult);
|
|
break;
|
|
}
|
|
case FullApplySiteKind::TryApplyInst: {
|
|
auto *tai = cast<TryApplyInst>(applySite.getInstruction());
|
|
for (auto *succBB : tai->getSuccessorBlocks())
|
|
for (auto *arg : succBB->getArguments())
|
|
resultCallback(arg);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void collectAllFormalResultsInTypeOrder(SILFunction &function,
|
|
SmallVectorImpl<SILValue> &results) {
|
|
SILFunctionConventions convs(function.getLoweredFunctionType(),
|
|
function.getModule());
|
|
auto indResults = function.getIndirectResults();
|
|
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
|
|
auto retVal = retInst->getOperand();
|
|
SmallVector<SILValue, 8> dirResults;
|
|
if (auto *tupleInst =
|
|
dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction()))
|
|
dirResults.append(tupleInst->getElements().begin(),
|
|
tupleInst->getElements().end());
|
|
else
|
|
dirResults.push_back(retVal);
|
|
unsigned indResIdx = 0, dirResIdx = 0;
|
|
for (auto &resInfo : convs.getResults())
|
|
results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++]
|
|
: indResults[indResIdx++]);
|
|
// Treat semantic result parameters as semantic results.
|
|
// Append them` parameters after formal results.
|
|
for (auto i : range(convs.getNumParameters())) {
|
|
auto paramInfo = convs.getParameters()[i];
|
|
if (!paramInfo.isAutoDiffSemanticResult())
|
|
continue;
|
|
auto *argument = function.getArgumentsWithoutIndirectResults()[i];
|
|
results.push_back(argument);
|
|
}
|
|
// Treat yields as semantic results. Note that we can only differentiate
|
|
// @yield_once with simple control flow, so we can assume that the function
|
|
// contains only a single `yield` instruction
|
|
auto yieldIt =
|
|
std::find_if(function.begin(), function.end(),
|
|
[](const SILBasicBlock &BB) -> bool {
|
|
const TermInst *TI = BB.getTerminator();
|
|
return isa<YieldInst>(TI);
|
|
});
|
|
if (yieldIt != function.end()) {
|
|
auto *yieldInst = cast<YieldInst>(yieldIt->getTerminator());
|
|
for (auto yield : yieldInst->getOperandValues())
|
|
results.push_back(yield);
|
|
}
|
|
}
|
|
|
|
void collectAllDirectResultsInTypeOrder(SILFunction &function,
|
|
SmallVectorImpl<SILValue> &results) {
|
|
SILFunctionConventions convs(function.getLoweredFunctionType(),
|
|
function.getModule());
|
|
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
|
|
auto retVal = retInst->getOperand();
|
|
if (auto *tupleInst = dyn_cast<TupleInst>(retVal))
|
|
results.append(tupleInst->getElements().begin(),
|
|
tupleInst->getElements().end());
|
|
else
|
|
results.push_back(retVal);
|
|
}
|
|
|
|
void collectAllActualResultsInTypeOrder(
|
|
FullApplySite fai, ArrayRef<SILValue> extractedDirectResults,
|
|
SmallVectorImpl<SILValue> &results) {
|
|
auto calleeConvs = fai.getSubstCalleeConv();
|
|
unsigned indResIdx = 0, dirResIdx = 0;
|
|
for (auto &resInfo : calleeConvs.getResults()) {
|
|
results.push_back(resInfo.isFormalDirect()
|
|
? extractedDirectResults[dirResIdx++]
|
|
: fai.getIndirectSILResults()[indResIdx++]);
|
|
}
|
|
}
|
|
|
|
void collectMinimalIndicesForFunctionCall(
|
|
FullApplySite ai, const AutoDiffConfig &parentConfig,
|
|
const DifferentiableActivityInfo &activityInfo,
|
|
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> ¶mIndices,
|
|
SmallVectorImpl<unsigned> &resultIndices) {
|
|
auto calleeFnTy = ai.getSubstCalleeType();
|
|
auto calleeConvs = ai.getSubstCalleeConv();
|
|
|
|
// Parameter indices are indices (in the callee type signature) of parameter
|
|
// arguments that are varied or are arguments.
|
|
// Record all parameter indices in type order.
|
|
unsigned currentParamIdx = 0;
|
|
for (auto applyArg : ai.getArgumentsWithoutIndirectResults()) {
|
|
if (activityInfo.isActive(applyArg, parentConfig))
|
|
paramIndices.push_back(currentParamIdx);
|
|
++currentParamIdx;
|
|
}
|
|
|
|
// Result indices are indices (in the callee type signature) of results that
|
|
// are useful.
|
|
SmallVector<SILValue, 8> directResults;
|
|
forEachApplyDirectResult(ai, [&](SILValue directResult) {
|
|
directResults.push_back(directResult);
|
|
});
|
|
auto indirectResults = ai.getIndirectSILResults();
|
|
// Record all results and result indices in type order.
|
|
results.reserve(calleeFnTy->getNumResults());
|
|
unsigned dirResIdx = 0;
|
|
unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult();
|
|
for (const auto &resAndIdx : enumerate(calleeConvs.getResults())) {
|
|
const auto &res = resAndIdx.value();
|
|
unsigned idx = resAndIdx.index();
|
|
if (res.isFormalDirect()) {
|
|
results.push_back(directResults[dirResIdx]);
|
|
if (auto dirRes = directResults[dirResIdx])
|
|
if (dirRes && activityInfo.isActive(dirRes, parentConfig))
|
|
resultIndices.push_back(idx);
|
|
++dirResIdx;
|
|
} else {
|
|
results.push_back(indirectResults[indResIdx]);
|
|
if (activityInfo.isActive(indirectResults[indResIdx], parentConfig))
|
|
resultIndices.push_back(idx);
|
|
++indResIdx;
|
|
}
|
|
}
|
|
|
|
// Record all semantic result parameters as results.
|
|
auto semanticResultParamResultIndex = calleeFnTy->getNumResults();
|
|
for (const auto ¶mAndIdx : enumerate(calleeConvs.getParameters())) {
|
|
const auto ¶m = paramAndIdx.value();
|
|
if (!param.isAutoDiffSemanticResult())
|
|
continue;
|
|
unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults();
|
|
results.push_back(ai.getArgument(idx));
|
|
resultIndices.push_back(semanticResultParamResultIndex++);
|
|
}
|
|
|
|
// Record all yields. While we do not have a way to represent direct yields
|
|
// (_read accessors) we run activity analysis for them. These will be
|
|
// diagnosed later.
|
|
if (BeginApplyInst *bai = dyn_cast<BeginApplyInst>(*ai)) {
|
|
for (const auto &yieldAndIdx : enumerate(calleeConvs.getYields())) {
|
|
results.push_back(bai->getYieldedValues()[yieldAndIdx.index()]);
|
|
resultIndices.push_back(semanticResultParamResultIndex++);
|
|
}
|
|
}
|
|
|
|
// Make sure the function call has active results.
|
|
#ifndef NDEBUG
|
|
assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults());
|
|
assert(llvm::any_of(results, [&](SILValue result) {
|
|
return activityInfo.isActive(result, parentConfig);
|
|
}));
|
|
#endif
|
|
}
|
|
|
|
std::optional<std::pair<SILDebugLocation, SILDebugVariable>>
|
|
findDebugLocationAndVariable(SILValue originalValue) {
|
|
if (auto *asi = dyn_cast<AllocStackInst>(originalValue))
|
|
return swift::transform(asi->getVarInfo(false), [&](SILDebugVariable var) {
|
|
return std::make_pair(asi->getDebugLocation(), var);
|
|
});
|
|
for (auto *use : originalValue->getUses()) {
|
|
if (auto *dvi = dyn_cast<DebugValueInst>(use->getUser()))
|
|
return swift::transform(dvi->getVarInfo(false), [&](SILDebugVariable var) {
|
|
// We need to drop `op_deref` here as we're transferring debug info
|
|
// location from debug_value instruction (which describes how to get value)
|
|
// into alloc_stack (which describes the location)
|
|
if (var.DIExpr.startsWithDeref())
|
|
var.DIExpr.eraseElement(var.DIExpr.element_begin());
|
|
return std::make_pair(dvi->getDebugLocation(), var);
|
|
});
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Diagnostic utilities
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SILLocation getValidLocation(SILValue v) {
|
|
auto loc = v.getLoc();
|
|
if (loc.isNull() || loc.getSourceLoc().isInvalid())
|
|
loc = v->getFunction()->getLocation();
|
|
return loc;
|
|
}
|
|
|
|
SILLocation getValidLocation(SILInstruction *inst) {
|
|
auto loc = inst->getLoc();
|
|
if (loc.isNull() || loc.getSourceLoc().isInvalid())
|
|
loc = inst->getFunction()->getLocation();
|
|
return loc;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Tangent property lookup utilities
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
|
|
CanType baseType, SILLocation loc,
|
|
DifferentiationInvoker invoker) {
|
|
auto &astCtx = context.getASTContext();
|
|
auto tanFieldInfo = evaluateOrDefault(
|
|
astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType},
|
|
TangentPropertyInfo(nullptr));
|
|
// If no error, return the tangent property.
|
|
if (tanFieldInfo)
|
|
return tanFieldInfo.tangentProperty;
|
|
// Otherwise, diagnose error and return nullptr.
|
|
assert(tanFieldInfo.error);
|
|
auto *parentDC = originalField->getDeclContext();
|
|
assert(parentDC->isTypeContext());
|
|
auto parentDeclName = parentDC->getSelfNominalTypeDecl()->getNameStr();
|
|
auto fieldName = originalField->getNameStr();
|
|
auto sourceLoc = loc.getSourceLoc();
|
|
switch (tanFieldInfo.error->kind) {
|
|
case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
|
|
llvm_unreachable(
|
|
"`@noDerivative` stored property accesses should not be "
|
|
"differentiated; activity analysis should not mark as varied");
|
|
case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
|
|
context.emitNondifferentiabilityError(
|
|
sourceLoc, invoker,
|
|
diag::autodiff_stored_property_parent_not_differentiable,
|
|
parentDeclName, fieldName);
|
|
break;
|
|
case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
|
|
context.emitNondifferentiabilityError(
|
|
sourceLoc, invoker, diag::autodiff_stored_property_not_differentiable,
|
|
parentDeclName, fieldName, originalField->getInterfaceType());
|
|
break;
|
|
case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
|
|
context.emitNondifferentiabilityError(
|
|
sourceLoc, invoker, diag::autodiff_stored_property_tangent_not_struct,
|
|
parentDeclName, fieldName);
|
|
break;
|
|
case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
|
|
context.emitNondifferentiabilityError(
|
|
sourceLoc, invoker,
|
|
diag::autodiff_stored_property_no_corresponding_tangent, parentDeclName,
|
|
fieldName);
|
|
break;
|
|
case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
|
|
context.emitNondifferentiabilityError(
|
|
sourceLoc, invoker, diag::autodiff_tangent_property_wrong_type,
|
|
parentDeclName, fieldName, tanFieldInfo.error->getType());
|
|
break;
|
|
case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
|
|
context.emitNondifferentiabilityError(
|
|
sourceLoc, invoker, diag::autodiff_tangent_property_not_stored,
|
|
parentDeclName, fieldName);
|
|
break;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
VarDecl *getTangentStoredProperty(ADContext &context,
|
|
SingleValueInstruction *projectionInst,
|
|
CanType baseType,
|
|
DifferentiationInvoker invoker) {
|
|
assert(isa<StructExtractInst>(projectionInst) ||
|
|
isa<StructElementAddrInst>(projectionInst) ||
|
|
isa<RefElementAddrInst>(projectionInst));
|
|
Projection proj(projectionInst);
|
|
auto loc = getValidLocation(projectionInst);
|
|
auto *field = proj.getVarDecl(projectionInst->getOperand(0)->getType());
|
|
return getTangentStoredProperty(context, field, baseType,
|
|
loc, invoker);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Code emission utilities
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
|
|
SILLocation loc) {
|
|
if (elements.size() == 1)
|
|
return elements.front();
|
|
return builder.createTuple(loc, elements);
|
|
}
|
|
|
|
void extractAllElements(SILValue value, SILBuilder &builder,
|
|
SmallVectorImpl<SILValue> &results) {
|
|
auto tupleType = value->getType().getAs<TupleType>();
|
|
if (!tupleType) {
|
|
results.push_back(value);
|
|
return;
|
|
}
|
|
if (builder.hasOwnership()) {
|
|
auto *dti = builder.createDestructureTuple(value.getLoc(), value);
|
|
results.append(dti->getResults().begin(), dti->getResults().end());
|
|
return;
|
|
}
|
|
for (auto i : range(tupleType->getNumElements()))
|
|
results.push_back(builder.createTupleExtract(value.getLoc(), value, i));
|
|
}
|
|
|
|
SILValue emitMemoryLayoutSize(
|
|
SILBuilder &builder, SILLocation loc, CanType type) {
|
|
auto &ctx = builder.getASTContext();
|
|
auto id = ctx.getIdentifier(getBuiltinName(BuiltinValueKind::Sizeof));
|
|
auto *builtin = cast<FuncDecl>(getBuiltinValueDecl(ctx, id));
|
|
auto metatypeTy = SILType::getPrimitiveObjectType(
|
|
CanMetatypeType::get(type, MetatypeRepresentation::Thin));
|
|
auto metatypeVal = builder.createMetatype(loc, metatypeTy);
|
|
return builder.createBuiltin(
|
|
loc, id, SILType::getBuiltinWordType(ctx),
|
|
SubstitutionMap::get(
|
|
builtin->getGenericSignature(), ArrayRef<Type>{type},
|
|
LookUpConformanceInModule()),
|
|
{metatypeVal});
|
|
}
|
|
|
|
SILValue emitProjectTopLevelSubcontext(
|
|
SILBuilder &builder, SILLocation loc, SILValue context,
|
|
SILType subcontextType) {
|
|
assert(context->getOwnershipKind() == OwnershipKind::Guaranteed);
|
|
auto &ctx = builder.getASTContext();
|
|
auto id = ctx.getIdentifier(
|
|
getBuiltinName(BuiltinValueKind::AutoDiffProjectTopLevelSubcontext));
|
|
assert(context->getType() == SILType::getNativeObjectType(ctx));
|
|
auto *subcontextAddr = builder.createBuiltin(
|
|
loc, id, SILType::getRawPointerType(ctx), SubstitutionMap(), {context});
|
|
return builder.createPointerToAddress(
|
|
loc, subcontextAddr, subcontextType.getAddressType(), /*isStrict*/ true);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utilities for looking up derivatives of functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns the AbstractFunctionDecl corresponding to `F`. If there isn't one,
|
|
/// returns `nullptr`.
|
|
static AbstractFunctionDecl *findAbstractFunctionDecl(SILFunction *F) {
|
|
auto *DC = F->getDeclContext();
|
|
if (!DC)
|
|
return nullptr;
|
|
auto *D = DC->getAsDecl();
|
|
if (!D)
|
|
return nullptr;
|
|
return dyn_cast<AbstractFunctionDecl>(D);
|
|
}
|
|
|
|
SILDifferentiabilityWitness *
|
|
getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
|
|
IndexSubset *parameterIndices,
|
|
IndexSubset *resultIndices) {
|
|
for (auto *w : module.lookUpDifferentiabilityWitnessesForFunction(
|
|
original->getName())) {
|
|
if (w->getParameterIndices() == parameterIndices &&
|
|
w->getResultIndices() == resultIndices)
|
|
return w;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
std::optional<AutoDiffConfig>
|
|
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
|
|
IndexSubset *parameterIndices,
|
|
IndexSubset *&minimalASTParameterIndices) {
|
|
std::optional<AutoDiffConfig> minimalConfig = std::nullopt;
|
|
auto configs = original->getDerivativeFunctionConfigurations();
|
|
for (auto &config : configs) {
|
|
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
|
|
config.parameterIndices,
|
|
original->getInterfaceType()->castTo<AnyFunctionType>());
|
|
|
|
if (silParameterIndices->getCapacity() < parameterIndices->getCapacity()) {
|
|
silParameterIndices =
|
|
silParameterIndices->extendingCapacity(original->getASTContext(),
|
|
parameterIndices->getCapacity());
|
|
}
|
|
|
|
// If all indices in `parameterIndices` are in `daParameterIndices`, and
|
|
// it has fewer indices than our current candidate and a primitive VJP,
|
|
// then `attr` is our new candidate.
|
|
//
|
|
// NOTE(TF-642): `attr` may come from a un-partial-applied function and
|
|
// have larger capacity than the desired indices. We expect this logic to
|
|
// go away when `partial_apply` supports `@differentiable` callees.
|
|
if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
|
|
original->getASTContext(), silParameterIndices->getCapacity())) &&
|
|
// fewer parameters than before
|
|
(!minimalConfig ||
|
|
silParameterIndices->getNumIndices() <
|
|
minimalConfig->parameterIndices->getNumIndices())) {
|
|
minimalASTParameterIndices = config.parameterIndices;
|
|
minimalConfig =
|
|
AutoDiffConfig(silParameterIndices, config.resultIndices,
|
|
autodiff::getDifferentiabilityWitnessGenericSignature(
|
|
original->getGenericSignature(),
|
|
config.derivativeGenericSignature));
|
|
}
|
|
}
|
|
return minimalConfig;
|
|
}
|
|
|
|
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
|
|
SILModule &module, SILFunction *original, DifferentiabilityKind kind,
|
|
IndexSubset *parameterIndices, IndexSubset *resultIndices) {
|
|
// Explicit differentiability witnesses only exist on SIL functions that come
|
|
// from AST functions.
|
|
auto *originalAFD = findAbstractFunctionDecl(original);
|
|
if (!originalAFD)
|
|
return nullptr;
|
|
|
|
IndexSubset *minimalASTParameterIndices = nullptr;
|
|
auto minimalConfig = findMinimalDerivativeConfiguration(
|
|
originalAFD, parameterIndices, minimalASTParameterIndices);
|
|
if (!minimalConfig)
|
|
return nullptr;
|
|
|
|
std::string originalName = original->getName().str();
|
|
// If original function requires a foreign entry point, use the foreign SIL
|
|
// function to get or create the minimal differentiability witness.
|
|
if (requiresForeignEntryPoint(originalAFD)) {
|
|
originalName = SILDeclRef(originalAFD).asForeign().mangle();
|
|
original = module.lookUpFunction(SILDeclRef(originalAFD).asForeign());
|
|
}
|
|
|
|
auto *existingWitness = module.lookUpDifferentiabilityWitness(
|
|
{originalName, kind, *minimalConfig});
|
|
if (existingWitness)
|
|
return existingWitness;
|
|
|
|
assert(original->isExternalDeclaration() &&
|
|
"SILGen should create differentiability witnesses for all function "
|
|
"definitions with explicit differentiable attributes");
|
|
|
|
return SILDifferentiabilityWitness::createDeclaration(
|
|
module,
|
|
// Witness for @_alwaysEmitIntoClient original function must be emitted,
|
|
// otherwise a linker error would occur due to undefined reference to the
|
|
// witness symbol.
|
|
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
|
|
: SILLinkage::PublicExternal,
|
|
original, kind, minimalConfig->parameterIndices,
|
|
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
|
|
}
|
|
|
|
} // end namespace autodiff
|
|
} // end namespace swift
|