[AutoDiff upstream] Add common SIL differentiation utilities.

This commit is contained in:
Dan Zheng
2020-04-05 19:17:17 -07:00
parent bb6d4ebd9f
commit 8081482b57
5 changed files with 326 additions and 1 deletions

View File

@@ -394,6 +394,8 @@ public:
void cacheVisibleDecls(SmallVectorImpl<ValueDecl *> &&globals) const;
const SmallVectorImpl<ValueDecl *> &getCachedVisibleDecls() const;
void addVisibleDecl(ValueDecl *decl);
virtual void lookupValue(DeclName name, NLKind lookupKind,
SmallVectorImpl<ValueDecl*> &result) const override;

View File

@@ -510,6 +510,17 @@ public:
return getArguments().slice(getNumIndirectSILResults());
}
InoutArgumentRange getInoutArguments() const {
switch (getKind()) {
case FullApplySiteKind::ApplyInst:
return cast<ApplyInst>(getInstruction())->getInoutArguments();
case FullApplySiteKind::TryApplyInst:
return cast<TryApplyInst>(getInstruction())->getInoutArguments();
case FullApplySiteKind::BeginApplyInst:
return cast<BeginApplyInst>(getInstruction())->getInoutArguments();
}
}
/// Returns true if \p op is the callee operand of this apply site
/// and not an argument operand.
bool isCalleeOperand(const Operand &op) const {

View File

@@ -21,6 +21,7 @@
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILModule.h"
#include "swift/SIL/TypeSubstCloner.h"
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
namespace swift {
@@ -34,12 +35,77 @@ namespace autodiff {
/// This is being used to print short debug messages within the AD pass.
raw_ostream &getADDebugStream();
/// Returns true if this is an full apply site whose callee has
/// `array.uninitialized_intrinsic` semantics.
bool isArrayLiteralIntrinsic(FullApplySite applySite);
/// If the given value `v` corresponds to an `ApplyInst` with
/// `array.uninitialized_intrinsic` semantics, returns the corresponding
/// `ApplyInst`. Otherwise, returns `nullptr`.
ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v);
/// Given an element address from an `array.uninitialized_intrinsic` `apply`
/// instruction, returns the `apply` instruction. The element address is either
/// a `pointer_to_address` or `index_addr` instruction to the `RawPointer`
/// result of the instrinsic:
///
/// %result = apply %array.uninitialized_intrinsic : $(Array<T>, RawPointer)
/// (%array, %ptr) = destructure_tuple %result
/// %elt0 = pointer_to_address %ptr to $*T // element address
/// %index_1 = integer_literal $Builtin.Word, 1
/// %elt1 = index_addr %elt0, %index_1 // element address
/// ...
ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v);
/// Given a value, finds its single `destructure_tuple` user if the value is
/// tuple-typed and such a user exists.
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value);
/// Given a full apply site, apply the given callback to each of its
/// "direct results".
///
/// - `apply`
/// Special case because `apply` returns a single (possibly tuple-typed) result
/// instead of multiple results. If the `apply` has a single
/// `destructure_tuple` user, treat the `destructure_tuple` results as the
/// `apply` direct results.
///
/// - `begin_apply`
/// Apply callback to each `begin_apply` direct result.
///
/// - `try_apply`
/// Apply callback to each `try_apply` successor basic block argument.
void forEachApplyDirectResult(
FullApplySite applySite, llvm::function_ref<void(SILValue)> resultCallback);
/// Given a function, gathers all of its formal results (both direct and
/// indirect) in an order defined by its result type. Note that "formal results"
/// refer to result values in the body of the function, not at call sites.
void collectAllFormalResultsInTypeOrder(SILFunction &function,
SmallVectorImpl<SILValue> &results);
/// Given a function, gathers all of its direct results in an order defined by
/// its result type. Note that "formal results" refer to result values in the
/// body of the function, not at call sites.
void collectAllDirectResultsInTypeOrder(SILFunction &function,
SmallVectorImpl<SILValue> &results);
/// Given a function call site, gathers all of its actual results (both direct
/// and indirect) in an order defined by its result type.
void collectAllActualResultsInTypeOrder(
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
SmallVectorImpl<SILValue> &results);
/// For an `apply` instruction with active results, compute:
/// - The results of the `apply` instruction, in type order.
/// - The set of minimal parameter and result indices for differentiating the
/// `apply` instruction.
void collectMinimalIndicesForFunctionCall(
ApplyInst *ai, SILAutoDiffIndices parentIndices,
const DifferentiableActivityInfo &activityInfo,
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
SmallVectorImpl<unsigned> &resultIndices);
/// Returns the underlying instruction for the given SILValue, if it exists,
/// peering through function conversion instructions.
template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
@@ -58,6 +124,10 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// Code emission utilities
//===----------------------------------------------------------------------===//
/// Given a range of elements, joins these into a single value. If there's
/// exactly one element, returns that element. Otherwise, creates a tuple using
/// a `tuple` instruction.
@@ -156,6 +226,59 @@ inline void createEntryArguments(SILFunction *f) {
}
}
/// Helper class for visiting basic blocks in post-order post-dominance order,
/// based on a worklist algorithm.
class PostOrderPostDominanceOrder {
SmallVector<DominanceInfoNode *, 16> buffer;
PostOrderFunctionInfo *postOrderInfo;
size_t srcIdx = 0;
public:
/// Constructor.
/// \p root The root of the post-dominator tree.
/// \p postOrderInfo The post-order info of the function.
/// \p capacity Should be the number of basic blocks in the dominator tree to
/// reduce memory allocation.
PostOrderPostDominanceOrder(DominanceInfoNode *root,
PostOrderFunctionInfo *postOrderInfo,
int capacity = 0)
: postOrderInfo(postOrderInfo) {
buffer.reserve(capacity);
buffer.push_back(root);
}
/// Get the next block from the worklist.
DominanceInfoNode *getNext() {
if (srcIdx == buffer.size())
return nullptr;
return buffer[srcIdx++];
}
/// Pushes the dominator children of a block onto the worklist in post-order.
void pushChildren(DominanceInfoNode *node) {
pushChildrenIf(node, [](SILBasicBlock *) { return true; });
}
/// Conditionally pushes the dominator children of a block onto the worklist
/// in post-order.
template <typename Pred>
void pushChildrenIf(DominanceInfoNode *node, Pred pred) {
SmallVector<DominanceInfoNode *, 4> children;
for (auto *child : *node)
children.push_back(child);
llvm::sort(children.begin(), children.end(),
[&](DominanceInfoNode *n1, DominanceInfoNode *n2) {
return postOrderInfo->getPONumber(n1->getBlock()) <
postOrderInfo->getPONumber(n2->getBlock());
});
for (auto *child : children) {
SILBasicBlock *childBB = child->getBlock();
if (pred(childBB))
buffer.push_back(child);
}
}
};
/// Cloner that remaps types using the target function's generic environment.
class BasicTypeSubstCloner final
: public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> {