diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 987614ef6fb..a52f13b13e9 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -559,7 +559,13 @@ public: EnumElementDecl *Element, SILType Ty) { return insert(new (F.getModule()) EnumInst(Loc, Operand, Element, Ty)); } - + + /// Create an enum Optional.Some. + EnumInst *createOptionalSome(SILLocation Loc, SILValue Operand, SILType Ty) { + auto *Decl = F.getModule().getASTContext().getOptionalSomeDecl(); + return createEnum(Loc, Operand, Decl, Ty); + } + InitEnumDataAddrInst *createInitEnumDataAddr(SILLocation Loc, SILValue Operand, EnumElementDecl *Element, SILType Ty) { return insert( diff --git a/include/swift/SILAnalysis/RCIdentityAnalysis.h b/include/swift/SILAnalysis/RCIdentityAnalysis.h index bea6f75c096..80ba4736597 100644 --- a/include/swift/SILAnalysis/RCIdentityAnalysis.h +++ b/include/swift/SILAnalysis/RCIdentityAnalysis.h @@ -48,6 +48,10 @@ public: SILValue getRCIdentityRoot(SILValue V); + /// Return all recursive users of V, looking through users which + /// propagate RCIdentity. + void getRCUsers(SILValue V, llvm::SmallVectorImpl &Users); + static bool classof(const SILAnalysis *S) { return S->getKind() == AnalysisKind::RCIdentity; } diff --git a/lib/SILAnalysis/RCIdentityAnalysis.cpp b/lib/SILAnalysis/RCIdentityAnalysis.cpp index c928e022cd5..cdd66ee9ed6 100644 --- a/lib/SILAnalysis/RCIdentityAnalysis.cpp +++ b/lib/SILAnalysis/RCIdentityAnalysis.cpp @@ -52,7 +52,7 @@ static bool isRCIdentityPreservingCast(ValueKind Kind) { } //===----------------------------------------------------------------------===// -// Analysis +// RCIdentityRoot Analysis //===----------------------------------------------------------------------===// /// Returns true if FirstIV is a SILArgument or SILInstruction in a BB that @@ -388,6 +388,56 @@ stripRCIdentityPreservingOps(SILValue V, unsigned RecursionDepth) { return V; } +//===----------------------------------------------------------------------===// +// RCUser Analysis +//===----------------------------------------------------------------------===// + +/// Return all recursive users of V, looking through users which propagate +/// RCIdentity. +/// +/// We only use the instruction analysis here. +void RCIdentityAnalysis::getRCUsers( + SILValue InputValue, llvm::SmallVectorImpl &Users) { + // Add V to the worklist. + llvm::SmallVector Worklist; + Worklist.push_back(InputValue); + + // A set used to ensure we only visit users once. + llvm::SmallPtrSet VisitedInsts; + + // Then until we finish the worklist... + while (!Worklist.empty()) { + // Pop off the top value. + SILValue V = Worklist.pop_back_val(); + + // For each user of V... + for (auto *Op : V.getUses()) { + SILInstruction *User = Op->getUser(); + + // If we have already visited this user, continue. + if (!VisitedInsts.insert(User).second) + continue; + + // Otherwise attempt to strip off one layer of RC identical instructions + // from User. + SILValue StrippedRCID = stripRCIdentityPreservingInsts(User); + + // If StrippedRCID is not V, then we know that User's result is + // conservatively not RCIdentical + // to V and we should add it to our RCUserList and continue. + if (StrippedRCID != V) { + Users.push_back(User); + continue; + } + + // Otherwise, add all of User's uses to our list to continue searching. + for (unsigned i = 0, e = User->getNumTypes(); i != e; ++i) { + Worklist.push_back(SILValue(User, i)); + } + } + } +} + //===----------------------------------------------------------------------===// // Main Entry Point //===----------------------------------------------------------------------===// diff --git a/lib/SILPasses/RemovePin.cpp b/lib/SILPasses/RemovePin.cpp index 33acc046180..023857761ed 100644 --- a/lib/SILPasses/RemovePin.cpp +++ b/lib/SILPasses/RemovePin.cpp @@ -18,6 +18,7 @@ #include "swift/SILAnalysis/ArraySemantic.h" #include "swift/SILAnalysis/ARCAnalysis.h" #include "swift/SILAnalysis/LoopAnalysis.h" +#include "swift/SILAnalysis/RCIdentityAnalysis.h" #include "swift/SILPasses/Passes.h" #include "swift/SILPasses/Transforms.h" #include "swift/SILPasses/Utils/CFG.h" @@ -30,9 +31,12 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Support/CommandLine.h" +STATISTIC(NumPinPairsRemoved, "Num pin pairs removed"); + using namespace swift; /// \brief Can this instruction read the pinned bit of the reference count. @@ -58,6 +62,8 @@ class RemovePinInsts : public SILFunctionTransform { AliasAnalysis *AA; + RCIdentityAnalysis *RCIA; + public: RemovePinInsts() {} @@ -65,6 +71,10 @@ public: void run() override { AA = PM->getAnalysis(); + RCIA = PM->getAnalysis(); + + DEBUG(llvm::dbgs() << "*** Running Pin Removal on " + << getFunction()->getName() << "\n"); bool Changed = false; for (auto &BB : *getFunction()) { @@ -72,30 +82,49 @@ public: // This is only a BB local analysis for now. AvailablePins.clear(); + DEBUG(llvm::dbgs() << "Visiting new BB!\n"); + for (auto InstIt = BB.begin(), End = BB.end(); InstIt != End; ) { auto *CurInst = &*InstIt; ++InstIt; + DEBUG(llvm::dbgs() << " Visiting: " << *CurInst); + // Add StrongPinInst to available pins. if (isa(CurInst)) { + DEBUG(llvm::dbgs() << " Found pin!\n"); AvailablePins.insert(CurInst); continue; } // Try to remove StrongUnpinInst if its input is available. if (auto *Unpin = dyn_cast(CurInst)) { - auto *PinDef = dyn_cast(Unpin->getOperand().getDef()); + DEBUG(llvm::dbgs() << " Found unpin!\n"); + SILValue RCId = RCIA->getRCIdentityRoot(Unpin->getOperand()); + DEBUG(llvm::dbgs() << " RCID Source: " << *RCId.getDef()); + auto *PinDef = dyn_cast(RCId.getDef()); if (PinDef && AvailablePins.count(PinDef)){ + DEBUG(llvm::dbgs() << " Found matching pin: " << *PinDef); SmallVector MarkDependentInsts; if (areSafePinUsers(PinDef, Unpin, MarkDependentInsts)) { + DEBUG(llvm::dbgs() << " Pin users are safe! Removing!\n"); Changed = true; - auto NewDep = PinDef->getOperand(); - for (auto &MD : MarkDependentInsts) - MD->setOperand(1, NewDep); + auto *Enum = SILBuilder(PinDef).createOptionalSome( + PinDef->getLoc(), PinDef->getOperand(), PinDef->getType(0)); + SILValue(PinDef).replaceAllUsesWith(Enum); Unpin->eraseFromParent(); PinDef->eraseFromParent(); + // Remove this pindef from AvailablePins. + AvailablePins.erase(PinDef); + ++NumPinPairsRemoved; + } else { + DEBUG(llvm::dbgs() + << " Pin users are not safe! Can not remove!\n"); } + continue; + } else { + DEBUG(llvm::dbgs() << " Failed to find matching pin!\n"); } // Otherwise, fall through. An unpin, through destruction of an object // can have arbitrary sideeffects. @@ -103,6 +132,8 @@ public: // In all other cases check whether this could be a potentially // releasing instruction. + DEBUG(llvm::dbgs() + << " Checking if this inst invalidates pins.\n"); invalidateAvailablePins(CurInst); } } @@ -112,17 +143,30 @@ public: SILAnalysis::PreserveKind::ProgramFlow); } - /// Pin uses are safe if they either mark a dependence or if it is the unpin we - /// are trying to remove. + /// Pin uses are safe if: + /// + /// 1. The user marks a dependence. + /// 2. The user is the unpin we are trying to remove. + /// 3. The user is an RCIdentical user of our Pin result and only has + /// RCIdentity preserving insts, mark dependence, or the unpin we are + /// trying + /// to remove as users. bool areSafePinUsers(StrongPinInst *Pin, StrongUnpinInst *Unpin, SmallVectorImpl &MarkDeps) { - for (auto *U : Pin->getUses()) { - if (auto *MD = dyn_cast(U->getUser())) + // Grab all uses looking past RCIdentical uses from RCIdentityAnalysis. + llvm::SmallVector Users; + RCIA->getRCUsers(SILValue(Pin), Users); + + for (auto *U : Users) { + if (auto *MD = dyn_cast(U)) { MarkDeps.push_back(MD); - else if (dyn_cast(U->getUser()) == Unpin) continue; - else - return false; + } + + if (dyn_cast(U) == Unpin) + continue; + + return false; } return true; } @@ -160,8 +204,15 @@ public: RemovePin.push_back(P); } - for (auto P: RemovePin) + if (RemovePin.empty()) { + DEBUG(llvm::dbgs() << " No pins to invalidate!\n"); + return; + } + + for (auto *P : RemovePin) { + DEBUG(llvm::dbgs() << " Invalidating Pin: " << *P); AvailablePins.erase(P); + } } }; } diff --git a/test/SILPasses/remove_pins.sil b/test/SILPasses/remove_pins.sil index 301355939f8..8d8f14ed3be 100644 --- a/test/SILPasses/remove_pins.sil +++ b/test/SILPasses/remove_pins.sil @@ -93,3 +93,42 @@ bb0(%0 : $*ArrayInt): strong_unpin %6 : $Optional return %7 : $Bool } + +// CHECK-LABEL: sil @remove_pins_with_inert_rcidentity_uses : $@thin (@inout ArrayInt) -> () { +// CHECK-NOT: strong_pin +// CHECK-NOT: strong_unpin +sil @remove_pins_with_inert_rcidentity_uses : $@thin (@inout ArrayInt) -> () { +bb0(%0 : $*ArrayInt): + %1 = load %0: $*ArrayInt + %2 = struct_extract %1 : $ArrayInt, #ArrayInt.buffer + %3 = struct_extract %2 : $ArrayIntBuffer, #ArrayIntBuffer.storage + %4 = unchecked_trivial_bit_cast %3 : $Builtin.NativeObject to $UnsafePointer + %5 = function_ref @_swift_isUniquelyReferencedOrPinned_nonNull_native : $@cc(cdecl) @thin (UnsafePointer) -> Bool + %6 = strong_pin %3 : $Builtin.NativeObject + %7 = integer_literal $Builtin.Int1, 0 + %8 = tuple(%6 : $Optional, %7 : $Builtin.Int1) + %9 = tuple_extract %8 : $(Optional, Builtin.Int1), 0 + strong_unpin %9 : $Optional + %9999 = tuple() + return %9999 : $() +} + +// CHECK-LABEL: sil @remove_pins_with_noninert_rcidentity_uses : $@thin (@inout ArrayInt) -> Bool { +// CHECK: strong_pin +// CHECK: strong_unpin +sil @remove_pins_with_noninert_rcidentity_uses : $@thin (@inout ArrayInt) -> Bool { +bb0(%0 : $*ArrayInt): + %1 = load %0: $*ArrayInt + %2 = struct_extract %1 : $ArrayInt, #ArrayInt.buffer + %3 = struct_extract %2 : $ArrayIntBuffer, #ArrayIntBuffer.storage + %4 = unchecked_trivial_bit_cast %3 : $Builtin.NativeObject to $UnsafePointer + %5 = function_ref @_swift_isUniquelyReferencedOrPinned_nonNull_native : $@cc(cdecl) @thin (UnsafePointer) -> Bool + %6 = strong_pin %3 : $Builtin.NativeObject + %7 = integer_literal $Builtin.Int1, 0 + %8 = tuple(%6 : $Optional, %7 : $Builtin.Int1) + %9 = tuple_extract %8 : $(Optional, Builtin.Int1), 0 + strong_unpin %9 : $Optional + %10 = unchecked_trivial_bit_cast %9 : $Optional to $UnsafePointer + %11 = apply %5(%10) : $@cc(cdecl) @thin (UnsafePointer) -> Bool + return %11 : $Bool +}