[region-isolation] Make PartitionOpEvaluator use CRTP instead of std::function callbacks.

Just a fixup requested by reviewers of incoming code that I wanted to do in a
follow on commit.
This commit is contained in:
Michael Gottesman
2023-12-04 12:27:06 -06:00
parent df03cb40ef
commit 398fa8b10f
3 changed files with 276 additions and 191 deletions

View File

@@ -934,115 +934,73 @@ private:
/// A data structure that applies a series of PartitionOps to a single Partition /// A data structure that applies a series of PartitionOps to a single Partition
/// that it modifies. /// that it modifies.
/// ///
/// Apply the passed PartitionOp to this partition, performing its action. A /// Callers use CRTP to modify its behavior. Please see the definition below of
/// `handleFailure` closure can optionally be passed in that will be called if /// a "blank" subclass PartitionOpEvaluatorBaseImpl for a description of the
/// a transferred region is required. The closure is given the PartitionOp /// methods needing to be implemented by other CRTP subclasses.
/// that failed, and the index of the SIL value that was required but template <typename Impl>
/// transferred. Additionally, a list of "nontransferrable" indices can be
/// passed in along with a handleTransferNonTransferrable closure. In the
/// event that a region containing one of the nontransferrable indices is
/// transferred, the closure will be called with the offending transfer.
struct PartitionOpEvaluator { struct PartitionOpEvaluator {
private:
Impl &asImpl() { return *reinterpret_cast<Impl *>(this); }
const Impl &asImpl() const { return *reinterpret_cast<const Impl *>(this); }
public:
using Element = PartitionPrimitives::Element; using Element = PartitionPrimitives::Element;
using Region = PartitionPrimitives::Region; using Region = PartitionPrimitives::Region;
using TransferringOperandSetFactory = using TransferringOperandSetFactory =
Partition::TransferringOperandSetFactory; Partition::TransferringOperandSetFactory;
protected:
TransferringOperandSetFactory &ptrSetFactory; TransferringOperandSetFactory &ptrSetFactory;
Partition &p; Partition &p;
/// If this PartitionOp evaluator should emit log statements. public:
bool emitLog = true;
/// If set to a non-null function, then this callback will be called if we
/// discover a transferred value was used after it was transferred.
///
/// The arguments passed to the closure are:
///
/// 1. The PartitionOp that required the element to be alive.
///
/// 2. The element in the PartitionOp that was asked to be alive.
///
/// 3. The operand of the instruction that originally transferred the
/// region. Can be used to get the immediate value transferred or the
/// transferring instruction.
std::function<void(const PartitionOp &, Element, TransferringOperand)>
failureCallback = nullptr;
/// A list of elements that cannot be transferred. Whenever we transfer, we
/// check this list to see if we are transferring the element and then call
/// transferNonTransferrableCallback. This should consist only of function
/// arguments.
ArrayRef<Element> nonTransferrableElements = {};
/// If set to a non-null function_ref, this is called if we detect a never
/// transferred element that was passed to a transfer instruction.
std::function<void(const PartitionOp &, Element)>
transferredNonTransferrableCallback = nullptr;
/// If set to a non-null function_ref, then this is used to determine if an
/// element is actor derived. If we determine that a region containing such an
/// element is transferred, we emit an error since actor regions cannot be
/// transferred.
std::function<bool(Element)> isActorDerivedCallback = nullptr;
/// Check if the representative value of \p elt is closure captured at \p
/// op.
///
/// NOTE: We actually just use the user of \p op in our callbacks. The reason
/// why we do not just pass in that SILInstruction is that then we would need
/// to access the instruction in the evaluator which creates a problem when
/// since the operand we pass in is a dummy operand.
std::function<bool(Element elt, Operand *op)> isClosureCapturedCallback =
nullptr;
PartitionOpEvaluator(Partition &p, PartitionOpEvaluator(Partition &p,
TransferringOperandSetFactory &ptrSetFactory) TransferringOperandSetFactory &ptrSetFactory)
: ptrSetFactory(ptrSetFactory), p(p) {} : ptrSetFactory(ptrSetFactory), p(p) {}
/// A wrapper around the failure callback that checks if it is nullptr. /// Call shouldEmitVerboseLogging on our CRTP subclass.
bool shouldEmitVerboseLogging() const {
return asImpl().shouldEmitVerboseLogging();
}
/// Call handleFailure on our CRTP subclass.
void handleFailure(const PartitionOp &op, Element elt, void handleFailure(const PartitionOp &op, Element elt,
TransferringOperand transferringOp) const { TransferringOperand transferringOp) const {
if (!failureCallback) return asImpl().handleFailure(op, elt, transferringOp);
return;
failureCallback(op, elt, transferringOp);
} }
/// A wrapper around transferNonTransferrableCallback that only calls it if it /// Call handleTransferNonTransferrable on our CRTP subclass.
/// is not null.
void handleTransferNonTransferrable(const PartitionOp &op, void handleTransferNonTransferrable(const PartitionOp &op,
Element elt) const { Element elt) const {
if (!transferredNonTransferrableCallback) return asImpl().handleTransferNonTransferrable(op, elt);
return;
transferredNonTransferrableCallback(op, elt);
} }
/// A wrapper around isActorDerivedCallback that returns false if /// Call isActorDerived on our CRTP subclass.
/// isActorDerivedCallback is nullptr and otherwise returns
/// isActorDerivedCallback().
bool isActorDerived(Element elt) const { bool isActorDerived(Element elt) const {
return bool(isActorDerivedCallback) && isActorDerivedCallback(elt); return asImpl().isActorDerived(elt);
} }
/// A wraper around isClosureCapturedCallback that returns false if /// Call isClosureCaptured on our CRTP subclass.
/// isClosureCapturedCallback is nullptr and otherwise returns
/// isClosureCapturedCallback.
bool isClosureCaptured(Element elt, Operand *op) const { bool isClosureCaptured(Element elt, Operand *op) const {
return bool(isClosureCapturedCallback) && return asImpl().isClosureCaptured(elt, op);
isClosureCapturedCallback(elt, op); }
/// Call getNonTransferrableElements() on our CRTP subclass.
ArrayRef<Element> getNonTransferrableElements() const {
return asImpl().getNonTransferrableElements();
} }
/// Apply \p op to the partition op. /// Apply \p op to the partition op.
void apply(const PartitionOp &op) const { void apply(const PartitionOp &op) const {
if (emitLog) { if (shouldEmitVerboseLogging()) {
REGIONBASEDISOLATION_VERBOSE_LOG(llvm::dbgs() << "Applying: "; REGIONBASEDISOLATION_VERBOSE_LOG(llvm::dbgs() << "Applying: ";
op.print(llvm::dbgs())); op.print(llvm::dbgs()));
REGIONBASEDISOLATION_VERBOSE_LOG(llvm::dbgs() << " Before: "; REGIONBASEDISOLATION_VERBOSE_LOG(llvm::dbgs() << " Before: ";
p.print(llvm::dbgs())); p.print(llvm::dbgs()));
} }
SWIFT_DEFER { SWIFT_DEFER {
if (emitLog) { if (shouldEmitVerboseLogging()) {
REGIONBASEDISOLATION_VERBOSE_LOG(llvm::dbgs() << " After: "; REGIONBASEDISOLATION_VERBOSE_LOG(llvm::dbgs() << " After: ";
p.print(llvm::dbgs())); p.print(llvm::dbgs()));
} }
@@ -1078,7 +1036,7 @@ struct PartitionOpEvaluator {
// check if any nontransferrables are transferred here, and handle the // check if any nontransferrables are transferred here, and handle the
// failure if so // failure if so
for (Element nonTransferrable : nonTransferrableElements) { for (Element nonTransferrable : getNonTransferrableElements()) {
assert( assert(
p.isTrackingElement(nonTransferrable) && p.isTrackingElement(nonTransferrable) &&
"nontransferrables should be function args and self, and therefore" "nontransferrables should be function args and self, and therefore"
@@ -1167,6 +1125,75 @@ struct PartitionOpEvaluator {
} }
}; };
/// A base implementation that can be used to default initialize CRTP
/// subclasses. Only used to implement base functionality for subclass
/// CRTPs. For true basic evaluation, use PartitionOpEvaluatorBasic below.
template <typename Subclass>
struct PartitionOpEvaluatorBaseImpl : PartitionOpEvaluator<Subclass> {
using Element = PartitionPrimitives::Element;
using Region = PartitionPrimitives::Region;
using TransferringOperandSetFactory =
Partition::TransferringOperandSetFactory;
using Super = PartitionOpEvaluator<Subclass>;
PartitionOpEvaluatorBaseImpl(Partition &workingPartition,
TransferringOperandSetFactory &ptrSetFactory)
: Super(workingPartition, ptrSetFactory) {}
/// Should we emit extra verbose logging statements when evaluating
/// PartitionOps.
bool shouldEmitVerboseLogging() const { return true; }
/// A function called if we discover a transferred value was used after it
/// was transferred.
///
/// The arguments passed to the closure are:
///
/// 1. The PartitionOp that required the element to be alive.
///
/// 2. The element in the PartitionOp that was asked to be alive.
///
/// 3. The operand of the instruction that originally transferred the
/// region. Can be used to get the immediate value transferred or the
/// transferring instruction.
void handleFailure(const PartitionOp &op, Element elt,
TransferringOperand transferringOp) const {}
/// A list of elements that cannot be transferred. Whenever we transfer, we
/// check this list to see if we are transferring the element and then call
/// transferNonTransferrableCallback. This should consist only of function
/// arguments.
ArrayRef<Element> getNonTransferrableElements() const { return {}; }
/// This is called if we detect a never transferred element that was passed to
/// a transfer instruction.
void handleTransferNonTransferrable(const PartitionOp &op,
Element elt) const {}
/// This is used to determine if an element is actor derived. If we determine
/// that a region containing such an element is transferred, we emit an error
/// since actor regions cannot be transferred.
bool isActorDerived(Element elt) const { return false; }
/// Check if the representative value of \p elt is closure captured at \p
/// op.
///
/// NOTE: We actually just use the user of \p op in our callbacks. The reason
/// why we do not just pass in that SILInstruction is that then we would need
/// to access the instruction in the evaluator which creates a problem when
/// since the operand we pass in is a dummy operand.
bool isClosureCaptured(Element elt, Operand *op) const { return false; }
};
/// A subclass of PartitionOpEvaluatorBaseImpl that doesn't have any special
/// behavior.
struct PartitionOpEvaluatorBasic final
: PartitionOpEvaluatorBaseImpl<PartitionOpEvaluatorBasic> {
PartitionOpEvaluatorBasic(Partition &workingPartition,
TransferringOperandSetFactory &ptrSetFactory)
: PartitionOpEvaluatorBaseImpl(workingPartition, ptrSetFactory) {}
};
} // namespace swift } // namespace swift
#endif // SWIFT_PARTITIONUTILS_H #endif // SWIFT_PARTITIONUTILS_H

View File

@@ -2319,16 +2319,28 @@ class BlockPartitionState {
/// to discover if an error occured. /// to discover if an error occured.
bool recomputeExitFromEntry(PartitionOpTranslator &translator) { bool recomputeExitFromEntry(PartitionOpTranslator &translator) {
Partition workingPartition = entryPartition; Partition workingPartition = entryPartition;
PartitionOpEvaluator eval(workingPartition, ptrSetFactory);
eval.isClosureCapturedCallback = [&](Element element, Operand *op) -> bool { struct ComputeEvaluator final
auto iter = translator.getValueForId(element); : PartitionOpEvaluatorBaseImpl<ComputeEvaluator> {
if (!iter) PartitionOpTranslator &translator;
return false;
auto value = iter->getRepresentative().maybeGetValue(); ComputeEvaluator(Partition &workingPartition,
if (!value) TransferringOperandSetFactory &ptrSetFactory,
return false; PartitionOpTranslator &translator)
return translator.isClosureCaptured(value, op->getUser()); : PartitionOpEvaluatorBaseImpl(workingPartition, ptrSetFactory),
translator(translator) {}
bool isClosureCaptured(Element elt, Operand *op) const {
auto iter = translator.getValueForId(elt);
if (!iter)
return false;
auto value = iter->getRepresentative().maybeGetValue();
if (!value)
return false;
return translator.isClosureCaptured(value, op->getUser());
}
}; };
ComputeEvaluator eval(workingPartition, ptrSetFactory, translator);
for (const auto &partitionOp : blockPartitionOps) { for (const auto &partitionOp : blockPartitionOps) {
// By calling apply without providing a `handleFailure` closure, errors // By calling apply without providing a `handleFailure` closure, errors
// will be suppressed // will be suppressed
@@ -2712,6 +2724,89 @@ class PartitionAnalysis {
void emitDiagnostics() { void emitDiagnostics() {
assert(solved && "diagnose should not be called before solve"); assert(solved && "diagnose should not be called before solve");
struct DiagnosticEvaluator final
: PartitionOpEvaluatorBaseImpl<DiagnosticEvaluator> {
PartitionOpTranslator &translator;
SmallFrozenMultiMap<Operand *, SILInstruction *, 8>
&transferOpToRequireInstMultiMap;
DiagnosticEvaluator(Partition &workingPartition,
TransferringOperandSetFactory &ptrSetFactory,
PartitionOpTranslator &translator,
SmallFrozenMultiMap<Operand *, SILInstruction *, 8>
&transferOpToRequireInstMultiMap)
: PartitionOpEvaluatorBaseImpl(workingPartition, ptrSetFactory),
translator(translator),
transferOpToRequireInstMultiMap(transferOpToRequireInstMultiMap) {}
void handleFailure(const PartitionOp &partitionOp,
TrackableValueID transferredVal,
TransferringOperand transferringOp) const {
// Ignore this if we have a gep like instruction that is returning a
// sendable type and transferringOp was not set with closure
// capture.
if (auto *svi =
dyn_cast<SingleValueInstruction>(partitionOp.getSourceInst())) {
if (isa<TupleElementAddrInst, StructElementAddrInst>(svi) &&
!isNonSendableType(svi->getType(), svi->getFunction())) {
bool isCapture = transferringOp.isClosureCaptured();
if (!isCapture) {
return;
}
}
}
auto rep = translator.getValueForId(transferredVal)
->getRepresentative()
.getValue();
LLVM_DEBUG(llvm::dbgs()
<< " Emitting Use After Transfer Error!\n"
<< " ID: %%" << transferredVal << "\n"
<< " Rep: " << *rep
<< " Require Inst: " << *partitionOp.getSourceInst()
<< " Transferring Op Num: "
<< transferringOp.getOperand()->getOperandNumber() << '\n'
<< " Transferring Inst: "
<< *transferringOp.getUser());
transferOpToRequireInstMultiMap.insert(transferringOp.getOperand(),
partitionOp.getSourceInst());
}
ArrayRef<Element> getNonTransferrableElements() const {
return translator.getNeverTransferredValues();
}
void
handleTransferNonTransferrable(const PartitionOp &partitionOp,
TrackableValueID transferredVal) const {
LLVM_DEBUG(llvm::dbgs()
<< " Emitting TransferNonTransferrable Error!\n"
<< " ID: %%" << transferredVal << "\n"
<< " Rep: "
<< *translator.getValueForId(transferredVal)
->getRepresentative()
.getValue());
diagnose(partitionOp, diag::regionbasedisolation_selforargtransferred);
}
bool isActorDerived(Element element) const {
auto iter = translator.getValueForId(element);
if (!iter)
return false;
return iter->isActorDerived();
}
bool isClosureCaptured(Element element, Operand *op) const {
auto iter = translator.getValueForId(element);
if (!iter)
return false;
auto value = iter->getRepresentative().maybeGetValue();
if (!value)
return false;
return translator.isClosureCaptured(value, op->getUser());
}
};
LLVM_DEBUG(llvm::dbgs() << "Emitting diagnostics for function " LLVM_DEBUG(llvm::dbgs() << "Emitting diagnostics for function "
<< function->getName() << "\n"); << function->getName() << "\n");
@@ -2726,68 +2821,8 @@ class PartitionAnalysis {
// Grab its entry partition and setup an evaluator for the partition that // Grab its entry partition and setup an evaluator for the partition that
// has callbacks that emit diagnsotics... // has callbacks that emit diagnsotics...
Partition workingPartition = blockState.getEntryPartition(); Partition workingPartition = blockState.getEntryPartition();
PartitionOpEvaluator eval(workingPartition, ptrSetFactory); DiagnosticEvaluator eval(workingPartition, ptrSetFactory, translator,
eval.failureCallback = /*handleFailure=*/ transferOpToRequireInstMultiMap);
[&](const PartitionOp &partitionOp, TrackableValueID transferredVal,
TransferringOperand transferringOp) {
// Ignore this if we have a gep like instruction that is returning a
// sendable type and transferringOp was not set with closure
// capture.
if (auto *svi = dyn_cast<SingleValueInstruction>(
partitionOp.getSourceInst())) {
if (isa<TupleElementAddrInst, StructElementAddrInst>(svi) &&
!isNonSendableType(svi->getType(), svi->getFunction())) {
bool isCapture = transferringOp.isClosureCaptured();
if (!isCapture) {
return;
}
}
}
auto rep = translator.getValueForId(transferredVal)
->getRepresentative()
.getValue();
LLVM_DEBUG(
llvm::dbgs()
<< " Emitting Use After Transfer Error!\n"
<< " ID: %%" << transferredVal << "\n"
<< " Rep: " << *rep
<< " Require Inst: " << *partitionOp.getSourceInst()
<< " Transferring Op Num: "
<< transferringOp.getOperand()->getOperandNumber() << '\n'
<< " Transferring Inst: " << *transferringOp.getUser());
transferOpToRequireInstMultiMap.insert(transferringOp.getOperand(),
partitionOp.getSourceInst());
};
eval.transferredNonTransferrableCallback =
[&](const PartitionOp &partitionOp, TrackableValueID transferredVal) {
LLVM_DEBUG(llvm::dbgs()
<< " Emitting TransferNonTransferrable Error!\n"
<< " ID: %%" << transferredVal << "\n"
<< " Rep: "
<< *translator.getValueForId(transferredVal)
->getRepresentative()
.getValue());
diagnose(partitionOp,
diag::regionbasedisolation_selforargtransferred);
};
eval.nonTransferrableElements = translator.getNeverTransferredValues();
eval.isActorDerivedCallback = [&](Element element) -> bool {
auto iter = translator.getValueForId(element);
if (!iter)
return false;
return iter->isActorDerived();
};
eval.isClosureCapturedCallback = [&](Element element,
Operand *op) -> bool {
auto iter = translator.getValueForId(element);
if (!iter)
return false;
auto value = iter->getRepresentative().maybeGetValue();
if (!value)
return false;
return translator.isClosureCaptured(value, op->getUser());
};
// And then evaluate all of our partition ops on the entry partition. // And then evaluate all of our partition ops on the entry partition.
for (auto &partitionOp : blockState.getPartitionOps()) { for (auto &partitionOp : blockState.getPartitionOps()) {

View File

@@ -67,7 +67,7 @@ TEST(PartitionUtilsTest, TestMergeAndJoin) {
Partition::TransferringOperandSetFactory factory(allocator); Partition::TransferringOperandSetFactory factory(allocator);
{ {
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorBasic eval(p1, factory);
eval.apply({PartitionOp::AssignFresh(Element(0)), eval.apply({PartitionOp::AssignFresh(Element(0)),
PartitionOp::AssignFresh(Element(1)), PartitionOp::AssignFresh(Element(1)),
PartitionOp::AssignFresh(Element(2)), PartitionOp::AssignFresh(Element(2)),
@@ -75,7 +75,7 @@ TEST(PartitionUtilsTest, TestMergeAndJoin) {
} }
{ {
PartitionOpEvaluator eval(p2, factory); PartitionOpEvaluatorBasic eval(p2, factory);
eval.apply({PartitionOp::AssignFresh(Element(5)), eval.apply({PartitionOp::AssignFresh(Element(5)),
PartitionOp::AssignFresh(Element(6)), PartitionOp::AssignFresh(Element(6)),
PartitionOp::AssignFresh(Element(7)), PartitionOp::AssignFresh(Element(7)),
@@ -83,7 +83,7 @@ TEST(PartitionUtilsTest, TestMergeAndJoin) {
} }
{ {
PartitionOpEvaluator eval(p3, factory); PartitionOpEvaluatorBasic eval(p3, factory);
eval.apply({PartitionOp::AssignFresh(Element(2)), eval.apply({PartitionOp::AssignFresh(Element(2)),
PartitionOp::AssignFresh(Element(3)), PartitionOp::AssignFresh(Element(3)),
PartitionOp::AssignFresh(Element(4)), PartitionOp::AssignFresh(Element(4)),
@@ -95,7 +95,7 @@ TEST(PartitionUtilsTest, TestMergeAndJoin) {
EXPECT_FALSE(Partition::equals(p1, p3)); EXPECT_FALSE(Partition::equals(p1, p3));
{ {
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorBasic eval(p1, factory);
eval.apply({PartitionOp::AssignFresh(Element(4)), eval.apply({PartitionOp::AssignFresh(Element(4)),
PartitionOp::AssignFresh(Element(5)), PartitionOp::AssignFresh(Element(5)),
PartitionOp::AssignFresh(Element(6)), PartitionOp::AssignFresh(Element(6)),
@@ -104,7 +104,7 @@ TEST(PartitionUtilsTest, TestMergeAndJoin) {
} }
{ {
PartitionOpEvaluator eval(p2, factory); PartitionOpEvaluatorBasic eval(p2, factory);
eval.apply({PartitionOp::AssignFresh(Element(1)), eval.apply({PartitionOp::AssignFresh(Element(1)),
PartitionOp::AssignFresh(Element(2)), PartitionOp::AssignFresh(Element(2)),
PartitionOp::AssignFresh(Element(3)), PartitionOp::AssignFresh(Element(3)),
@@ -113,7 +113,7 @@ TEST(PartitionUtilsTest, TestMergeAndJoin) {
} }
{ {
PartitionOpEvaluator eval(p3, factory); PartitionOpEvaluatorBasic eval(p3, factory);
eval.apply({PartitionOp::AssignFresh(Element(6)), eval.apply({PartitionOp::AssignFresh(Element(6)),
PartitionOp::AssignFresh(Element(7)), PartitionOp::AssignFresh(Element(7)),
PartitionOp::AssignFresh(Element(0)), PartitionOp::AssignFresh(Element(0)),
@@ -132,12 +132,12 @@ TEST(PartitionUtilsTest, TestMergeAndJoin) {
auto apply_to_p1_and_p3 = [&](PartitionOp op) { auto apply_to_p1_and_p3 = [&](PartitionOp op) {
{ {
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorBasic eval(p1, factory);
eval.apply(op); eval.apply(op);
} }
{ {
PartitionOpEvaluator eval(p3, factory); PartitionOpEvaluatorBasic eval(p3, factory);
eval.apply(op); eval.apply(op);
} }
expect_join_eq(); expect_join_eq();
@@ -145,12 +145,12 @@ TEST(PartitionUtilsTest, TestMergeAndJoin) {
auto apply_to_p2_and_p3 = [&](PartitionOp op) { auto apply_to_p2_and_p3 = [&](PartitionOp op) {
{ {
PartitionOpEvaluator eval(p2, factory); PartitionOpEvaluatorBasic eval(p2, factory);
eval.apply(op); eval.apply(op);
} }
{ {
PartitionOpEvaluator eval(p3, factory); PartitionOpEvaluatorBasic eval(p3, factory);
eval.apply(op); eval.apply(op);
} }
expect_join_eq(); expect_join_eq();
@@ -183,7 +183,7 @@ TEST(PartitionUtilsTest, Join1) {
Partition p1 = Partition::separateRegions(llvm::makeArrayRef(data1)); Partition p1 = Partition::separateRegions(llvm::makeArrayRef(data1));
{ {
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorBasic eval(p1, factory);
eval.apply({PartitionOp::Assign(Element(0), Element(0)), eval.apply({PartitionOp::Assign(Element(0), Element(0)),
PartitionOp::Assign(Element(1), Element(0)), PartitionOp::Assign(Element(1), Element(0)),
PartitionOp::Assign(Element(2), Element(2)), PartitionOp::Assign(Element(2), Element(2)),
@@ -194,7 +194,7 @@ TEST(PartitionUtilsTest, Join1) {
Partition p2 = Partition::separateRegions(llvm::makeArrayRef(data1)); Partition p2 = Partition::separateRegions(llvm::makeArrayRef(data1));
{ {
PartitionOpEvaluator eval(p2, factory); PartitionOpEvaluatorBasic eval(p2, factory);
eval.apply({PartitionOp::Assign(Element(0), Element(0)), eval.apply({PartitionOp::Assign(Element(0), Element(0)),
PartitionOp::Assign(Element(1), Element(0)), PartitionOp::Assign(Element(1), Element(0)),
PartitionOp::Assign(Element(2), Element(2)), PartitionOp::Assign(Element(2), Element(2)),
@@ -222,7 +222,7 @@ TEST(PartitionUtilsTest, Join2) {
Partition p1 = Partition::separateRegions(llvm::makeArrayRef(data1)); Partition p1 = Partition::separateRegions(llvm::makeArrayRef(data1));
{ {
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorBasic eval(p1, factory);
eval.apply({PartitionOp::Assign(Element(0), Element(0)), eval.apply({PartitionOp::Assign(Element(0), Element(0)),
PartitionOp::Assign(Element(1), Element(0)), PartitionOp::Assign(Element(1), Element(0)),
PartitionOp::Assign(Element(2), Element(2)), PartitionOp::Assign(Element(2), Element(2)),
@@ -235,7 +235,7 @@ TEST(PartitionUtilsTest, Join2) {
Element(7), Element(8), Element(9)}; Element(7), Element(8), Element(9)};
Partition p2 = Partition::separateRegions(llvm::makeArrayRef(data2)); Partition p2 = Partition::separateRegions(llvm::makeArrayRef(data2));
{ {
PartitionOpEvaluator eval(p2, factory); PartitionOpEvaluatorBasic eval(p2, factory);
eval.apply({PartitionOp::Assign(Element(4), Element(4)), eval.apply({PartitionOp::Assign(Element(4), Element(4)),
PartitionOp::Assign(Element(5), Element(5)), PartitionOp::Assign(Element(5), Element(5)),
PartitionOp::Assign(Element(6), Element(4)), PartitionOp::Assign(Element(6), Element(4)),
@@ -267,7 +267,7 @@ TEST(PartitionUtilsTest, Join2Reversed) {
Partition p1 = Partition::separateRegions(llvm::makeArrayRef(data1)); Partition p1 = Partition::separateRegions(llvm::makeArrayRef(data1));
{ {
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorBasic eval(p1, factory);
eval.apply({PartitionOp::Assign(Element(0), Element(0)), eval.apply({PartitionOp::Assign(Element(0), Element(0)),
PartitionOp::Assign(Element(1), Element(0)), PartitionOp::Assign(Element(1), Element(0)),
PartitionOp::Assign(Element(2), Element(2)), PartitionOp::Assign(Element(2), Element(2)),
@@ -280,7 +280,7 @@ TEST(PartitionUtilsTest, Join2Reversed) {
Element(7), Element(8), Element(9)}; Element(7), Element(8), Element(9)};
Partition p2 = Partition::separateRegions(llvm::makeArrayRef(data2)); Partition p2 = Partition::separateRegions(llvm::makeArrayRef(data2));
{ {
PartitionOpEvaluator eval(p2, factory); PartitionOpEvaluatorBasic eval(p2, factory);
eval.apply({PartitionOp::Assign(Element(4), Element(4)), eval.apply({PartitionOp::Assign(Element(4), Element(4)),
PartitionOp::Assign(Element(5), Element(5)), PartitionOp::Assign(Element(5), Element(5)),
PartitionOp::Assign(Element(6), Element(4)), PartitionOp::Assign(Element(6), Element(4)),
@@ -316,7 +316,7 @@ TEST(PartitionUtilsTest, JoinLarge) {
Element(25), Element(26), Element(27), Element(28), Element(29)}; Element(25), Element(26), Element(27), Element(28), Element(29)};
Partition p1 = Partition::separateRegions(llvm::makeArrayRef(data1)); Partition p1 = Partition::separateRegions(llvm::makeArrayRef(data1));
{ {
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorBasic eval(p1, factory);
eval.apply({PartitionOp::Assign(Element(0), Element(29)), eval.apply({PartitionOp::Assign(Element(0), Element(29)),
PartitionOp::Assign(Element(1), Element(17)), PartitionOp::Assign(Element(1), Element(17)),
PartitionOp::Assign(Element(2), Element(0)), PartitionOp::Assign(Element(2), Element(0)),
@@ -358,7 +358,7 @@ TEST(PartitionUtilsTest, JoinLarge) {
Element(40), Element(41), Element(42), Element(43), Element(44)}; Element(40), Element(41), Element(42), Element(43), Element(44)};
Partition p2 = Partition::separateRegions(llvm::makeArrayRef(data2)); Partition p2 = Partition::separateRegions(llvm::makeArrayRef(data2));
{ {
PartitionOpEvaluator eval(p2, factory); PartitionOpEvaluatorBasic eval(p2, factory);
eval.apply({PartitionOp::Assign(Element(15), Element(31)), eval.apply({PartitionOp::Assign(Element(15), Element(31)),
PartitionOp::Assign(Element(16), Element(34)), PartitionOp::Assign(Element(16), Element(34)),
PartitionOp::Assign(Element(17), Element(35)), PartitionOp::Assign(Element(17), Element(35)),
@@ -448,19 +448,19 @@ TEST(PartitionUtilsTest, TestAssign) {
Partition p2; Partition p2;
Partition p3; Partition p3;
PartitionOpEvaluator evalP1(p1, factory); PartitionOpEvaluatorBasic evalP1(p1, factory);
evalP1.apply({PartitionOp::AssignFresh(Element(0)), evalP1.apply({PartitionOp::AssignFresh(Element(0)),
PartitionOp::AssignFresh(Element(1)), PartitionOp::AssignFresh(Element(1)),
PartitionOp::AssignFresh(Element(2)), PartitionOp::AssignFresh(Element(2)),
PartitionOp::AssignFresh(Element(3))}); PartitionOp::AssignFresh(Element(3))});
PartitionOpEvaluator evalP2(p2, factory); PartitionOpEvaluatorBasic evalP2(p2, factory);
evalP2.apply({PartitionOp::AssignFresh(Element(0)), evalP2.apply({PartitionOp::AssignFresh(Element(0)),
PartitionOp::AssignFresh(Element(1)), PartitionOp::AssignFresh(Element(1)),
PartitionOp::AssignFresh(Element(2)), PartitionOp::AssignFresh(Element(2)),
PartitionOp::AssignFresh(Element(3))}); PartitionOp::AssignFresh(Element(3))});
PartitionOpEvaluator evalP3(p3, factory); PartitionOpEvaluatorBasic evalP3(p3, factory);
evalP3.apply({PartitionOp::AssignFresh(Element(0)), evalP3.apply({PartitionOp::AssignFresh(Element(0)),
PartitionOp::AssignFresh(Element(1)), PartitionOp::AssignFresh(Element(1)),
PartitionOp::AssignFresh(Element(2)), PartitionOp::AssignFresh(Element(2)),
@@ -528,6 +528,28 @@ TEST(PartitionUtilsTest, TestAssign) {
EXPECT_TRUE(Partition::equals(p1, p3)); EXPECT_TRUE(Partition::equals(p1, p3));
} }
namespace {
struct PartitionOpEvaluatorWithFailureCallback final
: PartitionOpEvaluatorBaseImpl<PartitionOpEvaluatorWithFailureCallback> {
using FailureCallbackTy =
std::function<void(const PartitionOp &, unsigned, TransferringOperand)>;
FailureCallbackTy failureCallback;
PartitionOpEvaluatorWithFailureCallback(
Partition &workingPartition, TransferringOperandSetFactory &ptrSetFactory,
FailureCallbackTy failureCallback)
: PartitionOpEvaluatorBaseImpl(workingPartition, ptrSetFactory),
failureCallback(failureCallback) {}
void handleFailure(const PartitionOp &op, Element elt,
TransferringOperand transferringOp) const {
failureCallback(op, elt, transferringOp);
}
};
} // namespace
// This test tests that consumption consumes entire regions as expected // This test tests that consumption consumes entire regions as expected
TEST(PartitionUtilsTest, TestConsumeAndRequire) { TEST(PartitionUtilsTest, TestConsumeAndRequire) {
llvm::BumpPtrAllocator allocator; llvm::BumpPtrAllocator allocator;
@@ -536,7 +558,7 @@ TEST(PartitionUtilsTest, TestConsumeAndRequire) {
Partition p; Partition p;
{ {
PartitionOpEvaluator eval(p, factory); PartitionOpEvaluatorBasic eval(p, factory);
eval.apply({PartitionOp::AssignFresh(Element(0)), eval.apply({PartitionOp::AssignFresh(Element(0)),
PartitionOp::AssignFresh(Element(1)), PartitionOp::AssignFresh(Element(1)),
PartitionOp::AssignFresh(Element(2)), PartitionOp::AssignFresh(Element(2)),
@@ -578,48 +600,46 @@ TEST(PartitionUtilsTest, TestConsumeAndRequire) {
TransferringOperand) { times_called++; }; TransferringOperand) { times_called++; };
{ {
PartitionOpEvaluator eval(p, factory); PartitionOpEvaluatorWithFailureCallback eval(p, factory,
eval.failureCallback = increment_times_called; increment_times_called);
eval.apply({PartitionOp::Require(Element(0)), eval.apply({PartitionOp::Require(Element(0)),
PartitionOp::Require(Element(1)), PartitionOp::Require(Element(1)),
PartitionOp::Require(Element(2))}); PartitionOp::Require(Element(2))});
} }
EXPECT_EQ(times_called, 3);
{ {
PartitionOpEvaluator eval(p, factory); PartitionOpEvaluatorWithFailureCallback eval(p, factory, never_called);
eval.failureCallback = never_called;
eval.apply({PartitionOp::Require(Element(3)), eval.apply({PartitionOp::Require(Element(3)),
PartitionOp::Require(Element(4)), PartitionOp::Require(Element(4)),
PartitionOp::Require(Element(5))}); PartitionOp::Require(Element(5))});
} }
{ {
PartitionOpEvaluator eval(p, factory); PartitionOpEvaluatorWithFailureCallback eval(p, factory,
eval.failureCallback = increment_times_called; increment_times_called);
eval.apply( eval.apply(
{PartitionOp::Require(Element(6)), PartitionOp::Require(Element(7))}); {PartitionOp::Require(Element(6)), PartitionOp::Require(Element(7))});
} }
{ {
PartitionOpEvaluator eval(p, factory); PartitionOpEvaluatorWithFailureCallback eval(p, factory, never_called);
eval.failureCallback = never_called;
eval.apply( eval.apply(
{PartitionOp::Require(Element(8)), PartitionOp::Require(Element(9))}); {PartitionOp::Require(Element(8)), PartitionOp::Require(Element(9))});
} }
{ {
PartitionOpEvaluator eval(p, factory); PartitionOpEvaluatorWithFailureCallback eval(p, factory,
eval.failureCallback = increment_times_called; increment_times_called);
eval.apply(PartitionOp::Require(Element(10))); eval.apply(PartitionOp::Require(Element(10)));
} }
{ {
PartitionOpEvaluator eval(p, factory); PartitionOpEvaluatorWithFailureCallback eval(p, factory, never_called);
eval.failureCallback = never_called;
eval.apply(PartitionOp::Require(Element(11))); eval.apply(PartitionOp::Require(Element(11)));
} }
EXPECT_TRUE(times_called == 6); EXPECT_EQ(times_called, 6);
} }
// This test tests that the copy constructor is usable to create fresh // This test tests that the copy constructor is usable to create fresh
@@ -630,7 +650,7 @@ TEST(PartitionUtilsTest, TestCopyConstructor) {
Partition p1; Partition p1;
{ {
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorBasic eval(p1, factory);
eval.apply(PartitionOp::AssignFresh(Element(0))); eval.apply(PartitionOp::AssignFresh(Element(0)));
} }
@@ -639,23 +659,25 @@ TEST(PartitionUtilsTest, TestCopyConstructor) {
// Change p1 again. // Change p1 again.
{ {
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorBasic eval(p1, factory);
eval.apply(PartitionOp::Transfer(Element(0), transferSingletons[0])); eval.apply(PartitionOp::Transfer(Element(0), transferSingletons[0]));
} }
{ {
bool failure = false; bool failure = false;
PartitionOpEvaluator eval(p1, factory); PartitionOpEvaluatorWithFailureCallback eval(
eval.failureCallback = [&](const PartitionOp &, unsigned, p1, factory, [&](const PartitionOp &, unsigned, TransferringOperand) {
TransferringOperand) { failure = true; }; failure = true;
});
eval.apply(PartitionOp::Require(Element(0))); eval.apply(PartitionOp::Require(Element(0)));
EXPECT_TRUE(failure); EXPECT_TRUE(failure);
} }
{ {
PartitionOpEvaluator eval(p2, factory); PartitionOpEvaluatorWithFailureCallback eval(
eval.failureCallback = [](const PartitionOp &, unsigned, p2, factory, [](const PartitionOp &, unsigned, TransferringOperand) {
TransferringOperand) { EXPECT_TRUE(false); }; EXPECT_TRUE(false);
});
eval.apply(PartitionOp::Require(Element(0))); eval.apply(PartitionOp::Require(Element(0)));
} }
} }
@@ -665,9 +687,10 @@ TEST(PartitionUtilsTest, TestUndoTransfer) {
Partition::TransferringOperandSetFactory factory(allocator); Partition::TransferringOperandSetFactory factory(allocator);
Partition p; Partition p;
PartitionOpEvaluator eval(p, factory); PartitionOpEvaluatorWithFailureCallback eval(
eval.failureCallback = [&](const PartitionOp &, unsigned, p, factory, [&](const PartitionOp &, unsigned, TransferringOperand) {
TransferringOperand) { EXPECT_TRUE(false); }; EXPECT_TRUE(false);
});
// Shouldn't error on this. // Shouldn't error on this.
eval.apply({PartitionOp::AssignFresh(Element(0)), eval.apply({PartitionOp::AssignFresh(Element(0)),