//===--- LoopUnroll.cpp - Loop unrolling ----------------------------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// #define DEBUG_TYPE "sil-loopunroll" #include "llvm/ADT/DepthFirstIterator.h" #include "swift/SIL/PatternMatch.h" #include "swift/SIL/SILCloner.h" #include "swift/SILOptimizer/Analysis/LoopAnalysis.h" #include "swift/SILOptimizer/PassManager/Passes.h" #include "swift/SILOptimizer/PassManager/Transforms.h" #include "swift/SILOptimizer/Utils/PerformanceInlinerUtils.h" #include "swift/SILOptimizer/Utils/SILInliner.h" #include "swift/SILOptimizer/Utils/SILSSAUpdater.h" using namespace swift; using namespace swift::PatternMatch; using llvm::DenseMap; using llvm::MapVector; static const uint64_t SILLoopUnrollThreshold = 250; namespace { /// Clone the basic blocks in a loop. class LoopCloner : public SILCloner { SILLoop *Loop; friend class SILInstructionVisitor; friend class SILCloner; public: LoopCloner(SILLoop *Loop) : SILCloner(*Loop->getHeader()->getParent()), Loop(Loop) {} /// Clone the basic blocks in the loop. void cloneLoop(); /// Get a map from basic blocks or the original loop to the cloned loop. MapVector &getBBMap() { return BBMap; } DenseMap &getValueMap() { return ValueMap; } protected: SILValue remapValue(SILValue V) { if (auto *BB = V->getParentBlock()) { if (!Loop->contains(BB)) return V; } return SILCloner::remapValue(V); } void postProcess(SILInstruction *Orig, SILInstruction *Cloned) { SILCloner::postProcess(Orig, Cloned); } }; } // end anonymous namespace void LoopCloner::cloneLoop() { auto *Header = Loop->getHeader(); auto *CurFun = Loop->getHeader()->getParent(); SmallVector ExitBlocks; Loop->getExitBlocks(ExitBlocks); for (auto *ExitBB : ExitBlocks) BBMap[ExitBB] = ExitBB; auto *ClonedHeader = CurFun->createBasicBlock(); BBMap[Header] = ClonedHeader; // Clone the arguments. for (auto *Arg : Header->getArguments()) { SILValue MappedArg = ClonedHeader->createPHIArgument( getOpType(Arg->getType()), ValueOwnershipKind::Owned); ValueMap.insert(std::make_pair(Arg, MappedArg)); } // Clone the instructions in this basic block and recursively clone // successor blocks. getBuilder().setInsertionPoint(ClonedHeader); visitSILBasicBlock(Header); // Fix-up terminators. for (auto BBPair : BBMap) if (BBPair.first != BBPair.second) { getBuilder().setInsertionPoint(BBPair.second); visit(BBPair.first->getTerminator()); } } /// Determine the number of iterations the loop is at most executed. The loop /// might contain early exits so this is the maximum if no early exits are /// taken. static Optional getMaxLoopTripCount(SILLoop *Loop, SILBasicBlock *Preheader, SILBasicBlock *Header, SILBasicBlock *Latch) { // Skip a split backedge. SILBasicBlock *OrigLatch = Latch; if (!Loop->isLoopExiting(Latch) && !(Latch = Latch->getSinglePredecessorBlock())) return None; if (!Loop->isLoopExiting(Latch)) return None; // Get the loop exit condition. auto *CondBr = dyn_cast(Latch->getTerminator()); if (!CondBr) return None; // Match an add 1 recurrence. SILPHIArgument *RecArg; IntegerLiteralInst *End; SILValue RecNext; unsigned Adjust = 0; if (!match(CondBr->getCondition(), m_BuiltinInst(BuiltinValueKind::ICMP_EQ, m_SILValue(RecNext), m_IntegerLiteralInst(End))) && !match(CondBr->getCondition(), m_BuiltinInst(BuiltinValueKind::ICMP_SGE, m_SILValue(RecNext), m_IntegerLiteralInst(End)))) { if (!match(CondBr->getCondition(), m_BuiltinInst(BuiltinValueKind::ICMP_SGT, m_SILValue(RecNext), m_IntegerLiteralInst(End)))) return None; // Otherwise, we have a greater than comparison. else Adjust = 1; } if (!match(RecNext, m_TupleExtractInst(m_ApplyInst(BuiltinValueKind::SAddOver, m_SILPHIArgument(RecArg), m_One()), 0))) return None; if (RecArg->getParent() != Header) return None; auto *Start = dyn_cast_or_null( RecArg->getIncomingValue(Preheader)); if (!Start) return None; if (RecNext != RecArg->getIncomingValue(OrigLatch)) return None; auto StartVal = Start->getValue(); auto EndVal = End->getValue(); if (StartVal.sgt(EndVal)) return None; auto Dist = EndVal - StartVal; if (Dist.getBitWidth() > 64) return None; if (Dist == 0) return None; return Dist.getZExtValue() + Adjust; } /// Check whether we can duplicate the instructions in the loop and use a /// heuristic that looks at the trip count and the cost of the instructions in /// the loop to determine whether we should unroll this loop. static bool canAndShouldUnrollLoop(SILLoop *Loop, uint64_t TripCount) { assert(Loop->getSubLoops().empty() && "Expect innermost loops"); if (TripCount > 32) return false; // We can unroll a loop if we can duplicate the instructions it holds. uint64_t Cost = 0; // Average number of instructions per basic block. // It is used to estimate the cost of the callee // inside a loop. const uint64_t InsnsPerBB = 4; for (auto *BB : Loop->getBlocks()) { for (auto &Inst : *BB) { if (!Loop->canDuplicate(&Inst)) return false; if (instructionInlineCost(Inst) != InlineCost::Free) ++Cost; if (auto AI = FullApplySite::isa(&Inst)) { auto Callee = AI.getCalleeFunction(); if (Callee && getEligibleFunction(AI, InlineSelection::Everything)) { // If callee is rather big and potentialy inlineable, it may be better // not to unroll, so that the body of the calle can be inlined later. Cost += Callee->size() * InsnsPerBB; } } if (Cost * TripCount > SILLoopUnrollThreshold) return false; } } return true; } /// Redirect the terminator of the current loop iteration's latch to the next /// iterations header or if this is the last iteration remove the backedge to /// the header. static void redirectTerminator(SILBasicBlock *Latch, unsigned CurLoopIter, unsigned LastLoopIter, SILBasicBlock *OrigHeader, SILBasicBlock *NextIterationsHeader) { auto *CurrentTerminator = Latch->getTerminator(); // We can either have a split backedge as our latch terminator. // HeaderBlock: // ... // cond_br %cond, ExitBlock, BackedgeBlock // // BackedgeBlock: // br HeaderBlock: // // Or a conditional branch back to the header. // HeaderBlock: // ... // cond_br %cond, ExitBlock, HeaderBlock // // Redirect the HeaderBlock target to the unrolled successor. In the // unrolled block of the last iteration unconditionally jump to the // ExitBlock instead. // Handle the split backedge case. if (auto *Br = dyn_cast(CurrentTerminator)) { // On the last iteration change the conditional exit to an unconditional // one. if (CurLoopIter == LastLoopIter) { auto *CondBr = cast( Latch->getSinglePredecessorBlock()->getTerminator()); if (CondBr->getTrueBB() != Latch) SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getTrueBB(), CondBr->getTrueArgs()); else SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getFalseBB(), CondBr->getFalseArgs()); CondBr->eraseFromParent(); return; } // Otherwise, branch to the next iteration's header. SILBuilder(Br).createBranch(Br->getLoc(), NextIterationsHeader, Br->getArgs()); Br->eraseFromParent(); return; } // Otherwise, we have a conditional branch to the header. auto *CondBr = cast(CurrentTerminator); // On the last iteration change the conditional exit to an unconditional // one. if (CurLoopIter == LastLoopIter) { if (CondBr->getTrueBB() != OrigHeader) SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getTrueBB(), CondBr->getTrueArgs()); else SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getFalseBB(), CondBr->getFalseArgs()); CondBr->eraseFromParent(); return; } // Otherwise, branch to the next iteration's header. if (CondBr->getTrueBB() == OrigHeader) { SILBuilder(CondBr).createCondBranch( CondBr->getLoc(), CondBr->getCondition(), NextIterationsHeader, CondBr->getTrueArgs(), CondBr->getFalseBB(), CondBr->getFalseArgs()); } else { SILBuilder(CondBr).createCondBranch( CondBr->getLoc(), CondBr->getCondition(), CondBr->getTrueBB(), CondBr->getTrueArgs(), NextIterationsHeader, CondBr->getFalseArgs()); } CondBr->eraseFromParent(); } /// Collect all the loop live out values in the map that maps original live out /// value to live out value in the cloned loop. static void collectLoopLiveOutValues( DenseMap> &LoopLiveOutValues, SILLoop *Loop, DenseMap &ClonedValues) { for (auto *Block : Loop->getBlocks()) { // Look at block arguments. for (auto *Arg : Block->getArguments()) { for (auto *Op : Arg->getUses()) { // Is this use outside the loop? if (!Loop->contains(Op->getUser())) { auto ArgumentValue = SILValue(Arg); assert(ClonedValues.count(ArgumentValue) && "Unmapped Argument!"); if (!LoopLiveOutValues.count(ArgumentValue)) LoopLiveOutValues[ArgumentValue].push_back( ClonedValues[ArgumentValue]); } } } // And the instructions. for (auto &Inst : *Block) { for (SILValue result : Inst.getResults()) { for (auto *Op : result->getUses()) { // Ignore uses inside the loop. if (Loop->contains(Op->getUser())) continue; auto UsedValue = Op->get(); assert(UsedValue == result && "Instructions must match"); if (!LoopLiveOutValues.count(UsedValue)) LoopLiveOutValues[UsedValue].push_back(ClonedValues[result]); } } } } } static void updateSSA(SILLoop *Loop, DenseMap> &LoopLiveOutValues) { SILSSAUpdater SSAUp; for (auto &MapEntry : LoopLiveOutValues) { // Collect out of loop uses of this value. auto OrigValue = MapEntry.first; SmallVector UseList; for (auto Use : OrigValue->getUses()) if (!Loop->contains(Use->getUser()->getParent())) UseList.push_back(UseWrapper(Use)); // Update SSA of use with the available values. SSAUp.Initialize(OrigValue->getType()); SSAUp.AddAvailableValue(OrigValue->getParentBlock(), OrigValue); for (auto NewValue : MapEntry.second) SSAUp.AddAvailableValue(NewValue->getParentBlock(), NewValue); for (auto U : UseList) { Operand *Use = U; SSAUp.RewriteUse(*Use); } } } /// Try to fully unroll the loop if we can determine the trip count and the trip /// count lis below a threshold. static bool tryToUnrollLoop(SILLoop *Loop) { assert(Loop->getSubLoops().empty() && "Expecting innermost loops"); auto *Preheader = Loop->getLoopPreheader(); if (!Preheader) return false; auto *Latch = Loop->getLoopLatch(); if (!Latch) return false; auto *Header = Loop->getHeader(); Optional MaxTripCount = getMaxLoopTripCount(Loop, Preheader, Header, Latch); if (!MaxTripCount) return false; if (!canAndShouldUnrollLoop(Loop, MaxTripCount.getValue())) return false; // TODO: We need to split edges from non-condbr exits for the SSA updater. For // now just don't handle loops containing such exits. SmallVector ExitingBlocks; Loop->getExitingBlocks(ExitingBlocks); for (auto &Exit : ExitingBlocks) if (!isa(Exit->getTerminator())) return false; DEBUG(llvm::dbgs() << "Unrolling loop in " << Header->getParent()->getName() << " " << *Loop << "\n"); SmallVector Headers; Headers.push_back(Header); SmallVector Latches; Latches.push_back(Latch); DenseMap> LoopLiveOutValues; // Copy the body MaxTripCount-1 times. for (uint64_t Cnt = 1; Cnt < *MaxTripCount; ++Cnt) { // Clone the blocks in the loop. LoopCloner Cloner(Loop); Cloner.cloneLoop(); Headers.push_back(Cloner.getBBMap()[Header]); Latches.push_back(Cloner.getBBMap()[Latch]); // Collect values defined in the loop but used outside. On the first // iteration we populate the map from original loop to cloned loop. On // subsequent iterations we only need to update this map with the values // from the new iteration's clone. if (Cnt == 1) collectLoopLiveOutValues(LoopLiveOutValues, Loop, Cloner.getValueMap()); else { for (auto &MapEntry : LoopLiveOutValues) { // Look it up in the value map. SILValue MappedValue = Cloner.getValueMap()[MapEntry.first]; MapEntry.second.push_back(MappedValue); assert(MapEntry.second.size() == Cnt); } } } // Thread the loop clones by redirecting the loop latches to the successor // iteration's header. for (unsigned Iteration = 0, End = Latches.size(); Iteration != End; ++Iteration) { auto *CurrentLatch = Latches[Iteration]; auto LastIteration = End - 1; auto *OriginalHeader = Headers[0]; auto *NextIterationsHeader = Iteration == LastIteration ? nullptr : Headers[Iteration + 1]; redirectTerminator(CurrentLatch, Iteration, LastIteration, OriginalHeader, NextIterationsHeader); } // Fixup SSA form for loop values used outside the loop. updateSSA(Loop, LoopLiveOutValues); return true; } // ============================================================================= // Driver // ============================================================================= namespace { class LoopUnrolling : public SILFunctionTransform { void run() override { bool Changed = false; auto *Fun = getFunction(); SILLoopInfo *LoopInfo = PM->getAnalysis()->get(Fun); // Collect innermost loops. SmallVector InnermostLoops; for (auto *Loop : *LoopInfo) { SmallVector Worklist; Worklist.push_back(Loop); for (unsigned i = 0; i < Worklist.size(); ++i) { auto *L = Worklist[i]; for (auto *SubLoop : *L) Worklist.push_back(SubLoop); if (L->getSubLoops().empty()) InnermostLoops.push_back(L); } } // Try to unroll innermost loops. for (auto *Loop : InnermostLoops) Changed |= tryToUnrollLoop(Loop); if (Changed) { invalidateAnalysis(SILAnalysis::InvalidationKind::FunctionBody); } } }; } // end anonymous namespace SILTransform *swift::createLoopUnroll() { return new LoopUnrolling(); }