Files
swift-mirror/lib/SILOptimizer/Differentiation/Common.cpp
T
Doug Gregor f2eb7cb1a8 [SIL] Model @export(interface) and @export(implementation) on SIL functions
The `@export(interface)` and `@export(implementation)` attributes
SE-0497 are queried directly on AST nodes in several places within the
SIL pipeline. However, they don't persist when SIL functions are
serialized, meaning that clients of the original module might make
different assumptions about the availability of a given function's
definition.

Represent these attributes in a SIL function (as an optional
CodeGenerationModel), (de-)serialize them into the module, and add a
textual representation as SIL function attributes `[export_interface]`
and `[export_implementation]`.
2026-04-15 13:04:10 -07:00

564 lines
22 KiB
C++

//===--- Common.cpp - Automatic differentiation common utils --*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Automatic differentiation common utilities.
//
//===----------------------------------------------------------------------===//
#include "swift/Basic/STLExtras.h"
#define DEBUG_TYPE "differentiation"
#include "swift/SIL/ApplySite.h"
#include "swift/SILOptimizer/Differentiation/Common.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/Basic/Assertions.h"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
namespace swift {
namespace autodiff {
raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
static SILValue getArrayValueOfElementAddress(SILValue v) {
while (true) {
switch (v->getKind()) {
case ValueKind::IndexAddrInst:
case ValueKind::RefTailAddrInst:
case ValueKind::UncheckedRefCastInst:
case ValueKind::StructExtractInst:
case ValueKind::BeginBorrowInst:
v = cast<SingleValueInstruction>(v)->getOperand(0);
break;
default:
return v;
}
}
}
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) {
SILValue arr = getArrayValueOfElementAddress(v);
auto *mvir = dyn_cast<MultipleValueInstructionResult>(arr);
if (!mvir)
return nullptr;
// Return the `array.uninitialized_intrinsic` application, if it exists.
if (auto *dti = dyn_cast<DestructureTupleInst>(mvir->getParent()))
return ArraySemanticsCall(dti->getOperand(),
semantics::ARRAY_UNINITIALIZED_INTRINSIC);
return nullptr;
}
bool isSemanticMemberAccessor(SILFunction *original) {
auto *dc = original->getDeclContext();
if (!dc)
return false;
auto *decl = dc->getAsDecl();
if (!decl)
return false;
auto *accessor = dyn_cast<AccessorDecl>(decl);
if (!accessor)
return false;
// Currently, only getters, setters and _modify accessors are supported.
if (accessor->getAccessorKind() != AccessorKind::Get &&
accessor->getAccessorKind() != AccessorKind::Set &&
accessor->getAccessorKind() != AccessorKind::Modify)
return false;
// Accessor must come from a `var` declaration.
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
if (!varDecl)
return false;
// Return true for stored property accessors.
if (varDecl->hasStorage() && varDecl->isInstanceMember())
return true;
// Return true for properties that have attached property wrappers.
if (varDecl->hasAttachedPropertyWrapper())
return true;
// Otherwise, return false.
// User-defined accessors can never be supported because they may use custom
// logic that does not semantically perform a member access.
return false;
}
bool hasSemanticMemberAccessorCallee(ApplySite applySite) {
if (auto *FRI = dyn_cast<FunctionRefBaseInst>(applySite.getCallee()))
if (auto *F = FRI->getReferencedFunctionOrNull())
return isSemanticMemberAccessor(F);
return false;
}
void forEachApplyDirectResult(
FullApplySite applySite,
llvm::function_ref<void(SILValue)> resultCallback) {
switch (applySite.getKind()) {
case FullApplySiteKind::ApplyInst: {
auto *ai = cast<ApplyInst>(applySite.getInstruction());
if (!ai->getType().is<TupleType>()) {
resultCallback(ai);
return;
}
if (auto *dti = ai->getSingleUserOfType<DestructureTupleInst>())
for (auto directResult : dti->getResults())
resultCallback(directResult);
break;
}
case FullApplySiteKind::BeginApplyInst: {
auto *bai = cast<BeginApplyInst>(applySite.getInstruction());
for (auto directResult : bai->getResults())
resultCallback(directResult);
break;
}
case FullApplySiteKind::TryApplyInst: {
auto *tai = cast<TryApplyInst>(applySite.getInstruction());
for (auto *succBB : tai->getSuccessorBlocks())
for (auto *arg : succBB->getArguments())
resultCallback(arg);
break;
}
}
}
void collectAllFormalResultsInTypeOrder(SILFunction &function,
SmallVectorImpl<SILValue> &results) {
SILFunctionConventions convs(function.getLoweredFunctionType(),
function.getModule());
auto indResults = function.getIndirectResults();
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
auto retVal = retInst->getOperand();
SmallVector<SILValue, 8> dirResults;
if (auto *tupleInst =
dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction()))
dirResults.append(tupleInst->getElements().begin(),
tupleInst->getElements().end());
else
dirResults.push_back(retVal);
unsigned indResIdx = 0, dirResIdx = 0;
for (auto &resInfo : convs.getResults())
results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++]
: indResults[indResIdx++]);
// Treat semantic result parameters as semantic results.
// Append them` parameters after formal results.
for (auto i : range(convs.getNumParameters())) {
auto paramInfo = convs.getParameters()[i];
if (!paramInfo.isAutoDiffSemanticResult())
continue;
auto *argument = function.getArgumentsWithoutIndirectResults()[i];
results.push_back(argument);
}
// Treat yields as semantic results. Note that we can only differentiate
// @yield_once with simple control flow, so we can assume that the function
// contains only a single `yield` instruction
auto yieldIt =
std::find_if(function.begin(), function.end(),
[](const SILBasicBlock &BB) -> bool {
const TermInst *TI = BB.getTerminator();
return isa<YieldInst>(TI);
});
if (yieldIt != function.end()) {
auto *yieldInst = cast<YieldInst>(yieldIt->getTerminator());
for (auto yield : yieldInst->getOperandValues())
results.push_back(yield);
}
}
void collectAllDirectResultsInTypeOrder(SILFunction &function,
SmallVectorImpl<SILValue> &results) {
SILFunctionConventions convs(function.getLoweredFunctionType(),
function.getModule());
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
auto retVal = retInst->getOperand();
if (auto *tupleInst = dyn_cast<TupleInst>(retVal))
results.append(tupleInst->getElements().begin(),
tupleInst->getElements().end());
else
results.push_back(retVal);
}
void collectAllActualResultsInTypeOrder(
FullApplySite fai, ArrayRef<SILValue> extractedDirectResults,
SmallVectorImpl<SILValue> &results) {
auto calleeConvs = fai.getSubstCalleeConv();
unsigned indResIdx = 0, dirResIdx = 0;
for (auto &resInfo : calleeConvs.getResults()) {
results.push_back(resInfo.isFormalDirect()
? extractedDirectResults[dirResIdx++]
: fai.getIndirectSILResults()[indResIdx++]);
}
}
void collectMinimalIndicesForFunctionCall(
FullApplySite ai, const AutoDiffConfig &parentConfig,
const DifferentiableActivityInfo &activityInfo,
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
SmallVectorImpl<unsigned> &resultIndices) {
auto calleeFnTy = ai.getSubstCalleeType();
auto calleeConvs = ai.getSubstCalleeConv();
// Parameter indices are indices (in the callee type signature) of parameter
// arguments that are varied or are arguments.
// Record all parameter indices in type order.
unsigned currentParamIdx = 0;
for (auto applyArg : ai.getArgumentsWithoutIndirectResults()) {
if (activityInfo.isActive(applyArg, parentConfig))
paramIndices.push_back(currentParamIdx);
++currentParamIdx;
}
// Result indices are indices (in the callee type signature) of results that
// are useful.
SmallVector<SILValue, 8> directResults;
forEachApplyDirectResult(ai, [&](SILValue directResult) {
directResults.push_back(directResult);
});
auto indirectResults = ai.getIndirectSILResults();
// Record all results and result indices in type order.
results.reserve(calleeFnTy->getNumResults());
unsigned dirResIdx = 0;
unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult();
for (const auto &resAndIdx : enumerate(calleeConvs.getResults())) {
const auto &res = resAndIdx.value();
unsigned idx = resAndIdx.index();
if (res.isFormalDirect()) {
results.push_back(directResults[dirResIdx]);
if (auto dirRes = directResults[dirResIdx])
if (dirRes && activityInfo.isActive(dirRes, parentConfig))
resultIndices.push_back(idx);
++dirResIdx;
} else {
results.push_back(indirectResults[indResIdx]);
if (activityInfo.isActive(indirectResults[indResIdx], parentConfig))
resultIndices.push_back(idx);
++indResIdx;
}
}
// Record all semantic result parameters as results.
auto semanticResultParamResultIndex = calleeFnTy->getNumResults();
for (const auto &paramAndIdx : enumerate(calleeConvs.getParameters())) {
const auto &param = paramAndIdx.value();
if (!param.isAutoDiffSemanticResult())
continue;
unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults();
results.push_back(ai.getArgument(idx));
resultIndices.push_back(semanticResultParamResultIndex++);
}
// Record all yields. While we do not have a way to represent direct yields
// (_read accessors) we run activity analysis for them. These will be
// diagnosed later.
if (BeginApplyInst *bai = dyn_cast<BeginApplyInst>(*ai)) {
for (const auto &yieldAndIdx : enumerate(calleeConvs.getYields())) {
results.push_back(bai->getYieldedValues()[yieldAndIdx.index()]);
resultIndices.push_back(semanticResultParamResultIndex++);
}
}
// Make sure the function call has active results.
#ifndef NDEBUG
assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults());
assert(llvm::any_of(results, [&](SILValue result) {
return activityInfo.isActive(result, parentConfig);
}));
#endif
}
std::optional<std::pair<SILDebugLocation, SILDebugVariable>>
findDebugLocationAndVariable(SILValue originalValue) {
if (auto *asi = dyn_cast<AllocStackInst>(originalValue))
return swift::transform(asi->getVarInfo(false), [&](SILDebugVariable var) {
return std::make_pair(asi->getDebugLocation(), var);
});
for (auto *use : originalValue->getUses()) {
if (auto *dvi = dyn_cast<DebugValueInst>(use->getUser()))
return swift::transform(dvi->getVarInfo(false), [&](SILDebugVariable var) {
// We need to drop `op_deref` here as we're transferring debug info
// location from debug_value instruction (which describes how to get value)
// into alloc_stack (which describes the location)
if (var.DIExpr.startsWithDeref())
var.DIExpr.eraseElement(var.DIExpr.element_begin());
return std::make_pair(dvi->getDebugLocation(), var);
});
}
return std::nullopt;
}
//===----------------------------------------------------------------------===//
// Diagnostic utilities
//===----------------------------------------------------------------------===//
SILLocation getValidLocation(SILValue v) {
auto loc = v.getLoc();
if (loc.isNull() || loc.getSourceLoc().isInvalid())
loc = v->getFunction()->getLocation();
return loc;
}
SILLocation getValidLocation(SILInstruction *inst) {
auto loc = inst->getLoc();
if (loc.isNull() || loc.getSourceLoc().isInvalid())
loc = inst->getFunction()->getLocation();
return loc;
}
//===----------------------------------------------------------------------===//
// Tangent property lookup utilities
//===----------------------------------------------------------------------===//
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
CanType baseType, SILLocation loc,
DifferentiationInvoker invoker) {
auto &astCtx = context.getASTContext();
auto tanFieldInfo = evaluateOrDefault(
astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType},
TangentPropertyInfo(nullptr));
// If no error, return the tangent property.
if (tanFieldInfo)
return tanFieldInfo.tangentProperty;
// Otherwise, diagnose error and return nullptr.
assert(tanFieldInfo.error);
auto *parentDC = originalField->getDeclContext();
assert(parentDC->isTypeContext());
auto parentDeclName = parentDC->getSelfNominalTypeDecl()->getNameStr();
auto fieldName = originalField->getNameStr();
auto sourceLoc = loc.getSourceLoc();
switch (tanFieldInfo.error->kind) {
case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
llvm_unreachable(
"`@noDerivative` stored property accesses should not be "
"differentiated; activity analysis should not mark as varied");
case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
context.emitNondifferentiabilityError(
sourceLoc, invoker,
diag::autodiff_stored_property_parent_not_differentiable,
parentDeclName, fieldName);
break;
case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
context.emitNondifferentiabilityError(
sourceLoc, invoker, diag::autodiff_stored_property_not_differentiable,
parentDeclName, fieldName, originalField->getInterfaceType());
break;
case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
context.emitNondifferentiabilityError(
sourceLoc, invoker, diag::autodiff_stored_property_tangent_not_struct,
parentDeclName, fieldName);
break;
case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
context.emitNondifferentiabilityError(
sourceLoc, invoker,
diag::autodiff_stored_property_no_corresponding_tangent, parentDeclName,
fieldName);
break;
case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
context.emitNondifferentiabilityError(
sourceLoc, invoker, diag::autodiff_tangent_property_wrong_type,
parentDeclName, fieldName, tanFieldInfo.error->getType());
break;
case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
context.emitNondifferentiabilityError(
sourceLoc, invoker, diag::autodiff_tangent_property_not_stored,
parentDeclName, fieldName);
break;
}
return nullptr;
}
VarDecl *getTangentStoredProperty(ADContext &context,
SingleValueInstruction *projectionInst,
CanType baseType,
DifferentiationInvoker invoker) {
assert(isa<StructExtractInst>(projectionInst) ||
isa<StructElementAddrInst>(projectionInst) ||
isa<RefElementAddrInst>(projectionInst));
Projection proj(projectionInst);
auto loc = getValidLocation(projectionInst);
auto *field = proj.getVarDecl(projectionInst->getOperand(0)->getType());
return getTangentStoredProperty(context, field, baseType,
loc, invoker);
}
//===----------------------------------------------------------------------===//
// Code emission utilities
//===----------------------------------------------------------------------===//
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
SILLocation loc) {
if (elements.size() == 1)
return elements.front();
return builder.createTuple(loc, elements);
}
void extractAllElements(SILValue value, SILBuilder &builder,
SmallVectorImpl<SILValue> &results) {
auto tupleType = value->getType().getAs<TupleType>();
if (!tupleType) {
results.push_back(value);
return;
}
if (builder.hasOwnership()) {
auto *dti = builder.createDestructureTuple(value.getLoc(), value);
results.append(dti->getResults().begin(), dti->getResults().end());
return;
}
for (auto i : range(tupleType->getNumElements()))
results.push_back(builder.createTupleExtract(value.getLoc(), value, i));
}
SILValue emitMemoryLayoutSize(
SILBuilder &builder, SILLocation loc, CanType type) {
auto &ctx = builder.getASTContext();
auto id = ctx.getIdentifier(getBuiltinName(BuiltinValueKind::Sizeof));
auto *builtin = cast<FuncDecl>(getBuiltinValueDecl(ctx, id));
auto metatypeTy = SILType::getPrimitiveObjectType(
CanMetatypeType::get(type, MetatypeRepresentation::Thin));
auto metatypeVal = builder.createMetatype(loc, metatypeTy);
return builder.createBuiltin(
loc, id, SILType::getBuiltinWordType(ctx),
SubstitutionMap::get(
builtin->getGenericSignature(), ArrayRef<Type>{type},
LookUpConformanceInModule()),
{metatypeVal});
}
SILValue emitProjectTopLevelSubcontext(
SILBuilder &builder, SILLocation loc, SILValue context,
SILType subcontextType) {
assert(context->getOwnershipKind() == OwnershipKind::Guaranteed);
auto &ctx = builder.getASTContext();
auto id = ctx.getIdentifier(
getBuiltinName(BuiltinValueKind::AutoDiffProjectTopLevelSubcontext));
assert(context->getType() == SILType::getNativeObjectType(ctx));
auto *subcontextAddr = builder.createBuiltin(
loc, id, SILType::getRawPointerType(ctx), SubstitutionMap(), {context});
return builder.createPointerToAddress(
loc, subcontextAddr, subcontextType.getAddressType(), /*isStrict*/ true);
}
//===----------------------------------------------------------------------===//
// Utilities for looking up derivatives of functions
//===----------------------------------------------------------------------===//
/// Returns the AbstractFunctionDecl corresponding to `F`. If there isn't one,
/// returns `nullptr`.
static AbstractFunctionDecl *findAbstractFunctionDecl(SILFunction *F) {
auto *DC = F->getDeclContext();
if (!DC)
return nullptr;
auto *D = DC->getAsDecl();
if (!D)
return nullptr;
return dyn_cast<AbstractFunctionDecl>(D);
}
SILDifferentiabilityWitness *
getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
IndexSubset *parameterIndices,
IndexSubset *resultIndices) {
for (auto *w : module.lookUpDifferentiabilityWitnessesForFunction(
original->getName())) {
if (w->getParameterIndices() == parameterIndices &&
w->getResultIndices() == resultIndices)
return w;
}
return nullptr;
}
std::optional<AutoDiffConfig>
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
IndexSubset *parameterIndices,
IndexSubset *&minimalASTParameterIndices) {
std::optional<AutoDiffConfig> minimalConfig = std::nullopt;
auto configs = original->getDerivativeFunctionConfigurations();
for (auto &config : configs) {
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
config.parameterIndices,
original->getInterfaceType()->castTo<AnyFunctionType>());
if (silParameterIndices->getCapacity() < parameterIndices->getCapacity()) {
silParameterIndices =
silParameterIndices->extendingCapacity(original->getASTContext(),
parameterIndices->getCapacity());
}
// If all indices in `parameterIndices` are in `daParameterIndices`, and
// it has fewer indices than our current candidate and a primitive VJP,
// then `attr` is our new candidate.
//
// NOTE(TF-642): `attr` may come from a un-partial-applied function and
// have larger capacity than the desired indices. We expect this logic to
// go away when `partial_apply` supports `@differentiable` callees.
if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
original->getASTContext(), silParameterIndices->getCapacity())) &&
// fewer parameters than before
(!minimalConfig ||
silParameterIndices->getNumIndices() <
minimalConfig->parameterIndices->getNumIndices())) {
minimalASTParameterIndices = config.parameterIndices;
minimalConfig =
AutoDiffConfig(silParameterIndices, config.resultIndices,
autodiff::getDifferentiabilityWitnessGenericSignature(
original->getGenericSignature(),
config.derivativeGenericSignature));
}
}
return minimalConfig;
}
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
SILModule &module, SILFunction *original, DifferentiabilityKind kind,
IndexSubset *parameterIndices, IndexSubset *resultIndices) {
// Explicit differentiability witnesses only exist on SIL functions that come
// from AST functions.
auto *originalAFD = findAbstractFunctionDecl(original);
if (!originalAFD)
return nullptr;
IndexSubset *minimalASTParameterIndices = nullptr;
auto minimalConfig = findMinimalDerivativeConfiguration(
originalAFD, parameterIndices, minimalASTParameterIndices);
if (!minimalConfig)
return nullptr;
std::string originalName = original->getName().str();
// If original function requires a foreign entry point, use the foreign SIL
// function to get or create the minimal differentiability witness.
if (requiresForeignEntryPoint(originalAFD)) {
originalName = SILDeclRef(originalAFD).asForeign().mangle();
original = module.lookUpFunction(SILDeclRef(originalAFD).asForeign());
}
auto *existingWitness = module.lookUpDifferentiabilityWitness(
{originalName, kind, *minimalConfig});
if (existingWitness)
return existingWitness;
assert(original->isExternalDeclaration() &&
"SILGen should create differentiability witnesses for all function "
"definitions with explicit differentiable attributes");
return SILDifferentiabilityWitness::createDeclaration(
module,
// Witness for @_alwaysEmitIntoClient original function must be emitted,
// otherwise a linker error would occur due to undefined reference to the
// witness symbol.
original->isAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
: SILLinkage::PublicExternal,
original, kind, minimalConfig->parameterIndices,
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
}
} // end namespace autodiff
} // end namespace swift