diff --git a/include/swift/SIL/SILBasicBlock.h b/include/swift/SIL/SILBasicBlock.h index f28aaa971df..8ef33799c07 100644 --- a/include/swift/SIL/SILBasicBlock.h +++ b/include/swift/SIL/SILBasicBlock.h @@ -231,6 +231,52 @@ public: const_reverse_iterator rbegin() const { return InstList.rbegin(); } const_reverse_iterator rend() const { return InstList.rend(); } + llvm::iterator_range getRangeStartingAtInst(SILInstruction *inst) { + assert(inst->getParent() == this); + return {inst->getIterator(), end()}; + } + + llvm::iterator_range getRangeEndingAtInst(SILInstruction *inst) { + assert(inst->getParent() == this); + return {begin(), inst->getIterator()}; + } + + llvm::iterator_range + getReverseRangeStartingAtInst(SILInstruction *inst) { + assert(inst->getParent() == this); + return {inst->getReverseIterator(), rend()}; + } + + llvm::iterator_range + getReverseRangeEndingAtInst(SILInstruction *inst) { + assert(inst->getParent() == this); + return {rbegin(), inst->getReverseIterator()}; + } + + llvm::iterator_range + getRangeStartingAtInst(SILInstruction *inst) const { + assert(inst->getParent() == this); + return {inst->getIterator(), end()}; + } + + llvm::iterator_range + getRangeEndingAtInst(SILInstruction *inst) const { + assert(inst->getParent() == this); + return {begin(), inst->getIterator()}; + } + + llvm::iterator_range + getReverseRangeStartingAtInst(SILInstruction *inst) const { + assert(inst->getParent() == this); + return {inst->getReverseIterator(), rend()}; + } + + llvm::iterator_range + getReverseRangeEndingAtInst(SILInstruction *inst) const { + assert(inst->getParent() == this); + return {rbegin(), inst->getReverseIterator()}; + } + /// Allows deleting instructions while iterating over all instructions of the /// block. /// diff --git a/include/swift/SILOptimizer/Utils/PartitionUtils.h b/include/swift/SILOptimizer/Utils/PartitionUtils.h index c9338ae2828..18797a89381 100644 --- a/include/swift/SILOptimizer/Utils/PartitionUtils.h +++ b/include/swift/SILOptimizer/Utils/PartitionUtils.h @@ -88,6 +88,70 @@ struct DenseMapInfo { namespace swift { +struct TransferringOperand { + using ValueType = llvm::PointerIntPair; + ValueType value; + + TransferringOperand() : value() {} + TransferringOperand(Operand *op, bool isClosureCaptured) + : value(op, isClosureCaptured) {} + explicit TransferringOperand(Operand *op) : value(op, false) {} + TransferringOperand(ValueType newValue) : value(newValue) {} + + operator bool() const { return bool(value.getPointer()); } + + Operand *getOperand() const { return value.getPointer(); } + + bool isClosureCaptured() const { return value.getInt(); } + + SILInstruction *getUser() const { return getOperand()->getUser(); } +}; + +} // namespace swift + +namespace llvm { + +template <> +struct PointerLikeTypeTraits { + using TransferringOperand = swift::TransferringOperand; + + static inline void *getAsVoidPointer(TransferringOperand ptr) { + return PointerLikeTypeTraits< + TransferringOperand::ValueType>::getAsVoidPointer(ptr.value); + } + static inline TransferringOperand getFromVoidPointer(void *ptr) { + return {PointerLikeTypeTraits< + TransferringOperand::ValueType>::getFromVoidPointer(ptr)}; + } + + static constexpr int NumLowBitsAvailable = PointerLikeTypeTraits< + TransferringOperand::ValueType>::NumLowBitsAvailable; +}; + +template <> +struct DenseMapInfo { + using TransferringOperand = swift::TransferringOperand; + using ParentInfo = DenseMapInfo; + + static TransferringOperand getEmptyKey() { + return TransferringOperand(ParentInfo::getEmptyKey()); + } + static TransferringOperand getTombstoneKey() { + return TransferringOperand(ParentInfo::getTombstoneKey()); + } + + static unsigned getHashValue(TransferringOperand operand) { + return ParentInfo::getHashValue(operand.value); + } + static bool isEqual(TransferringOperand LHS, TransferringOperand RHS) { + return ParentInfo::isEqual(LHS.value, RHS.value); + } +}; + +} // namespace llvm + +namespace swift { + /// PartitionOpKind represents the different kinds of PartitionOps that /// SILInstructions can be translated to enum class PartitionOpKind : uint8_t { @@ -142,7 +206,7 @@ private: : opKind(opKind), opArgs({arg1}), source(sourceOperand) { assert(((opKind != PartitionOpKind::Transfer && opKind != PartitionOpKind::UndoTransfer) || - sourceOperand) && + bool(sourceOperand)) && "Transfer needs a sourceInst"); } @@ -304,7 +368,7 @@ private: /// multi map here. The implication of this is that when we are performing /// dataflow we use a union operation to combine CFG elements and just take /// the first instruction that we see. - llvm::SmallDenseMap regionToTransferredOpMap; + llvm::SmallDenseMap regionToTransferredOpMap; public: Partition() : elementToRegionMap({}), canonical(true) {} @@ -358,12 +422,12 @@ public: /// Mark val as transferred. Returns true if we inserted \p /// transferOperand. We return false otherwise. - bool markTransferred(Element val, Operand *transferOperand) { + bool markTransferred(Element val, TransferringOperand transferredOperand) { // First see if our val is tracked. If it is not tracked, insert it and mark // its new region as transferred. if (!isTracked(val)) { elementToRegionMap.insert_or_assign(val, fresh_label); - regionToTransferredOpMap.insert({fresh_label, transferOperand}); + regionToTransferredOpMap.insert({fresh_label, transferredOperand}); fresh_label = Region(fresh_label + 1); canonical = false; return true; @@ -373,7 +437,7 @@ public: auto iter1 = elementToRegionMap.find(val); assert(iter1 != elementToRegionMap.end()); auto iter2 = - regionToTransferredOpMap.try_emplace(iter1->second, transferOperand); + regionToTransferredOpMap.try_emplace(iter1->second, transferredOperand); return iter2.second; } @@ -532,13 +596,30 @@ public: os << "["; for (auto [regionNo, elementNumbers] : multimap.getRange()) { - bool isTransferred = regionToTransferredOpMap.count(regionNo); - os << (isTransferred ? "{" : "("); + auto iter = regionToTransferredOpMap.find(regionNo); + bool isTransferred = iter != regionToTransferredOpMap.end(); + bool isClosureCaptured = + isTransferred ? iter->getSecond().isClosureCaptured() : false; + + if (isTransferred) { + os << '{'; + if (isClosureCaptured) + os << '*'; + } else { + os << '('; + } + int j = 0; for (Element i : elementNumbers) { os << (j++ ? " " : "") << i; } - os << (isTransferred ? "}" : ")"); + if (isTransferred) { + if (isClosureCaptured) + os << '*'; + os << '}'; + } else { + os << ')'; + } } os << "]\n"; } @@ -552,13 +633,13 @@ public: /// Return the instruction that transferred \p val's region or nullptr /// otherwise. - Operand *getTransferred(Element val) const { + std::optional getTransferred(Element val) const { auto iter = elementToRegionMap.find(val); if (iter == elementToRegionMap.end()) - return nullptr; + return std::nullopt; auto iter2 = regionToTransferredOpMap.find(iter->second); if (iter2 == regionToTransferredOpMap.end()) - return nullptr; + return std::nullopt; return iter2->second; } @@ -642,7 +723,7 @@ private: // // TODO: If we just used an array for this, we could just rewrite and // re-sort and not have to deal with potential allocations. - llvm::SmallDenseMap oldMap = + llvm::SmallDenseMap oldMap = std::move(regionToTransferredOpMap); for (auto &[oldReg, op] : oldMap) { auto iter = oldRegionToRelabeledMap.find(oldReg); @@ -676,8 +757,9 @@ private: horizontalUpdate(elementToRegionMap, snd, fstRegion); auto iter = regionToTransferredOpMap.find(sndRegion); if (iter != regionToTransferredOpMap.end()) { - regionToTransferredOpMap.try_emplace(fstRegion, iter->second); + auto operand = iter->second; regionToTransferredOpMap.erase(iter); + regionToTransferredOpMap.try_emplace(fstRegion, operand); } } else { result = sndRegion; @@ -685,8 +767,9 @@ private: horizontalUpdate(elementToRegionMap, fst, sndRegion); auto iter = regionToTransferredOpMap.find(fstRegion); if (iter != regionToTransferredOpMap.end()) { - regionToTransferredOpMap.try_emplace(sndRegion, iter->second); + auto operand = iter->second; regionToTransferredOpMap.erase(iter); + regionToTransferredOpMap.try_emplace(sndRegion, operand); } } @@ -751,8 +834,8 @@ struct PartitionOpEvaluator { /// 3. The operand of the instruction that originally transferred the /// region. Can be used to get the immediate value transferred or the /// transferring instruction. - std::function failureCallback = - nullptr; + std::function + failureCallback = nullptr; /// A list of elements that cannot be transferred. Whenever we transfer, we /// check this list to see if we are transferring the element and then call @@ -771,11 +854,21 @@ struct PartitionOpEvaluator { /// transferred. std::function isActorDerivedCallback = nullptr; + /// Check if the representative value of \p elt is closure captured at \p + /// op. + /// + /// NOTE: We actually just use the user of \p op in our callbacks. The reason + /// why we do not just pass in that SILInstruction is that then we would need + /// to access the instruction in the evaluator which creates a problem when + /// since the operand we pass in is a dummy operand. + std::function isClosureCapturedCallback = + nullptr; + PartitionOpEvaluator(Partition &p) : p(p) {} /// A wrapper around the failure callback that checks if it is nullptr. void handleFailure(const PartitionOp &op, Element elt, - Operand *transferringOp) const { + TransferringOperand transferringOp) const { if (!failureCallback) return; failureCallback(op, elt, transferringOp); @@ -797,6 +890,14 @@ struct PartitionOpEvaluator { return bool(isActorDerivedCallback) && isActorDerivedCallback(elt); } + /// A wraper around isClosureCapturedCallback that returns false if + /// isClosureCapturedCallback is nullptr and otherwise returns + /// isClosureCapturedCallback. + bool isClosureCaptured(Element elt, Operand *op) const { + return bool(isClosureCapturedCallback) && + isClosureCapturedCallback(elt, op); + } + /// Apply \p op to the partition op. void apply(const PartitionOp &op) const { if (emitLog) { @@ -821,8 +922,8 @@ struct PartitionOpEvaluator { "Assign PartitionOp's source argument should be already tracked"); // If we are using a region that was transferred as our assignment source // value... emit an error. - if (auto *transferringInst = p.getTransferred(op.getOpArgs()[1])) - handleFailure(op, op.getOpArgs()[1], transferringInst); + if (auto transferredOperand = p.getTransferred(op.getOpArgs()[1])) + handleFailure(op, op.getOpArgs()[1], *transferredOperand); p.elementToRegionMap.insert_or_assign( op.getOpArgs()[0], p.elementToRegionMap.at(op.getOpArgs()[1])); @@ -863,17 +964,22 @@ struct PartitionOpEvaluator { if (isActorDerived(op.getOpArgs()[0])) return handleTransferNonTransferrable(op, op.getOpArgs()[0]); + // While we are checking for actor derived, also check if our value or any + // value in our region is closure captured and propagate that bit in our + // transferred inst. + bool isClosureCapturedElt = + isClosureCaptured(op.getOpArgs()[0], op.getSourceOp()); + Region elementRegion = p.elementToRegionMap.at(op.getOpArgs()[0]); - if (llvm::any_of(p.elementToRegionMap, - [&](const std::pair &pair) -> bool { - if (pair.second != elementRegion) - return false; - return isActorDerived(pair.first); - })) - return handleTransferNonTransferrable(op, op.getOpArgs()[0]); + for (const auto &pair : p.elementToRegionMap) { + if (pair.second == elementRegion && isActorDerived(pair.first)) + return handleTransferNonTransferrable(op, op.getOpArgs()[0]); + isClosureCapturedElt |= isClosureCaptured(pair.first, op.getSourceOp()); + } // Mark op.getOpArgs()[0] as transferred. - p.markTransferred(op.getOpArgs()[0], op.getSourceOp()); + p.markTransferred(op.getOpArgs()[0], + {op.getSourceOp(), isClosureCapturedElt}); return; } case PartitionOpKind::UndoTransfer: { @@ -894,10 +1000,10 @@ struct PartitionOpEvaluator { "Merge PartitionOp's arguments should already be tracked"); // if attempting to merge a transferred region, handle the failure - if (auto *transferringInst = p.getTransferred(op.getOpArgs()[0])) - handleFailure(op, op.getOpArgs()[0], transferringInst); - if (auto *transferringInst = p.getTransferred(op.getOpArgs()[1])) - handleFailure(op, op.getOpArgs()[1], transferringInst); + if (auto transferringOp = p.getTransferred(op.getOpArgs()[0])) + handleFailure(op, op.getOpArgs()[0], *transferringOp); + if (auto transferringOp = p.getTransferred(op.getOpArgs()[1])) + handleFailure(op, op.getOpArgs()[1], *transferringOp); p.merge(op.getOpArgs()[0], op.getOpArgs()[1]); return; @@ -906,8 +1012,8 @@ struct PartitionOpEvaluator { "Require PartitionOp should be passed 1 argument"); assert(p.elementToRegionMap.count(op.getOpArgs()[0]) && "Require PartitionOp's argument should already be tracked"); - if (auto *transferringInst = p.getTransferred(op.getOpArgs()[0])) - handleFailure(op, op.getOpArgs()[0], transferringInst); + if (auto transferringOp = p.getTransferred(op.getOpArgs()[0])) + handleFailure(op, op.getOpArgs()[0], *transferringOp); return; } diff --git a/lib/SILOptimizer/Mandatory/TransferNonSendable.cpp b/lib/SILOptimizer/Mandatory/TransferNonSendable.cpp index 762c78bc786..937bbf1425f 100644 --- a/lib/SILOptimizer/Mandatory/TransferNonSendable.cpp +++ b/lib/SILOptimizer/Mandatory/TransferNonSendable.cpp @@ -19,6 +19,7 @@ #include "swift/SIL/BasicBlockDatastructures.h" #include "swift/SIL/DynamicCasts.h" #include "swift/SIL/MemAccessUtils.h" +#include "swift/SIL/MemoryLocations.h" #include "swift/SIL/NodeDatastructures.h" #include "swift/SIL/OperandDatastructures.h" #include "swift/SIL/OwnershipUtils.h" @@ -326,6 +327,236 @@ static InFlightDiagnostic diagnose(const SILInstruction *inst, Diag diag, std::forward(args)...); } +//===----------------------------------------------------------------------===// +// MARK: Partial Apply Reachability +//===----------------------------------------------------------------------===// + +namespace { + +/// We need to be able to know if instructions that extract sendable fields from +/// non-sendable addresses are reachable from a partial_apply that captures the +/// non-sendable value or its underlying object by reference. In such a case, we +/// need to require the value to not be transferred when the extraction happens +/// since we could race on extracting the value. +/// +/// The reason why we use a dataflow to do this is that: +/// +/// 1. We do not want to recompute this for each individual instruction that +/// might be reachable from the partial apply. +/// +/// 2. Just computing reachability early is a very easy way to do this. +struct PartialApplyReachabilityDataflow { + PostOrderFunctionInfo *pofi; + llvm::DenseMap valueToBit; + std::vector> valueToGenInsts; + + struct BlockState { + SmallBitVector entry; + SmallBitVector exit; + SmallBitVector gen; + bool needsUpdate = true; + }; + + BasicBlockData blockData; + bool propagatedReachability = false; + + PartialApplyReachabilityDataflow(SILFunction *fn, PostOrderFunctionInfo *pofi) + : pofi(pofi), blockData(fn) {} + + /// Begin tracking an operand of a partial apply. + void add(Operand *op); + + /// Once we have finished adding data to the data, propagate reachability. + void propagateReachability(); + + bool isReachable(SILValue value, SILInstruction *user) const; + bool isReachable(Operand *op) const { + return isReachable(op->get(), op->getUser()); + } + + bool isGenInstruction(SILValue value, SILInstruction *inst) const { + assert(propagatedReachability && "Only valid once propagated reachability"); + auto iter = + std::lower_bound(valueToGenInsts.begin(), valueToGenInsts.end(), + std::make_pair(value, nullptr), + [](const std::pair &p1, + const std::pair &p2) { + return p1 < p2; + }); + return iter != valueToGenInsts.end() && iter->first == value && + iter->second == inst; + } + + void print(llvm::raw_ostream &os) const; + + SWIFT_DEBUG_DUMP { print(llvm::dbgs()); } + +private: + SILValue getRootValue(SILValue value) const { + return getUnderlyingTrackedValue(value); + } + + unsigned getBitForValue(SILValue value) const { + unsigned size = valueToBit.size(); + auto &self = const_cast(*this); + auto iter = self.valueToBit.try_emplace(value, size); + return iter.first->second; + } +}; + +} // namespace + +void PartialApplyReachabilityDataflow::add(Operand *op) { + assert(!propagatedReachability && + "Cannot add more operands once reachability is computed"); + SILValue underlyingValue = getRootValue(op->get()); + LLVM_DEBUG(llvm::dbgs() << "PartialApplyReachability::add.\nValue: " + << underlyingValue << "User: " << *op->getUser()); + + unsigned bit = getBitForValue(underlyingValue); + auto &state = blockData[op->getParentBlock()]; + state.gen.resize(bit + 1); + state.gen.set(bit); + valueToGenInsts.emplace_back(underlyingValue, op->getUser()); +} + +bool PartialApplyReachabilityDataflow::isReachable(SILValue value, + SILInstruction *user) const { + assert( + propagatedReachability && + "Can only check for reachability once reachability has been propagated"); + SILValue baseValue = getRootValue(value); + auto iter = valueToBit.find(baseValue); + // If we aren't tracking this value... just bail. + if (iter == valueToBit.end()) + return false; + unsigned bitNum = iter->second; + auto &state = blockData[user->getParent()]; + + // If we are reachable at entry, then we are done. + if (state.entry.test(bitNum)) { + return true; + } + + // Otherwise, check if we are reachable at exit. If we are not, then we are + // not reachable. + if (!state.exit.test(bitNum)) { + return false; + } + + // We were not reachable at entry but are at our exit... walk the block and + // see if our user is before a gen instruction. + auto genStart = std::lower_bound( + valueToGenInsts.begin(), valueToGenInsts.end(), + std::make_pair(baseValue, nullptr), + [](const std::pair &p1, + const std::pair &p2) { return p1 < p2; }); + if (genStart == valueToGenInsts.end() || genStart->first != baseValue) + return false; + + auto genEnd = genStart; + while (genEnd->first == baseValue) + ++genEnd; + + // Walk forward from the beginning of the block to user. If we do not find a + // gen instruction, then we know the gen occurs after the op. + return llvm::any_of( + user->getParent()->getRangeEndingAtInst(user), [&](SILInstruction &inst) { + auto iter = std::lower_bound( + genStart, genEnd, std::make_pair(baseValue, &inst), + [](const std::pair &p1, + const std::pair &p2) { + return p1 < p2; + }); + return iter != valueToGenInsts.end() && iter->first == baseValue && + iter->second == &inst; + }); +} + +void PartialApplyReachabilityDataflow::propagateReachability() { + assert(!propagatedReachability && "Cannot propagate reachability twice"); + propagatedReachability = true; + + // Now that we have finished initializing, resize all of our bitVectors to the + // final number of bits. + unsigned numBits = valueToBit.size(); + + // If numBits is none, we have nothing to process. + if (numBits == 0) + return; + + for (auto iter : blockData) { + iter.data.entry.resize(numBits); + iter.data.exit.resize(numBits); + iter.data.gen.resize(numBits); + iter.data.needsUpdate = true; + } + + // Freeze our value to gen insts map so we can perform in block checks. + sortUnique(valueToGenInsts); + + // We perform a simple gen-kill dataflow with union. Since we are just + // propagating reachability, there isn't any kill. + bool anyNeedUpdate = true; + SmallBitVector temp(numBits); + blockData[&*blockData.getFunction()->begin()].needsUpdate = true; + while (anyNeedUpdate) { + anyNeedUpdate = false; + + for (auto *block : pofi->getReversePostOrder()) { + auto &state = blockData[block]; + + if (!state.needsUpdate) { + continue; + } + + state.needsUpdate = false; + temp.reset(); + for (auto *predBlock : block->getPredecessorBlocks()) { + auto &predState = blockData[predBlock]; + temp |= predState.exit; + } + + state.entry = temp; + + temp |= state.gen; + + if (temp != state.exit) { + state.exit = temp; + for (auto *succBlock : block->getSuccessorBlocks()) { + anyNeedUpdate = true; + blockData[succBlock].needsUpdate = true; + } + } + } + } + + LLVM_DEBUG(llvm::dbgs() << "Propagating Captures Result!\n"; + print(llvm::dbgs())); +} + +void PartialApplyReachabilityDataflow::print(llvm::raw_ostream &os) const { + // This is only invoked for debugging purposes, so make nicer output. + std::vector> data; + for (auto [value, bitNo] : valueToBit) { + data.emplace_back(bitNo, value); + } + std::sort(data.begin(), data.end()); + + os << "(BitNo, Value):\n"; + for (auto [bitNo, value] : data) { + os << " " << bitNo << ": " << value; + } + + os << "(Block,GenBits):\n"; + for (auto [block, state] : blockData) { + os << " bb" << block.getDebugID() << ".\n" + << " Entry: " << state.entry << '\n' + << " Gen: " << state.gen << '\n' + << " Exit: " << state.exit << '\n'; + } +} + //===----------------------------------------------------------------------===// // MARK: Expr/Type Inference for Diagnostics //===----------------------------------------------------------------------===// @@ -747,6 +978,8 @@ class PartitionOpTranslator { /// of PartitionOps. PartitionOpBuilder builder; + PartialApplyReachabilityDataflow partialApplyReachabilityDataflow; + std::optional tryToTrackValue(SILValue value) const { auto state = getTrackableValue(value); if (state.isNonSendable()) @@ -831,15 +1064,24 @@ class PartitionOpTranslator { for (auto &block : *function) { for (auto &inst : block) { - // See if this instruction is a partial apply whose non-sendable address - // operands we need to mark as captured_uniquely identified. Importantly - // this does not affect the result of the partial apply so there isn't - // any problems with the builtin section earlier. if (auto *pai = dyn_cast(&inst)) { - // If we find an address or a box of a non-Sendable type that is - // passed to a partial_apply, mark the value's representative as being - // uniquely identified and captured. - for (SILValue val : inst.getOperandValues()) { + ApplySite applySite(pai); + for (Operand &op : applySite.getArgumentOperands()) { + // See if this operand is inout_aliasable or is passed as a box. In + // such a case, we are passing by reference so we need to add it to + // the reachability. + if (applySite.getArgumentConvention(op) == + SILArgumentConvention::Indirect_InoutAliasable || + op.get()->getType().is()) + partialApplyReachabilityDataflow.add(&op); + + // See if this instruction is a partial apply whose non-sendable + // address operands we need to mark as captured_uniquely identified. + // + // If we find an address or a box of a non-Sendable type that is + // passed to a partial_apply, mark the value's representative as + // being uniquely identified and captured. + SILValue val = op.get(); if (val->getType().isAddress() && isNonSendableType(val->getType())) { auto trackVal = getTrackableValue(val, true); @@ -847,7 +1089,6 @@ class PartitionOpTranslator { LLVM_DEBUG(trackVal.print(llvm::dbgs())); continue; } - if (auto *pbi = dyn_cast(val)) { if (isNonSendableType( pbi->getType().getSILBoxFieldType(function))) { @@ -860,11 +1101,15 @@ class PartitionOpTranslator { } } } + + // Once we have finished processing all blocks, propagate reachability. + partialApplyReachabilityDataflow.propagateReachability(); } public: - PartitionOpTranslator(SILFunction *function) - : function(function), functionArgPartition(), builder() { + PartitionOpTranslator(SILFunction *function, PostOrderFunctionInfo *pofi) + : function(function), functionArgPartition(), builder(), + partialApplyReachabilityDataflow(function, pofi) { builder.translator = this; gatherFlowInsensitiveInformationBeforeDataflow(); @@ -924,6 +1169,10 @@ public: return {{iter2->first, iter2->second}}; } + bool isClosureCaptured(SILValue value, SILInstruction *inst) const { + return partialApplyReachabilityDataflow.isReachable(value, inst); + } + private: bool valueHasID(SILValue value, bool dumpIfHasNoID = false) { assert(getTrackableValue(value).isNonSendable() && @@ -1239,14 +1488,16 @@ public: auto handleSILOperands = [&](MutableArrayRef operands) { for (auto &op : operands) { - if (auto value = tryToTrackValue(op.get())) + if (auto value = tryToTrackValue(op.get())) { builder.addTransfer(value->getRepresentative(), &op); + } } }; auto handleSILSelf = [&](Operand *self) { - if (auto value = tryToTrackValue(self->get())) + if (auto value = tryToTrackValue(self->get())) { builder.addTransfer(value->getRepresentative(), self); + } }; if (applySite.hasSelfArgument()) { @@ -1501,11 +1752,20 @@ public: case SILInstructionKind::TupleElementAddrInst: case SILInstructionKind::StructElementAddrInst: { auto *svi = cast(inst); - // If we have a sendable field... we can always access it after - // transferring... so do not track this. - if (!isNonSendableType(svi->getType())) - return; - return translateSILLookThrough(svi->getResult(0), svi->getOperand(0)); + + // If our result is non-Sendable, just treat this as a lookthrough. + if (isNonSendableType(svi->getType())) + return translateSILLookThrough(svi->getResult(0), svi->getOperand(0)); + + // Otherwise, we are extracting a sendable field from a non-Sendable base + // type. We need to track this as an assignment so that if we transferred + // the value we emit an error. Since we do not track uses of Sendable + // values this is the best place to emit the error since we do not look + // further to find the actual use site. + // + // TODO: We could do a better job here and attempt to find the actual use + // of the Sendable addr. That would require adding more logic though. + return translateSILRequire(svi->getOperand(0)); } // We identify tuple results with their operand's id. @@ -1846,9 +2106,16 @@ class BlockPartitionState { /// /// NOTE: This method ignored errors that arise. We process separately later /// to discover if an error occured. - bool recomputeExitFromEntry() { + bool recomputeExitFromEntry(PartitionOpTranslator &translator) { Partition workingPartition = entryPartition; PartitionOpEvaluator eval(workingPartition); + eval.isClosureCapturedCallback = [&](Element element, Operand *op) -> bool { + auto iter = translator.getValueForId(element); + if (!iter) + return false; + return translator.isClosureCaptured(iter->getRepresentative(), + op->getUser()); + }; for (const auto &partitionOp : blockPartitionOps) { // By calling apply without providing a `handleFailure` closure, errors // will be suppressed @@ -2111,18 +2378,20 @@ class PartitionAnalysis { SILFunction *function; + PostOrderFunctionInfo *pofi; + bool solved; /// The constructor initializes each block in the function by compiling it to /// PartitionOps, then seeds the solve method by setting `needsUpdate` to true /// for the entry block - PartitionAnalysis(SILFunction *fn) - : translator(fn), + PartitionAnalysis(SILFunction *fn, PostOrderFunctionInfo *pofi) + : translator(fn, pofi), blockStates(fn, [this](SILBasicBlock *block) { return BlockPartitionState(block, translator); }), - function(fn), solved(false) { + function(fn), pofi(pofi), solved(false) { // Initialize the entry block as needing an update, and having a partition // that places all its non-sendable args in a single region blockStates[fn->getEntryBlock()].needsUpdate = true; @@ -2141,9 +2410,10 @@ class PartitionAnalysis { while (anyNeedUpdate) { anyNeedUpdate = false; - for (auto [block, blockState] : blockStates) { + for (auto *block : pofi->getReversePostOrder()) { + auto &blockState = blockStates[block]; - LLVM_DEBUG(llvm::dbgs() << "Block: bb" << block.getDebugID() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Block: bb" << block->getDebugID() << "\n"); if (!blockState.needsUpdate) { LLVM_DEBUG(llvm::dbgs() << " Doesn't need update! Skipping!\n"); continue; @@ -2163,7 +2433,7 @@ class PartitionAnalysis { // This loop computes the join of the exit partitions of all // predecessors of this block - for (SILBasicBlock *predBlock : block.getPredecessorBlocks()) { + for (SILBasicBlock *predBlock : block->getPredecessorBlocks()) { BlockPartitionState &predState = blockStates[predBlock]; // ignore predecessors that haven't been reached by the analysis yet if (!predState.reached) @@ -2206,8 +2476,8 @@ class PartitionAnalysis { // recompute this block's exit partition from its (updated) entry // partition, and if this changed the exit partition notify all // successor blocks that they need to update as well - if (blockState.recomputeExitFromEntry()) { - for (SILBasicBlock *succBlock : block.getSuccessorBlocks()) { + if (blockState.recomputeExitFromEntry(translator)) { + for (SILBasicBlock *succBlock : block->getSuccessorBlocks()) { anyNeedUpdate = true; blockStates[succBlock].needsUpdate = true; } @@ -2243,7 +2513,21 @@ class PartitionAnalysis { PartitionOpEvaluator eval(workingPartition); eval.failureCallback = /*handleFailure=*/ [&](const PartitionOp &partitionOp, TrackableValueID transferredVal, - Operand *transferringOp) { + TransferringOperand transferringOp) { + // Ignore this if we have a gep like instruction that is returning a + // sendable type and transferringOp was not set with closure + // capture. + if (auto *svi = dyn_cast( + partitionOp.getSourceInst())) { + if (isa(svi) && + !isNonSendableType(svi->getType(), svi->getFunction())) { + bool isCapture = transferringOp.isClosureCaptured(); + if (!isCapture) { + return; + } + } + } + auto rep = translator.getValueForId(transferredVal)->getRepresentative(); LLVM_DEBUG( @@ -2253,9 +2537,9 @@ class PartitionAnalysis { << " Rep: " << *rep << " Require Inst: " << *partitionOp.getSourceInst() << " Transferring Op Num: " - << transferringOp->getOperandNumber() << '\n' - << " Transferring Inst: " << *transferringOp->getUser()); - transferOpToRequireInstMultiMap.insert(transferringOp, + << transferringOp.getOperand()->getOperandNumber() << '\n' + << " Transferring Inst: " << *transferringOp.getUser()); + transferOpToRequireInstMultiMap.insert(transferringOp.getOperand(), partitionOp.getSourceInst()); }; eval.transferredNonTransferrableCallback = @@ -2276,6 +2560,14 @@ class PartitionAnalysis { return false; return iter->isActorDerived(); }; + eval.isClosureCapturedCallback = [&](Element element, + Operand *op) -> bool { + auto iter = translator.getValueForId(element); + if (!iter) + return false; + return translator.isClosureCaptured(iter->getRepresentative(), + op->getUser()); + }; // And then evaluate all of our partition ops on the entry partition. for (auto &partitionOp : blockState.getPartitionOps()) { @@ -2374,8 +2666,9 @@ public: } } - static void performForFunction(SILFunction *function) { - auto analysis = PartitionAnalysis(function); + static void performForFunction(SILFunction *function, + PostOrderFunctionInfo *pofi) { + auto analysis = PartitionAnalysis(function, pofi); analysis.solve(); LLVM_DEBUG(llvm::dbgs() << "SOLVED: "; analysis.print(llvm::dbgs());); analysis.emitDiagnostics(); @@ -2415,7 +2708,8 @@ class TransferNonSendable : public SILFunctionTransform { if (!function->getASTContext().getProtocol(KnownProtocolKind::Sendable)) llvm::report_fatal_error("Sendable protocol not available!"); - PartitionAnalysis::performForFunction(function); + auto *pofi = this->getAnalysis()->get(function); + PartitionAnalysis::performForFunction(function, pofi); } }; diff --git a/test/Concurrency/sendnonsendable_basic.swift b/test/Concurrency/sendnonsendable_basic.swift index f7cca9ec793..e4f7543e2d2 100644 --- a/test/Concurrency/sendnonsendable_basic.swift +++ b/test/Concurrency/sendnonsendable_basic.swift @@ -752,7 +752,7 @@ func varNonSendableNonTrivialFinalClassFieldTest() async { // MARK: StructFieldTests // //////////////////////////// -struct StructFieldTests { // expected-complete-note 12 {{}} +struct StructFieldTests { // expected-complete-note 31 {{}} let letSendableTrivial = 0 let letSendableNonTrivial = SendableKlass() let letNonSendableNonTrivial = NonSendableKlass() @@ -786,31 +786,6 @@ func letNonSendableNonTrivialLetStructFieldTest() async { useValue(test) } -func varSendableTrivialLetStructFieldTest() async { - let test = StructFieldTests() - await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} - // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} - _ = test.varSendableTrivial - useValue(test) // expected-tns-note {{access here could race}} -} - -func varSendableNonTrivialLetStructFieldTest() async { - let test = StructFieldTests() - await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} - // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} - _ = test.varSendableNonTrivial - useValue(test) // expected-tns-note {{access here could race}} -} - -func varNonSendableNonTrivialLetStructFieldTest() async { - let test = StructFieldTests() - await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} - // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} - let z = test.varNonSendableNonTrivial // expected-tns-note {{access here could race}} - _ = z - useValue(test) -} - func letSendableTrivialVarStructFieldTest() async { var test = StructFieldTests() test = StructFieldTests() @@ -839,6 +814,52 @@ func letNonSendableNonTrivialVarStructFieldTest() async { useValue(test) } +// Lets can access sendable let/var even if captured in a closure. +func letNonSendableNonTrivialLetStructFieldClosureTest() async { + let test = StructFieldTests() + let cls = { + print(test) + } + _ = cls + var cls2 = {} + cls2 = { + print(test) + } + _ = cls2 + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + let z = test.letSendableNonTrivial + _ = z + let z2 = test.varSendableNonTrivial + _ = z2 + useValue(test) // expected-tns-note {{access here could race}} +} + +func varSendableTrivialLetStructFieldTest() async { + let test = StructFieldTests() + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + _ = test.varSendableTrivial + useValue(test) // expected-tns-note {{access here could race}} +} + +func varSendableNonTrivialLetStructFieldTest() async { + let test = StructFieldTests() + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + _ = test.varSendableNonTrivial + useValue(test) // expected-tns-note {{access here could race}} +} + +func varNonSendableNonTrivialLetStructFieldTest() async { + let test = StructFieldTests() + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + let z = test.varNonSendableNonTrivial // expected-tns-note {{access here could race}} + _ = z + useValue(test) +} + func varSendableTrivialVarStructFieldTest() async { var test = StructFieldTests() test = StructFieldTests() @@ -867,6 +888,284 @@ func varNonSendableNonTrivialVarStructFieldTest() async { useValue(test) } +// vars cannot access sendable let/var if captured in a closure. +func varNonSendableNonTrivialLetStructFieldClosureTest1() async { + var test = StructFieldTests() + test = StructFieldTests() + let cls = { + test = StructFieldTests() + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + let z = test.letSendableNonTrivial // expected-tns-note {{access here could race}} + _ = z + useValue(test) +} + +func varNonSendableNonTrivialLetStructFieldClosureTest2() async { + var test = StructFieldTests() + test = StructFieldTests() + let cls = { + test = StructFieldTests() + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + let z = test.varSendableNonTrivial // expected-tns-note {{access here could race}} + _ = z + useValue(test) +} + +func varNonSendableNonTrivialLetStructFieldClosureTest3() async { + var test = StructFieldTests() + test = StructFieldTests() + let cls = { + test = StructFieldTests() + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + test.varSendableNonTrivial = SendableKlass() // expected-tns-note {{access here could race}} + useValue(test) +} + +// vars cannot access sendable let/var if captured in a closure. +func varNonSendableNonTrivialLetStructFieldClosureTest4() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + cls = { + test = StructFieldTests() + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + let z = test.letSendableNonTrivial // expected-tns-note {{access here could race}} + _ = z + useValue(test) +} + +func varNonSendableNonTrivialLetStructFieldClosureTest5() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + cls = { + test = StructFieldTests() + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + let z = test.varSendableNonTrivial // expected-tns-note {{access here could race}} + _ = z + useValue(test) +} + +func varNonSendableNonTrivialLetStructFieldClosureTest6() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + cls = { + test = StructFieldTests() + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + test.varSendableNonTrivial = SendableKlass() // expected-tns-note {{access here could race}} + useValue(test) +} + +func varNonSendableNonTrivialLetStructFieldClosureTest7() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + cls = { + test.varSendableNonTrivial = SendableKlass() + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + test.varSendableNonTrivial = SendableKlass() // expected-tns-note {{access here could race}} + useValue(test) +} + +func varNonSendableNonTrivialLetStructFieldClosureTest8() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + cls = { + useInOut(&test) + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + test.varSendableNonTrivial = SendableKlass() // expected-tns-note {{access here could race}} + useValue(test) +} + +func varNonSendableNonTrivialLetStructFieldClosureTest9() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + cls = { + useInOut(&test.varSendableNonTrivial) + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + test.varSendableNonTrivial = SendableKlass() // expected-tns-note {{access here could race}} + useValue(test) +} + +func varNonSendableNonTrivialLetStructFieldClosureFlowSensitive1() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + + if await booleanFlag { + cls = { + useInOut(&test.varSendableNonTrivial) + } + _ = cls + } else { + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + + test.varSendableNonTrivial = SendableKlass() + } + + useValue(test) // expected-tns-note {{access here could race}} +} + +func varNonSendableNonTrivialLetStructFieldClosureFlowSensitive2() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + + if await booleanFlag { + cls = { + useInOut(&test.varSendableNonTrivial) + } + _ = cls + + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + } + + test.varSendableNonTrivial = SendableKlass() // expected-tns-note {{access here could race}} + useValue(test) +} + +// We do not error when accessing the sendable field in this example since the +// transfer is not reachable from the closure. Instead we emit an error on test. +func varNonSendableNonTrivialLetStructFieldClosureFlowSensitive3() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + + if await booleanFlag { + cls = { + useInOut(&test.varSendableNonTrivial) + } + _ = cls + } else { + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + } + + test.varSendableNonTrivial = SendableKlass() + useValue(test) // expected-tns-note {{access here could race}} +} + +func varNonSendableNonTrivialLetStructFieldClosureFlowSensitive4() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + + for _ in 0..<1024 { + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + test = StructFieldTests() // expected-tns-note {{access here could race}} + cls = { + useInOut(&test.varSendableNonTrivial) + } + _ = cls + } + + test.varSendableNonTrivial = SendableKlass() + useValue(test) +} + +func varNonSendableNonTrivialLetStructFieldClosureFlowSensitive5() async { + var test = StructFieldTests() + test = StructFieldTests() + + // The reason why we error here is that even though we reassign at the end of + // the for loop and currently understand that test has a new value different + // from the value assigned to test at the beginning of the for loop, when we + // merge back through the for loop, we have to merge the regions due to the + // union operation we perform. So the conservatism of the dataflow creates + // this. This is a case where we are going to need to be able to have the + // compiler explain the regions well. + for _ in 0..<1024 { + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-tns-note @-1 {{access here could race}} + // expected-complete-warning @-2 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + test = StructFieldTests() + } + + test.varSendableNonTrivial = SendableKlass() + useValue(test) // expected-tns-note {{access here could race}} +} + +// In this case since we are tracking the transfer from the if statement, we +// do not track the closure. +// +// TODO: Change to track all sets. +func varNonSendableNonTrivialLetStructFieldClosureFlowSensitive6() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + + if await booleanFlag { + cls = { + useInOut(&test.varSendableNonTrivial) + } + _ = cls + await transferToMain(test) + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + } else { + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + } + + test.varSendableNonTrivial = SendableKlass() + useValue(test) // expected-tns-note {{access here could race}} +} + +// In this case since we are tracking the transfer from the else statement, we +// track the closure. +func varNonSendableNonTrivialLetStructFieldClosureFlowSensitive7() async { + var test = StructFieldTests() + test = StructFieldTests() + var cls = {} + + if await booleanFlag { + await transferToMain(test) + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + } else { + cls = { + useInOut(&test.varSendableNonTrivial) + } + _ = cls + await transferToMain(test) // expected-tns-warning {{passing argument of non-sendable type 'StructFieldTests' from nonisolated context to main actor-isolated context at this call site could yield a race with accesses later in this function}} + // expected-complete-warning @-1 {{passing argument of non-sendable type 'StructFieldTests' into main actor-isolated context may introduce data races}} + } + + test.varSendableNonTrivial = SendableKlass() // expected-tns-note {{access here could race}} + useValue(test) +} + //////////////////////////// // MARK: TupleFieldTests // //////////////////////////// diff --git a/unittests/SILOptimizer/PartitionUtilsTest.cpp b/unittests/SILOptimizer/PartitionUtilsTest.cpp index 9458e08653c..ca60d846075 100644 --- a/unittests/SILOptimizer/PartitionUtilsTest.cpp +++ b/unittests/SILOptimizer/PartitionUtilsTest.cpp @@ -548,14 +548,13 @@ TEST(PartitionUtilsTest, TestConsumeAndRequire) { // expected: p: ({0 1 2 6 7 10} (3 4 5) (8 9) (Element(11))) - auto never_called = [](const PartitionOp &, unsigned, Operand *) { + auto never_called = [](const PartitionOp &, unsigned, TransferringOperand) { EXPECT_TRUE(false); }; int times_called = 0; - auto increment_times_called = [&](const PartitionOp &, unsigned, Operand *) { - times_called++; - }; + auto increment_times_called = [&](const PartitionOp &, unsigned, + TransferringOperand) { times_called++; }; { PartitionOpEvaluator eval(p); @@ -623,18 +622,16 @@ TEST(PartitionUtilsTest, TestCopyConstructor) { { bool failure = false; PartitionOpEvaluator eval(p1); - eval.failureCallback = [&](const PartitionOp &, unsigned, Operand *) { - failure = true; - }; + eval.failureCallback = [&](const PartitionOp &, unsigned, + TransferringOperand) { failure = true; }; eval.apply(PartitionOp::Require(Element(0))); EXPECT_TRUE(failure); } { PartitionOpEvaluator eval(p2); - eval.failureCallback = [](const PartitionOp &, unsigned, Operand *) { - EXPECT_TRUE(false); - }; + eval.failureCallback = [](const PartitionOp &, unsigned, + TransferringOperand) { EXPECT_TRUE(false); }; eval.apply(PartitionOp::Require(Element(0))); } } @@ -642,9 +639,8 @@ TEST(PartitionUtilsTest, TestCopyConstructor) { TEST(PartitionUtilsTest, TestUndoTransfer) { Partition p; PartitionOpEvaluator eval(p); - eval.failureCallback = [&](const PartitionOp &, unsigned, Operand *) { - EXPECT_TRUE(false); - }; + eval.failureCallback = [&](const PartitionOp &, unsigned, + TransferringOperand) { EXPECT_TRUE(false); }; // Shouldn't error on this. eval.apply({PartitionOp::AssignFresh(Element(0)),