Files
swift-mirror/lib/SILOptimizer/Differentiation/Common.cpp
Anton Korobeynikov 7668666ad2 Support differentiation of wrapped value modify accessors (#78794)
Some fixes for coroutines with normal results and `partial_apply` of coroutines were required.

Fixes #55084
2025-09-10 02:30:26 -07:00

553 lines
22 KiB
C++

//===--- Common.cpp - Automatic differentiation common utils --*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Automatic differentiation common utilities.
//
//===----------------------------------------------------------------------===//
#include "swift/Basic/STLExtras.h"
#define DEBUG_TYPE "differentiation"
#include "swift/SIL/ApplySite.h"
#include "swift/SILOptimizer/Differentiation/Common.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/Basic/Assertions.h"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
namespace swift {
namespace autodiff {
raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) {
// Find the `pointer_to_address` result, peering through `index_addr`.
auto *ptai = dyn_cast<PointerToAddressInst>(v);
if (auto *iai = dyn_cast<IndexAddrInst>(v))
ptai = dyn_cast<PointerToAddressInst>(iai->getOperand(0));
if (!ptai)
return nullptr;
auto *mdi = dyn_cast<MarkDependenceInst>(
ptai->getOperand()->getDefiningInstruction());
if (!mdi)
return nullptr;
// Return the `array.uninitialized_intrinsic` application, if it exists.
if (auto *dti = dyn_cast<DestructureTupleInst>(
mdi->getValue()->getDefiningInstruction()))
return ArraySemanticsCall(dti->getOperand(),
semantics::ARRAY_UNINITIALIZED_INTRINSIC);
return nullptr;
}
bool isSemanticMemberAccessor(SILFunction *original) {
auto *dc = original->getDeclContext();
if (!dc)
return false;
auto *decl = dc->getAsDecl();
if (!decl)
return false;
auto *accessor = dyn_cast<AccessorDecl>(decl);
if (!accessor)
return false;
// Currently, only getters, setters and _modify accessors are supported.
if (accessor->getAccessorKind() != AccessorKind::Get &&
accessor->getAccessorKind() != AccessorKind::Set &&
accessor->getAccessorKind() != AccessorKind::Modify)
return false;
// Accessor must come from a `var` declaration.
auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage());
if (!varDecl)
return false;
// Return true for stored property accessors.
if (varDecl->hasStorage() && varDecl->isInstanceMember())
return true;
// Return true for properties that have attached property wrappers.
if (varDecl->hasAttachedPropertyWrapper())
return true;
// Otherwise, return false.
// User-defined accessors can never be supported because they may use custom
// logic that does not semantically perform a member access.
return false;
}
bool hasSemanticMemberAccessorCallee(ApplySite applySite) {
if (auto *FRI = dyn_cast<FunctionRefBaseInst>(applySite.getCallee()))
if (auto *F = FRI->getReferencedFunctionOrNull())
return isSemanticMemberAccessor(F);
return false;
}
void forEachApplyDirectResult(
FullApplySite applySite,
llvm::function_ref<void(SILValue)> resultCallback) {
switch (applySite.getKind()) {
case FullApplySiteKind::ApplyInst: {
auto *ai = cast<ApplyInst>(applySite.getInstruction());
if (!ai->getType().is<TupleType>()) {
resultCallback(ai);
return;
}
if (auto *dti = ai->getSingleUserOfType<DestructureTupleInst>())
for (auto directResult : dti->getResults())
resultCallback(directResult);
break;
}
case FullApplySiteKind::BeginApplyInst: {
auto *bai = cast<BeginApplyInst>(applySite.getInstruction());
for (auto directResult : bai->getResults())
resultCallback(directResult);
break;
}
case FullApplySiteKind::TryApplyInst: {
auto *tai = cast<TryApplyInst>(applySite.getInstruction());
for (auto *succBB : tai->getSuccessorBlocks())
for (auto *arg : succBB->getArguments())
resultCallback(arg);
break;
}
}
}
void collectAllFormalResultsInTypeOrder(SILFunction &function,
SmallVectorImpl<SILValue> &results) {
SILFunctionConventions convs(function.getLoweredFunctionType(),
function.getModule());
auto indResults = function.getIndirectResults();
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
auto retVal = retInst->getOperand();
SmallVector<SILValue, 8> dirResults;
if (auto *tupleInst =
dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction()))
dirResults.append(tupleInst->getElements().begin(),
tupleInst->getElements().end());
else
dirResults.push_back(retVal);
unsigned indResIdx = 0, dirResIdx = 0;
for (auto &resInfo : convs.getResults())
results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++]
: indResults[indResIdx++]);
// Treat semantic result parameters as semantic results.
// Append them` parameters after formal results.
for (auto i : range(convs.getNumParameters())) {
auto paramInfo = convs.getParameters()[i];
if (!paramInfo.isAutoDiffSemanticResult())
continue;
auto *argument = function.getArgumentsWithoutIndirectResults()[i];
results.push_back(argument);
}
// Treat yields as semantic results. Note that we can only differentiate
// @yield_once with simple control flow, so we can assume that the function
// contains only a single `yield` instruction
auto yieldIt =
std::find_if(function.begin(), function.end(),
[](const SILBasicBlock &BB) -> bool {
const TermInst *TI = BB.getTerminator();
return isa<YieldInst>(TI);
});
if (yieldIt != function.end()) {
auto *yieldInst = cast<YieldInst>(yieldIt->getTerminator());
for (auto yield : yieldInst->getOperandValues())
results.push_back(yield);
}
}
void collectAllDirectResultsInTypeOrder(SILFunction &function,
SmallVectorImpl<SILValue> &results) {
SILFunctionConventions convs(function.getLoweredFunctionType(),
function.getModule());
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
auto retVal = retInst->getOperand();
if (auto *tupleInst = dyn_cast<TupleInst>(retVal))
results.append(tupleInst->getElements().begin(),
tupleInst->getElements().end());
else
results.push_back(retVal);
}
void collectAllActualResultsInTypeOrder(
FullApplySite fai, ArrayRef<SILValue> extractedDirectResults,
SmallVectorImpl<SILValue> &results) {
auto calleeConvs = fai.getSubstCalleeConv();
unsigned indResIdx = 0, dirResIdx = 0;
for (auto &resInfo : calleeConvs.getResults()) {
results.push_back(resInfo.isFormalDirect()
? extractedDirectResults[dirResIdx++]
: fai.getIndirectSILResults()[indResIdx++]);
}
}
void collectMinimalIndicesForFunctionCall(
FullApplySite ai, const AutoDiffConfig &parentConfig,
const DifferentiableActivityInfo &activityInfo,
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
SmallVectorImpl<unsigned> &resultIndices) {
auto calleeFnTy = ai.getSubstCalleeType();
auto calleeConvs = ai.getSubstCalleeConv();
// Parameter indices are indices (in the callee type signature) of parameter
// arguments that are varied or are arguments.
// Record all parameter indices in type order.
unsigned currentParamIdx = 0;
for (auto applyArg : ai.getArgumentsWithoutIndirectResults()) {
if (activityInfo.isActive(applyArg, parentConfig))
paramIndices.push_back(currentParamIdx);
++currentParamIdx;
}
// Result indices are indices (in the callee type signature) of results that
// are useful.
SmallVector<SILValue, 8> directResults;
forEachApplyDirectResult(ai, [&](SILValue directResult) {
directResults.push_back(directResult);
});
auto indirectResults = ai.getIndirectSILResults();
// Record all results and result indices in type order.
results.reserve(calleeFnTy->getNumResults());
unsigned dirResIdx = 0;
unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult();
for (const auto &resAndIdx : enumerate(calleeConvs.getResults())) {
const auto &res = resAndIdx.value();
unsigned idx = resAndIdx.index();
if (res.isFormalDirect()) {
results.push_back(directResults[dirResIdx]);
if (auto dirRes = directResults[dirResIdx])
if (dirRes && activityInfo.isActive(dirRes, parentConfig))
resultIndices.push_back(idx);
++dirResIdx;
} else {
results.push_back(indirectResults[indResIdx]);
if (activityInfo.isActive(indirectResults[indResIdx], parentConfig))
resultIndices.push_back(idx);
++indResIdx;
}
}
// Record all semantic result parameters as results.
auto semanticResultParamResultIndex = calleeFnTy->getNumResults();
for (const auto &paramAndIdx : enumerate(calleeConvs.getParameters())) {
const auto &param = paramAndIdx.value();
if (!param.isAutoDiffSemanticResult())
continue;
unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults();
results.push_back(ai.getArgument(idx));
resultIndices.push_back(semanticResultParamResultIndex++);
}
// Record all yields. While we do not have a way to represent direct yields
// (_read accessors) we run activity analysis for them. These will be
// diagnosed later.
if (BeginApplyInst *bai = dyn_cast<BeginApplyInst>(*ai)) {
for (const auto &yieldAndIdx : enumerate(calleeConvs.getYields())) {
results.push_back(bai->getYieldedValues()[yieldAndIdx.index()]);
resultIndices.push_back(semanticResultParamResultIndex++);
}
}
// Make sure the function call has active results.
#ifndef NDEBUG
assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults());
assert(llvm::any_of(results, [&](SILValue result) {
return activityInfo.isActive(result, parentConfig);
}));
#endif
}
std::optional<std::pair<SILDebugLocation, SILDebugVariable>>
findDebugLocationAndVariable(SILValue originalValue) {
if (auto *asi = dyn_cast<AllocStackInst>(originalValue))
return swift::transform(asi->getVarInfo(false), [&](SILDebugVariable var) {
return std::make_pair(asi->getDebugLocation(), var);
});
for (auto *use : originalValue->getUses()) {
if (auto *dvi = dyn_cast<DebugValueInst>(use->getUser()))
return swift::transform(dvi->getVarInfo(false), [&](SILDebugVariable var) {
// We need to drop `op_deref` here as we're transferring debug info
// location from debug_value instruction (which describes how to get value)
// into alloc_stack (which describes the location)
if (var.DIExpr.startsWithDeref())
var.DIExpr.eraseElement(var.DIExpr.element_begin());
return std::make_pair(dvi->getDebugLocation(), var);
});
}
return std::nullopt;
}
//===----------------------------------------------------------------------===//
// Diagnostic utilities
//===----------------------------------------------------------------------===//
SILLocation getValidLocation(SILValue v) {
auto loc = v.getLoc();
if (loc.isNull() || loc.getSourceLoc().isInvalid())
loc = v->getFunction()->getLocation();
return loc;
}
SILLocation getValidLocation(SILInstruction *inst) {
auto loc = inst->getLoc();
if (loc.isNull() || loc.getSourceLoc().isInvalid())
loc = inst->getFunction()->getLocation();
return loc;
}
//===----------------------------------------------------------------------===//
// Tangent property lookup utilities
//===----------------------------------------------------------------------===//
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
CanType baseType, SILLocation loc,
DifferentiationInvoker invoker) {
auto &astCtx = context.getASTContext();
auto tanFieldInfo = evaluateOrDefault(
astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType},
TangentPropertyInfo(nullptr));
// If no error, return the tangent property.
if (tanFieldInfo)
return tanFieldInfo.tangentProperty;
// Otherwise, diagnose error and return nullptr.
assert(tanFieldInfo.error);
auto *parentDC = originalField->getDeclContext();
assert(parentDC->isTypeContext());
auto parentDeclName = parentDC->getSelfNominalTypeDecl()->getNameStr();
auto fieldName = originalField->getNameStr();
auto sourceLoc = loc.getSourceLoc();
switch (tanFieldInfo.error->kind) {
case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
llvm_unreachable(
"`@noDerivative` stored property accesses should not be "
"differentiated; activity analysis should not mark as varied");
case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
context.emitNondifferentiabilityError(
sourceLoc, invoker,
diag::autodiff_stored_property_parent_not_differentiable,
parentDeclName, fieldName);
break;
case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
context.emitNondifferentiabilityError(
sourceLoc, invoker, diag::autodiff_stored_property_not_differentiable,
parentDeclName, fieldName, originalField->getInterfaceType());
break;
case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
context.emitNondifferentiabilityError(
sourceLoc, invoker, diag::autodiff_stored_property_tangent_not_struct,
parentDeclName, fieldName);
break;
case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
context.emitNondifferentiabilityError(
sourceLoc, invoker,
diag::autodiff_stored_property_no_corresponding_tangent, parentDeclName,
fieldName);
break;
case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
context.emitNondifferentiabilityError(
sourceLoc, invoker, diag::autodiff_tangent_property_wrong_type,
parentDeclName, fieldName, tanFieldInfo.error->getType());
break;
case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
context.emitNondifferentiabilityError(
sourceLoc, invoker, diag::autodiff_tangent_property_not_stored,
parentDeclName, fieldName);
break;
}
return nullptr;
}
VarDecl *getTangentStoredProperty(ADContext &context,
SingleValueInstruction *projectionInst,
CanType baseType,
DifferentiationInvoker invoker) {
assert(isa<StructExtractInst>(projectionInst) ||
isa<StructElementAddrInst>(projectionInst) ||
isa<RefElementAddrInst>(projectionInst));
Projection proj(projectionInst);
auto loc = getValidLocation(projectionInst);
auto *field = proj.getVarDecl(projectionInst->getOperand(0)->getType());
return getTangentStoredProperty(context, field, baseType,
loc, invoker);
}
//===----------------------------------------------------------------------===//
// Code emission utilities
//===----------------------------------------------------------------------===//
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
SILLocation loc) {
if (elements.size() == 1)
return elements.front();
return builder.createTuple(loc, elements);
}
void extractAllElements(SILValue value, SILBuilder &builder,
SmallVectorImpl<SILValue> &results) {
auto tupleType = value->getType().getAs<TupleType>();
if (!tupleType) {
results.push_back(value);
return;
}
if (builder.hasOwnership()) {
auto *dti = builder.createDestructureTuple(value.getLoc(), value);
results.append(dti->getResults().begin(), dti->getResults().end());
return;
}
for (auto i : range(tupleType->getNumElements()))
results.push_back(builder.createTupleExtract(value.getLoc(), value, i));
}
SILValue emitMemoryLayoutSize(
SILBuilder &builder, SILLocation loc, CanType type) {
auto &ctx = builder.getASTContext();
auto id = ctx.getIdentifier(getBuiltinName(BuiltinValueKind::Sizeof));
auto *builtin = cast<FuncDecl>(getBuiltinValueDecl(ctx, id));
auto metatypeTy = SILType::getPrimitiveObjectType(
CanMetatypeType::get(type, MetatypeRepresentation::Thin));
auto metatypeVal = builder.createMetatype(loc, metatypeTy);
return builder.createBuiltin(
loc, id, SILType::getBuiltinWordType(ctx),
SubstitutionMap::get(
builtin->getGenericSignature(), ArrayRef<Type>{type},
LookUpConformanceInModule()),
{metatypeVal});
}
SILValue emitProjectTopLevelSubcontext(
SILBuilder &builder, SILLocation loc, SILValue context,
SILType subcontextType) {
assert(context->getOwnershipKind() == OwnershipKind::Guaranteed);
auto &ctx = builder.getASTContext();
auto id = ctx.getIdentifier(
getBuiltinName(BuiltinValueKind::AutoDiffProjectTopLevelSubcontext));
assert(context->getType() == SILType::getNativeObjectType(ctx));
auto *subcontextAddr = builder.createBuiltin(
loc, id, SILType::getRawPointerType(ctx), SubstitutionMap(), {context});
return builder.createPointerToAddress(
loc, subcontextAddr, subcontextType.getAddressType(), /*isStrict*/ true);
}
//===----------------------------------------------------------------------===//
// Utilities for looking up derivatives of functions
//===----------------------------------------------------------------------===//
/// Returns the AbstractFunctionDecl corresponding to `F`. If there isn't one,
/// returns `nullptr`.
static AbstractFunctionDecl *findAbstractFunctionDecl(SILFunction *F) {
auto *DC = F->getDeclContext();
if (!DC)
return nullptr;
auto *D = DC->getAsDecl();
if (!D)
return nullptr;
return dyn_cast<AbstractFunctionDecl>(D);
}
SILDifferentiabilityWitness *
getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
IndexSubset *parameterIndices,
IndexSubset *resultIndices) {
for (auto *w : module.lookUpDifferentiabilityWitnessesForFunction(
original->getName())) {
if (w->getParameterIndices() == parameterIndices &&
w->getResultIndices() == resultIndices)
return w;
}
return nullptr;
}
std::optional<AutoDiffConfig>
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
IndexSubset *parameterIndices,
IndexSubset *&minimalASTParameterIndices) {
std::optional<AutoDiffConfig> minimalConfig = std::nullopt;
auto configs = original->getDerivativeFunctionConfigurations();
for (auto &config : configs) {
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
config.parameterIndices,
original->getInterfaceType()->castTo<AnyFunctionType>());
if (silParameterIndices->getCapacity() < parameterIndices->getCapacity()) {
silParameterIndices =
silParameterIndices->extendingCapacity(original->getASTContext(),
parameterIndices->getCapacity());
}
// If all indices in `parameterIndices` are in `daParameterIndices`, and
// it has fewer indices than our current candidate and a primitive VJP,
// then `attr` is our new candidate.
//
// NOTE(TF-642): `attr` may come from a un-partial-applied function and
// have larger capacity than the desired indices. We expect this logic to
// go away when `partial_apply` supports `@differentiable` callees.
if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity(
original->getASTContext(), silParameterIndices->getCapacity())) &&
// fewer parameters than before
(!minimalConfig ||
silParameterIndices->getNumIndices() <
minimalConfig->parameterIndices->getNumIndices())) {
minimalASTParameterIndices = config.parameterIndices;
minimalConfig =
AutoDiffConfig(silParameterIndices, config.resultIndices,
autodiff::getDifferentiabilityWitnessGenericSignature(
original->getGenericSignature(),
config.derivativeGenericSignature));
}
}
return minimalConfig;
}
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
SILModule &module, SILFunction *original, DifferentiabilityKind kind,
IndexSubset *parameterIndices, IndexSubset *resultIndices) {
// Explicit differentiability witnesses only exist on SIL functions that come
// from AST functions.
auto *originalAFD = findAbstractFunctionDecl(original);
if (!originalAFD)
return nullptr;
IndexSubset *minimalASTParameterIndices = nullptr;
auto minimalConfig = findMinimalDerivativeConfiguration(
originalAFD, parameterIndices, minimalASTParameterIndices);
if (!minimalConfig)
return nullptr;
std::string originalName = original->getName().str();
// If original function requires a foreign entry point, use the foreign SIL
// function to get or create the minimal differentiability witness.
if (requiresForeignEntryPoint(originalAFD)) {
originalName = SILDeclRef(originalAFD).asForeign().mangle();
original = module.lookUpFunction(SILDeclRef(originalAFD).asForeign());
}
auto *existingWitness = module.lookUpDifferentiabilityWitness(
{originalName, kind, *minimalConfig});
if (existingWitness)
return existingWitness;
assert(original->isExternalDeclaration() &&
"SILGen should create differentiability witnesses for all function "
"definitions with explicit differentiable attributes");
return SILDifferentiabilityWitness::createDeclaration(
module,
// Witness for @_alwaysEmitIntoClient original function must be emitted,
// otherwise a linker error would occur due to undefined reference to the
// witness symbol.
original->markedAsAlwaysEmitIntoClient() ? SILLinkage::PublicNonABI
: SILLinkage::PublicExternal,
original, kind, minimalConfig->parameterIndices,
minimalConfig->resultIndices, minimalConfig->derivativeGenericSignature);
}
} // end namespace autodiff
} // end namespace swift