mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[AutoDiff upstream] Add common SIL differentiation utilities.
This commit is contained in:
@@ -394,6 +394,8 @@ public:
|
||||
void cacheVisibleDecls(SmallVectorImpl<ValueDecl *> &&globals) const;
|
||||
const SmallVectorImpl<ValueDecl *> &getCachedVisibleDecls() const;
|
||||
|
||||
void addVisibleDecl(ValueDecl *decl);
|
||||
|
||||
virtual void lookupValue(DeclName name, NLKind lookupKind,
|
||||
SmallVectorImpl<ValueDecl*> &result) const override;
|
||||
|
||||
|
||||
@@ -510,6 +510,17 @@ public:
|
||||
return getArguments().slice(getNumIndirectSILResults());
|
||||
}
|
||||
|
||||
InoutArgumentRange getInoutArguments() const {
|
||||
switch (getKind()) {
|
||||
case FullApplySiteKind::ApplyInst:
|
||||
return cast<ApplyInst>(getInstruction())->getInoutArguments();
|
||||
case FullApplySiteKind::TryApplyInst:
|
||||
return cast<TryApplyInst>(getInstruction())->getInoutArguments();
|
||||
case FullApplySiteKind::BeginApplyInst:
|
||||
return cast<BeginApplyInst>(getInstruction())->getInoutArguments();
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if \p op is the callee operand of this apply site
|
||||
/// and not an argument operand.
|
||||
bool isCalleeOperand(const Operand &op) const {
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "swift/SIL/SILFunction.h"
|
||||
#include "swift/SIL/SILModule.h"
|
||||
#include "swift/SIL/TypeSubstCloner.h"
|
||||
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
|
||||
|
||||
namespace swift {
|
||||
|
||||
@@ -34,12 +35,77 @@ namespace autodiff {
|
||||
/// This is being used to print short debug messages within the AD pass.
|
||||
raw_ostream &getADDebugStream();
|
||||
|
||||
/// Returns true if this is an full apply site whose callee has
|
||||
/// `array.uninitialized_intrinsic` semantics.
|
||||
bool isArrayLiteralIntrinsic(FullApplySite applySite);
|
||||
|
||||
/// If the given value `v` corresponds to an `ApplyInst` with
|
||||
/// `array.uninitialized_intrinsic` semantics, returns the corresponding
|
||||
/// `ApplyInst`. Otherwise, returns `nullptr`.
|
||||
ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v);
|
||||
|
||||
/// Given an element address from an `array.uninitialized_intrinsic` `apply`
|
||||
/// instruction, returns the `apply` instruction. The element address is either
|
||||
/// a `pointer_to_address` or `index_addr` instruction to the `RawPointer`
|
||||
/// result of the instrinsic:
|
||||
///
|
||||
/// %result = apply %array.uninitialized_intrinsic : $(Array<T>, RawPointer)
|
||||
/// (%array, %ptr) = destructure_tuple %result
|
||||
/// %elt0 = pointer_to_address %ptr to $*T // element address
|
||||
/// %index_1 = integer_literal $Builtin.Word, 1
|
||||
/// %elt1 = index_addr %elt0, %index_1 // element address
|
||||
/// ...
|
||||
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v);
|
||||
|
||||
/// Given a value, finds its single `destructure_tuple` user if the value is
|
||||
/// tuple-typed and such a user exists.
|
||||
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);
|
||||
|
||||
/// Given a full apply site, apply the given callback to each of its
|
||||
/// "direct results".
|
||||
///
|
||||
/// - `apply`
|
||||
/// Special case because `apply` returns a single (possibly tuple-typed) result
|
||||
/// instead of multiple results. If the `apply` has a single
|
||||
/// `destructure_tuple` user, treat the `destructure_tuple` results as the
|
||||
/// `apply` direct results.
|
||||
///
|
||||
/// - `begin_apply`
|
||||
/// Apply callback to each `begin_apply` direct result.
|
||||
///
|
||||
/// - `try_apply`
|
||||
/// Apply callback to each `try_apply` successor basic block argument.
|
||||
void forEachApplyDirectResult(
|
||||
FullApplySite applySite, llvm::function_ref<void(SILValue)> resultCallback);
|
||||
|
||||
/// Given a function, gathers all of its formal results (both direct and
|
||||
/// indirect) in an order defined by its result type. Note that "formal results"
|
||||
/// refer to result values in the body of the function, not at call sites.
|
||||
void collectAllFormalResultsInTypeOrder(SILFunction &function,
|
||||
SmallVectorImpl<SILValue> &results);
|
||||
|
||||
/// Given a function, gathers all of its direct results in an order defined by
|
||||
/// its result type. Note that "formal results" refer to result values in the
|
||||
/// body of the function, not at call sites.
|
||||
void collectAllDirectResultsInTypeOrder(SILFunction &function,
|
||||
SmallVectorImpl<SILValue> &results);
|
||||
|
||||
/// Given a function call site, gathers all of its actual results (both direct
|
||||
/// and indirect) in an order defined by its result type.
|
||||
void collectAllActualResultsInTypeOrder(
|
||||
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
|
||||
SmallVectorImpl<SILValue> &results);
|
||||
|
||||
/// For an `apply` instruction with active results, compute:
|
||||
/// - The results of the `apply` instruction, in type order.
|
||||
/// - The set of minimal parameter and result indices for differentiating the
|
||||
/// `apply` instruction.
|
||||
void collectMinimalIndicesForFunctionCall(
|
||||
ApplyInst *ai, SILAutoDiffIndices parentIndices,
|
||||
const DifferentiableActivityInfo &activityInfo,
|
||||
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> ¶mIndices,
|
||||
SmallVectorImpl<unsigned> &resultIndices);
|
||||
|
||||
/// Returns the underlying instruction for the given SILValue, if it exists,
|
||||
/// peering through function conversion instructions.
|
||||
template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
|
||||
@@ -58,6 +124,10 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Code emission utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Given a range of elements, joins these into a single value. If there's
|
||||
/// exactly one element, returns that element. Otherwise, creates a tuple using
|
||||
/// a `tuple` instruction.
|
||||
@@ -156,6 +226,59 @@ inline void createEntryArguments(SILFunction *f) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper class for visiting basic blocks in post-order post-dominance order,
|
||||
/// based on a worklist algorithm.
|
||||
class PostOrderPostDominanceOrder {
|
||||
SmallVector<DominanceInfoNode *, 16> buffer;
|
||||
PostOrderFunctionInfo *postOrderInfo;
|
||||
size_t srcIdx = 0;
|
||||
|
||||
public:
|
||||
/// Constructor.
|
||||
/// \p root The root of the post-dominator tree.
|
||||
/// \p postOrderInfo The post-order info of the function.
|
||||
/// \p capacity Should be the number of basic blocks in the dominator tree to
|
||||
/// reduce memory allocation.
|
||||
PostOrderPostDominanceOrder(DominanceInfoNode *root,
|
||||
PostOrderFunctionInfo *postOrderInfo,
|
||||
int capacity = 0)
|
||||
: postOrderInfo(postOrderInfo) {
|
||||
buffer.reserve(capacity);
|
||||
buffer.push_back(root);
|
||||
}
|
||||
|
||||
/// Get the next block from the worklist.
|
||||
DominanceInfoNode *getNext() {
|
||||
if (srcIdx == buffer.size())
|
||||
return nullptr;
|
||||
return buffer[srcIdx++];
|
||||
}
|
||||
|
||||
/// Pushes the dominator children of a block onto the worklist in post-order.
|
||||
void pushChildren(DominanceInfoNode *node) {
|
||||
pushChildrenIf(node, [](SILBasicBlock *) { return true; });
|
||||
}
|
||||
|
||||
/// Conditionally pushes the dominator children of a block onto the worklist
|
||||
/// in post-order.
|
||||
template <typename Pred>
|
||||
void pushChildrenIf(DominanceInfoNode *node, Pred pred) {
|
||||
SmallVector<DominanceInfoNode *, 4> children;
|
||||
for (auto *child : *node)
|
||||
children.push_back(child);
|
||||
llvm::sort(children.begin(), children.end(),
|
||||
[&](DominanceInfoNode *n1, DominanceInfoNode *n2) {
|
||||
return postOrderInfo->getPONumber(n1->getBlock()) <
|
||||
postOrderInfo->getPONumber(n2->getBlock());
|
||||
});
|
||||
for (auto *child : children) {
|
||||
SILBasicBlock *childBB = child->getBlock();
|
||||
if (pred(childBB))
|
||||
buffer.push_back(child);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Cloner that remaps types using the target function's generic environment.
|
||||
class BasicTypeSubstCloner final
|
||||
: public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> {
|
||||
|
||||
@@ -2232,6 +2232,11 @@ SourceFile::getCachedVisibleDecls() const {
|
||||
return getCache().AllVisibleValues;
|
||||
}
|
||||
|
||||
void SourceFile::addVisibleDecl(ValueDecl *decl) {
|
||||
Decls->push_back(decl);
|
||||
getCache().AllVisibleValues.push_back(decl);
|
||||
}
|
||||
|
||||
static void performAutoImport(
|
||||
SourceFile &SF,
|
||||
SourceFile::ImplicitModuleImportKind implicitModuleImportKind) {
|
||||
|
||||
@@ -24,9 +24,126 @@ namespace autodiff {
|
||||
raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Code emission utilities
|
||||
// Helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool isArrayLiteralIntrinsic(FullApplySite applySite) {
|
||||
return doesApplyCalleeHaveSemantics(applySite.getCalleeOrigin(),
|
||||
"array.uninitialized_intrinsic");
|
||||
}
|
||||
|
||||
ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) {
|
||||
if (auto *ai = dyn_cast<ApplyInst>(v))
|
||||
if (isArrayLiteralIntrinsic(ai))
|
||||
return ai;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) {
|
||||
// Find the `pointer_to_address` result, peering through `index_addr`.
|
||||
auto *ptai = dyn_cast<PointerToAddressInst>(v);
|
||||
if (auto *iai = dyn_cast<IndexAddrInst>(v))
|
||||
ptai = dyn_cast<PointerToAddressInst>(iai->getOperand(0));
|
||||
if (!ptai)
|
||||
return nullptr;
|
||||
// Return the `array.uninitialized_intrinsic` application, if it exists.
|
||||
if (auto *dti = dyn_cast<DestructureTupleInst>(
|
||||
ptai->getOperand()->getDefiningInstruction())) {
|
||||
if (auto *ai = getAllocateUninitializedArrayIntrinsic(dti->getOperand()))
|
||||
return ai;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
|
||||
bool foundDestructureTupleUser = false;
|
||||
if (!value->getType().is<TupleType>())
|
||||
return nullptr;
|
||||
DestructureTupleInst *result = nullptr;
|
||||
for (auto *use : value->getUses()) {
|
||||
if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser())) {
|
||||
assert(!foundDestructureTupleUser &&
|
||||
"There should only be one `destructure_tuple` user of a tuple");
|
||||
foundDestructureTupleUser = true;
|
||||
result = dti;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void forEachApplyDirectResult(
|
||||
FullApplySite applySite,
|
||||
llvm::function_ref<void(SILValue)> resultCallback) {
|
||||
switch (applySite.getKind()) {
|
||||
case FullApplySiteKind::ApplyInst: {
|
||||
auto *ai = cast<ApplyInst>(applySite.getInstruction());
|
||||
if (!ai->getType().is<TupleType>()) {
|
||||
resultCallback(ai);
|
||||
return;
|
||||
}
|
||||
if (auto *dti = getSingleDestructureTupleUser(ai))
|
||||
for (auto directResult : dti->getResults())
|
||||
resultCallback(directResult);
|
||||
break;
|
||||
}
|
||||
case FullApplySiteKind::BeginApplyInst: {
|
||||
auto *bai = cast<BeginApplyInst>(applySite.getInstruction());
|
||||
for (auto directResult : bai->getResults())
|
||||
resultCallback(directResult);
|
||||
break;
|
||||
}
|
||||
case FullApplySiteKind::TryApplyInst: {
|
||||
auto *tai = cast<TryApplyInst>(applySite.getInstruction());
|
||||
for (auto *succBB : tai->getSuccessorBlocks())
|
||||
for (auto *arg : succBB->getArguments())
|
||||
resultCallback(arg);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void collectAllFormalResultsInTypeOrder(SILFunction &function,
|
||||
SmallVectorImpl<SILValue> &results) {
|
||||
SILFunctionConventions convs(function.getLoweredFunctionType(),
|
||||
function.getModule());
|
||||
auto indResults = function.getIndirectResults();
|
||||
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
|
||||
auto retVal = retInst->getOperand();
|
||||
SmallVector<SILValue, 8> dirResults;
|
||||
if (auto *tupleInst =
|
||||
dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction()))
|
||||
dirResults.append(tupleInst->getElements().begin(),
|
||||
tupleInst->getElements().end());
|
||||
else
|
||||
dirResults.push_back(retVal);
|
||||
unsigned indResIdx = 0, dirResIdx = 0;
|
||||
for (auto &resInfo : convs.getResults())
|
||||
results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++]
|
||||
: indResults[indResIdx++]);
|
||||
// Treat `inout` parameters as semantic results.
|
||||
// Append `inout` parameters after formal results.
|
||||
for (auto i : range(convs.getNumParameters())) {
|
||||
auto paramInfo = convs.getParameters()[i];
|
||||
if (!paramInfo.isIndirectMutating())
|
||||
continue;
|
||||
auto *argument = function.getArgumentsWithoutIndirectResults()[i];
|
||||
results.push_back(argument);
|
||||
}
|
||||
}
|
||||
|
||||
void collectAllDirectResultsInTypeOrder(SILFunction &function,
|
||||
SmallVectorImpl<SILValue> &results) {
|
||||
SILFunctionConventions convs(function.getLoweredFunctionType(),
|
||||
function.getModule());
|
||||
auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
|
||||
auto retVal = retInst->getOperand();
|
||||
if (auto *tupleInst = dyn_cast<TupleInst>(retVal))
|
||||
results.append(tupleInst->getElements().begin(),
|
||||
tupleInst->getElements().end());
|
||||
else
|
||||
results.push_back(retVal);
|
||||
}
|
||||
|
||||
void collectAllActualResultsInTypeOrder(
|
||||
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
|
||||
SmallVectorImpl<SILValue> &results) {
|
||||
@@ -39,6 +156,73 @@ void collectAllActualResultsInTypeOrder(
|
||||
}
|
||||
}
|
||||
|
||||
void collectMinimalIndicesForFunctionCall(
|
||||
ApplyInst *ai, SILAutoDiffIndices parentIndices,
|
||||
const DifferentiableActivityInfo &activityInfo,
|
||||
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> ¶mIndices,
|
||||
SmallVectorImpl<unsigned> &resultIndices) {
|
||||
auto calleeFnTy = ai->getSubstCalleeType();
|
||||
auto calleeConvs = ai->getSubstCalleeConv();
|
||||
// Parameter indices are indices (in the callee type signature) of parameter
|
||||
// arguments that are varied or are arguments.
|
||||
// Record all parameter indices in type order.
|
||||
unsigned currentParamIdx = 0;
|
||||
for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) {
|
||||
if (activityInfo.isActive(applyArg, parentIndices))
|
||||
paramIndices.push_back(currentParamIdx);
|
||||
++currentParamIdx;
|
||||
}
|
||||
// Result indices are indices (in the callee type signature) of results that
|
||||
// are useful.
|
||||
SmallVector<SILValue, 8> directResults;
|
||||
forEachApplyDirectResult(ai, [&](SILValue directResult) {
|
||||
directResults.push_back(directResult);
|
||||
});
|
||||
auto indirectResults = ai->getIndirectSILResults();
|
||||
// Record all results and result indices in type order.
|
||||
results.reserve(calleeFnTy->getNumResults());
|
||||
unsigned dirResIdx = 0;
|
||||
unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult();
|
||||
for (auto &resAndIdx : enumerate(calleeConvs.getResults())) {
|
||||
auto &res = resAndIdx.value();
|
||||
unsigned idx = resAndIdx.index();
|
||||
if (res.isFormalDirect()) {
|
||||
results.push_back(directResults[dirResIdx]);
|
||||
if (auto dirRes = directResults[dirResIdx])
|
||||
if (dirRes && activityInfo.isActive(dirRes, parentIndices))
|
||||
resultIndices.push_back(idx);
|
||||
++dirResIdx;
|
||||
} else {
|
||||
results.push_back(indirectResults[indResIdx]);
|
||||
if (activityInfo.isActive(indirectResults[indResIdx], parentIndices))
|
||||
resultIndices.push_back(idx);
|
||||
++indResIdx;
|
||||
}
|
||||
}
|
||||
// Record all `inout` parameters as results.
|
||||
auto inoutParamResultIndex = calleeFnTy->getNumResults();
|
||||
for (auto ¶mAndIdx : enumerate(calleeConvs.getParameters())) {
|
||||
auto ¶m = paramAndIdx.value();
|
||||
if (!param.isIndirectMutating())
|
||||
continue;
|
||||
unsigned idx = paramAndIdx.index();
|
||||
auto inoutArg = ai->getArgument(idx);
|
||||
results.push_back(inoutArg);
|
||||
resultIndices.push_back(inoutParamResultIndex++);
|
||||
}
|
||||
// Make sure the function call has active results.
|
||||
auto numResults = calleeFnTy->getNumResults() +
|
||||
calleeFnTy->getNumIndirectMutatingParameters();
|
||||
assert(results.size() == numResults);
|
||||
assert(llvm::any_of(results, [&](SILValue result) {
|
||||
return activityInfo.isActive(result, parentIndices);
|
||||
}));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Code emission utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
|
||||
SILLocation loc) {
|
||||
if (elements.size() == 1)
|
||||
|
||||
Reference in New Issue
Block a user