mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Unroll loops with known short trip count
This enables array value propagation in array literal loops like:
for e in [2,3,4] {
r += e
}
Allowing us to completely get rid of the array.
rdar://19958821
SR-203
This commit is contained in:
464
lib/SILOptimizer/LoopTransforms/LoopUnroll.cpp
Normal file
464
lib/SILOptimizer/LoopTransforms/LoopUnroll.cpp
Normal file
@@ -0,0 +1,464 @@
|
||||
//===--------- LoopUnroll.cpp - Loop unrolling ------------*- C++ -*-------===//
|
||||
//
|
||||
// This source file is part of the Swift.org open source project
|
||||
//
|
||||
// Copyright (c) 2014 - 2015 Apple Inc. and the Swift project authors
|
||||
// Licensed under Apache License v2.0 with Runtime Library Exception
|
||||
//
|
||||
// See http://swift.org/LICENSE.txt for license information
|
||||
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define DEBUG_TYPE "sil-loopunroll"
|
||||
|
||||
#include "llvm/ADT/DepthFirstIterator.h"
|
||||
|
||||
#include "swift/SIL/PatternMatch.h"
|
||||
#include "swift/SIL/SILCloner.h"
|
||||
#include "swift/SILOptimizer/Analysis/LoopAnalysis.h"
|
||||
#include "swift/SILOptimizer/PassManager/Passes.h"
|
||||
#include "swift/SILOptimizer/PassManager/Transforms.h"
|
||||
#include "swift/SILOptimizer/Utils/SILInliner.h"
|
||||
#include "swift/SILOptimizer/Utils/SILSSAUpdater.h"
|
||||
|
||||
using namespace swift;
|
||||
using namespace swift::PatternMatch;
|
||||
|
||||
using llvm::DenseMap;
|
||||
using llvm::MapVector;
|
||||
|
||||
static const uint64_t SILLoopUnrollThreshold = 250;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Clone the basic blocks in a loop.
|
||||
class LoopCloner : public SILCloner<LoopCloner> {
|
||||
SILLoop *Loop;
|
||||
|
||||
friend class SILVisitor<LoopCloner>;
|
||||
friend class SILCloner<LoopCloner>;
|
||||
|
||||
public:
|
||||
LoopCloner(SILLoop *Loop)
|
||||
: SILCloner<LoopCloner>(*Loop->getHeader()->getParent()), Loop(Loop) {}
|
||||
|
||||
/// Clone the basic blocks in the loop.
|
||||
void cloneLoop();
|
||||
|
||||
/// Get a map from basic blocks or the original loop to the cloned loop.
|
||||
MapVector<SILBasicBlock *, SILBasicBlock *> &getBBMap() { return BBMap; }
|
||||
|
||||
DenseMap<SILValue, SILValue> &getValueMap() { return ValueMap; }
|
||||
DenseMap<SILInstruction *, SILInstruction *> &getInstMap() {
|
||||
return InstructionMap;
|
||||
}
|
||||
|
||||
protected:
|
||||
SILValue remapValue(SILValue V) {
|
||||
if (auto *BB = V.getDef()->getParentBB()) {
|
||||
if (!Loop->contains(BB))
|
||||
return V;
|
||||
}
|
||||
return SILCloner<LoopCloner>::remapValue(V);
|
||||
}
|
||||
void postProcess(SILInstruction *Orig, SILInstruction *Cloned) {
|
||||
SILCloner<LoopCloner>::postProcess(Orig, Cloned);
|
||||
}
|
||||
};
|
||||
|
||||
} // End anonymous namespace.
|
||||
|
||||
void LoopCloner::cloneLoop() {
|
||||
auto *Header = Loop->getHeader();
|
||||
auto *CurFun = Loop->getHeader()->getParent();
|
||||
auto &Mod = CurFun->getModule();
|
||||
|
||||
SmallVector<SILBasicBlock *, 16> ExitBlocks;
|
||||
Loop->getExitBlocks(ExitBlocks);
|
||||
for (auto *ExitBB : ExitBlocks)
|
||||
BBMap[ExitBB] = ExitBB;
|
||||
|
||||
auto *ClonedHeader = new (Mod) SILBasicBlock(CurFun);
|
||||
BBMap[Header] = ClonedHeader;
|
||||
|
||||
// Clone the arguments.
|
||||
for (auto *Arg : Header->getBBArgs()) {
|
||||
SILValue MappedArg =
|
||||
new (Mod) SILArgument(ClonedHeader, getOpType(Arg->getType()));
|
||||
ValueMap.insert(std::make_pair(Arg, MappedArg));
|
||||
}
|
||||
|
||||
// Clone the instructions in this basic block and recursively clone
|
||||
// successor blocks.
|
||||
getBuilder().setInsertionPoint(ClonedHeader);
|
||||
visitSILBasicBlock(Header);
|
||||
// Fix-up terminators.
|
||||
for (auto BBPair : BBMap)
|
||||
if (BBPair.first != BBPair.second) {
|
||||
getBuilder().setInsertionPoint(BBPair.second);
|
||||
visit(BBPair.first->getTerminator());
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the number of iterations the loop is at most executed. The loop
|
||||
/// might contain early exits so this is the maximum if no early exits are
|
||||
/// taken.
|
||||
static Optional<uint64_t> getMaxLoopTripCount(SILLoop *Loop,
|
||||
SILBasicBlock *Preheader,
|
||||
SILBasicBlock *Header,
|
||||
SILBasicBlock *Latch) {
|
||||
|
||||
// Skip a split backedge.
|
||||
SILBasicBlock *OrigLatch = Latch;
|
||||
if (!Loop->isLoopExiting(Latch) && !(Latch = Latch->getSinglePredecessor()))
|
||||
return None;
|
||||
if (!Loop->isLoopExiting(Latch))
|
||||
return None;
|
||||
|
||||
// Get the loop exit condition.
|
||||
auto *CondBr = dyn_cast<CondBranchInst>(Latch->getTerminator());
|
||||
if (!CondBr)
|
||||
return None;
|
||||
|
||||
// Match an add 1 recurrence.
|
||||
SILArgument *RecArg;
|
||||
IntegerLiteralInst *End;
|
||||
SILValue RecNext;
|
||||
|
||||
if (!match(CondBr->getCondition(),
|
||||
m_BuiltinInst(BuiltinValueKind::ICMP_EQ, m_SILValue(RecNext),
|
||||
m_IntegerLiteralInst(End))))
|
||||
return None;
|
||||
if (!match(RecNext,
|
||||
m_TupleExtractInst(m_ApplyInst(BuiltinValueKind::SAddOver,
|
||||
m_SILArgument(RecArg), m_One()),
|
||||
0)))
|
||||
return None;
|
||||
|
||||
if (RecArg->getParent() != Header)
|
||||
return None;
|
||||
|
||||
auto *Start = dyn_cast_or_null<IntegerLiteralInst>(
|
||||
RecArg->getIncomingValue(Preheader).getDef());
|
||||
if (!Start)
|
||||
return None;
|
||||
|
||||
if (RecNext != RecArg->getIncomingValue(OrigLatch))
|
||||
return None;
|
||||
|
||||
auto StartVal = Start->getValue();
|
||||
auto EndVal = End->getValue();
|
||||
if (StartVal.sgt(EndVal))
|
||||
return None;
|
||||
|
||||
auto Dist = EndVal - StartVal;
|
||||
if (Dist.getBitWidth() > 64)
|
||||
return None;
|
||||
|
||||
if (Dist == 0)
|
||||
return None;
|
||||
|
||||
return Dist.getZExtValue();
|
||||
}
|
||||
|
||||
/// Check whether we can duplicate the instructions in the loop and use a
|
||||
/// heuristic that looks at the trip count and the cost of the instructions in
|
||||
/// the loop to determine whether we should unroll this loop.
|
||||
static bool canAndShouldUnrollLoop(SILLoop *Loop, uint64_t TripCount) {
|
||||
assert(Loop->getSubLoops().empty() && "Expect innermost loops");
|
||||
if (TripCount > 32)
|
||||
return false;
|
||||
|
||||
// We can unroll a loop if we can duplicate the instructions it holds.
|
||||
uint64_t Cost = 0;
|
||||
for (auto *BB : Loop->getBlocks()) {
|
||||
for (auto &Inst : *BB) {
|
||||
if (!Loop->canDuplicate(&Inst))
|
||||
return false;
|
||||
if (instructionInlineCost(Inst) != InlineCost::Free)
|
||||
++Cost;
|
||||
if (Cost * TripCount > SILLoopUnrollThreshold)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Redirect the terminator of the current loop iteration's latch to the next
|
||||
/// iterations header or if this is the last iteration remove the backedge to
|
||||
/// the header.
|
||||
static void redirectTerminator(SILBasicBlock *Latch, unsigned CurLoopIter,
|
||||
unsigned LastLoopIter, SILBasicBlock *OrigHeader,
|
||||
SILBasicBlock *NextIterationsHeader) {
|
||||
|
||||
auto *CurrentTerminator = Latch->getTerminator();
|
||||
|
||||
// We can either have a split backedge as our latch terminator.
|
||||
// HeaderBlock:
|
||||
// ...
|
||||
// cond_br %cond, ExitBlock, BackegdeBlock
|
||||
//
|
||||
// BackegdeBlock:
|
||||
// br HeaderBlock:
|
||||
//
|
||||
// Or a a conditional branch back to the header.
|
||||
// HeaderBlock:
|
||||
// ...
|
||||
// cond_br %cond, ExitBlock, HeaderBlock
|
||||
//
|
||||
// Redirect the HeaderBlock target to the unrolled successor. In the
|
||||
// unrolled block of the last iteration unconditionally jump to the
|
||||
// ExitBlock instead.
|
||||
|
||||
// Handle the split backedge case.
|
||||
if (auto *Br = dyn_cast<BranchInst>(CurrentTerminator)) {
|
||||
// On the last iteration change the conditional exit to an unconditional
|
||||
// one.
|
||||
if (CurLoopIter == LastLoopIter) {
|
||||
auto *CondBr =
|
||||
cast<CondBranchInst>(Latch->getSinglePredecessor()->getTerminator());
|
||||
if (CondBr->getTrueBB() != Latch)
|
||||
SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getTrueBB(),
|
||||
CondBr->getTrueArgs());
|
||||
else
|
||||
SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getFalseBB(),
|
||||
CondBr->getFalseArgs());
|
||||
CondBr->eraseFromParent();
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, branch to the next iteration's header.
|
||||
SILBuilder(Br).createBranch(Br->getLoc(), NextIterationsHeader,
|
||||
Br->getArgs());
|
||||
Br->eraseFromParent();
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, we have a conditional branch to the header.
|
||||
auto *CondBr = cast<CondBranchInst>(CurrentTerminator);
|
||||
// On the last iteration change the conditional exit to an unconditional
|
||||
// one.
|
||||
if (CurLoopIter == LastLoopIter) {
|
||||
if (CondBr->getTrueBB() != OrigHeader)
|
||||
SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getTrueBB(),
|
||||
CondBr->getTrueArgs());
|
||||
else
|
||||
SILBuilder(CondBr).createBranch(CondBr->getLoc(), CondBr->getFalseBB(),
|
||||
CondBr->getFalseArgs());
|
||||
CondBr->eraseFromParent();
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, branch to the next iteration's header.
|
||||
if (CondBr->getTrueBB() == OrigHeader) {
|
||||
SILBuilder(CondBr).createCondBranch(
|
||||
CondBr->getLoc(), CondBr->getCondition(), NextIterationsHeader,
|
||||
CondBr->getTrueArgs(), CondBr->getFalseBB(), CondBr->getFalseArgs());
|
||||
} else {
|
||||
SILBuilder(CondBr).createCondBranch(
|
||||
CondBr->getLoc(), CondBr->getCondition(), CondBr->getTrueBB(),
|
||||
CondBr->getTrueArgs(), NextIterationsHeader, CondBr->getFalseArgs());
|
||||
}
|
||||
CondBr->eraseFromParent();
|
||||
}
|
||||
|
||||
/// Collect all the loop live out values in the map that maps original live out
|
||||
/// value to live out value in the cloned loop.
|
||||
static void collectLoopLiveOutValues(
|
||||
DenseMap<SILValue, SmallVector<SILValue, 8>> &LoopLiveOutValues,
|
||||
SILLoop *Loop, DenseMap<SILValue, SILValue> &ClonedValues,
|
||||
DenseMap<SILInstruction *, SILInstruction *> &ClonedInstructions) {
|
||||
for (auto *Block : Loop->getBlocks()) {
|
||||
// Look at block arguments.
|
||||
for (auto *Arg : Block->getBBArgs()) {
|
||||
for (auto *Op : Arg->getUses()) {
|
||||
// Is this use outside the loop?
|
||||
if (!Loop->contains(Op->getUser())) {
|
||||
auto ArgumentValue = SILValue(Arg);
|
||||
assert(ClonedValues.count(ArgumentValue) && "Unmapped Argument!");
|
||||
|
||||
if (!LoopLiveOutValues.count(ArgumentValue))
|
||||
LoopLiveOutValues[ArgumentValue].push_back(
|
||||
ClonedValues[ArgumentValue]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// And the instructions.
|
||||
for (auto &Inst : *Block) {
|
||||
for (auto *Op : Inst.getUses()) {
|
||||
// Is this use outside the loop.
|
||||
if (!Loop->contains(Op->getUser())) {
|
||||
auto UsedValue = Op->get();
|
||||
assert(UsedValue.getDef() == &Inst && "Instructions must match");
|
||||
assert(ClonedInstructions.count(&Inst) && "Unmapped instruction!");
|
||||
|
||||
if (!LoopLiveOutValues.count(UsedValue))
|
||||
LoopLiveOutValues[UsedValue].push_back(SILValue(
|
||||
ClonedInstructions[&Inst], UsedValue.getResultNumber()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
updateSSA(SILLoop *Loop,
|
||||
DenseMap<SILValue, SmallVector<SILValue, 8>> &LoopLiveOutValues) {
|
||||
SILSSAUpdater SSAUp;
|
||||
for (auto &MapEntry : LoopLiveOutValues) {
|
||||
// Collect out of loop uses of this value.
|
||||
auto OrigValue = MapEntry.first;
|
||||
SmallVector<UseWrapper, 16> UseList;
|
||||
for (auto Use : OrigValue.getUses())
|
||||
if (!Loop->contains(Use->getUser()->getParent()))
|
||||
UseList.push_back(UseWrapper(Use));
|
||||
// Update SSA of use with the available values.
|
||||
SSAUp.Initialize(OrigValue.getType());
|
||||
SSAUp.AddAvailableValue(OrigValue->getParentBB(), OrigValue);
|
||||
for (auto NewValue : MapEntry.second)
|
||||
SSAUp.AddAvailableValue(NewValue->getParentBB(), NewValue);
|
||||
for (auto U : UseList) {
|
||||
Operand *Use = U;
|
||||
SSAUp.RewriteUse(*Use);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to fully unroll the loop if we can determine the trip count and the trip
|
||||
/// count lis below a threshold.
|
||||
static bool tryToUnrollLoop(SILLoop *Loop) {
|
||||
assert(Loop->getSubLoops().empty() && "Expecting innermost loops");
|
||||
|
||||
auto *Preheader = Loop->getLoopPreheader();
|
||||
if (!Preheader)
|
||||
return false;
|
||||
|
||||
auto *Latch = Loop->getLoopLatch();
|
||||
if (!Latch)
|
||||
return false;
|
||||
|
||||
auto *Header = Loop->getHeader();
|
||||
|
||||
Optional<uint64_t> MaxTripCount =
|
||||
getMaxLoopTripCount(Loop, Preheader, Header, Latch);
|
||||
if (!MaxTripCount)
|
||||
return false;
|
||||
|
||||
if (!canAndShouldUnrollLoop(Loop, MaxTripCount.getValue()))
|
||||
return false;
|
||||
|
||||
// TODO: We need to split edges from non-condbr exits for the SSA updater. For
|
||||
// now just don't handle loops containing such exits.
|
||||
SmallVector<SILBasicBlock *, 16> ExitingBlocks;
|
||||
Loop->getExitingBlocks(ExitingBlocks);
|
||||
for (auto &Exit : ExitingBlocks)
|
||||
if (!isa<CondBranchInst>(Exit->getTerminator()))
|
||||
return false;
|
||||
|
||||
SmallVector<SILBasicBlock *, 16> Headers;
|
||||
Headers.push_back(Header);
|
||||
|
||||
SmallVector<SILBasicBlock *, 16> Latches;
|
||||
Latches.push_back(Latch);
|
||||
|
||||
DenseMap<SILValue, SmallVector<SILValue, 8>> LoopLiveOutValues;
|
||||
|
||||
// Copy the body MaxTripCount-1 times.
|
||||
for (uint64_t Cnt = 1; Cnt < *MaxTripCount; ++Cnt) {
|
||||
// Clone the blocks in the loop.
|
||||
LoopCloner Cloner(Loop);
|
||||
Cloner.cloneLoop();
|
||||
Headers.push_back(Cloner.getBBMap()[Header]);
|
||||
Latches.push_back(Cloner.getBBMap()[Latch]);
|
||||
|
||||
// Collect values defined in the loop but used outside. On the first
|
||||
// iteration we populate the map from original loop to cloned loop. On
|
||||
// subsequent iterations we only need to update this map with the values
|
||||
// from the new iteration's clone.
|
||||
if (Cnt == 1)
|
||||
collectLoopLiveOutValues(LoopLiveOutValues, Loop, Cloner.getValueMap(),
|
||||
Cloner.getInstMap());
|
||||
else {
|
||||
for (auto &MapEntry : LoopLiveOutValues) {
|
||||
// If this is an argument look up the value in the value map.
|
||||
SILValue MappedValue;
|
||||
if (isa<SILArgument>(MapEntry.first))
|
||||
MappedValue = Cloner.getValueMap()[MapEntry.first];
|
||||
// Otherwise, consult the instruction map.
|
||||
else
|
||||
MappedValue = SILValue(
|
||||
Cloner
|
||||
.getInstMap()[cast<SILInstruction>(MapEntry.first.getDef())],
|
||||
MapEntry.first.getResultNumber());
|
||||
MapEntry.second.push_back(MappedValue);
|
||||
assert(MapEntry.second.size() == Cnt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Thread the loop clones by redirecting the loop latches to the successor
|
||||
// iteration's header.
|
||||
for (unsigned Iteration = 0, End = Latches.size(); Iteration != End;
|
||||
++Iteration) {
|
||||
auto *CurrentLatch = Latches[Iteration];
|
||||
auto LastIteration = End - 1;
|
||||
auto *OriginalHeader = Headers[0];
|
||||
auto *NextIterationsHeader =
|
||||
Iteration == LastIteration ? nullptr : Headers[Iteration + 1];
|
||||
|
||||
redirectTerminator(CurrentLatch, Iteration, LastIteration, OriginalHeader,
|
||||
NextIterationsHeader);
|
||||
}
|
||||
|
||||
// Fixup SSA form for loop values used outside the loop.
|
||||
updateSSA(Loop, LoopLiveOutValues);
|
||||
return true;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Driver
|
||||
// =============================================================================
|
||||
|
||||
namespace {
|
||||
|
||||
class LoopUnrolling : public SILFunctionTransform {
|
||||
|
||||
StringRef getName() override { return "SIL Loop Unrolling"; }
|
||||
|
||||
void run() override {
|
||||
bool Changed = false;
|
||||
|
||||
auto *Fun = getFunction();
|
||||
SILLoopInfo *LoopInfo = PM->getAnalysis<SILLoopAnalysis>()->get(Fun);
|
||||
|
||||
// Collect innermost loops.
|
||||
SmallVector<SILLoop *, 16> InnermostLoops;
|
||||
for (auto *Loop : *LoopInfo) {
|
||||
SmallVector<SILLoop *, 8> Worklist;
|
||||
Worklist.push_back(Loop);
|
||||
|
||||
for (unsigned i = 0; i < Worklist.size(); ++i) {
|
||||
auto *L = Worklist[i];
|
||||
for (auto *SubLoop : *L)
|
||||
Worklist.push_back(SubLoop);
|
||||
if (L->getSubLoops().empty())
|
||||
InnermostLoops.push_back(L);
|
||||
}
|
||||
}
|
||||
|
||||
// Try to unroll innermost loops.
|
||||
for (auto *Loop : InnermostLoops)
|
||||
Changed |= tryToUnrollLoop(Loop);
|
||||
|
||||
if (Changed) {
|
||||
invalidateAnalysis(SILAnalysis::InvalidationKind::FunctionBody);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace.
|
||||
|
||||
SILTransform *swift::createLoopUnroll() {
|
||||
return new LoopUnrolling();
|
||||
}
|
||||
Reference in New Issue
Block a user