RLE: better handling of ref_element/tail_addr [immutable]

Rerun RLE with cutting off the base address of loads at `ref_element/tail_addr [immutable]`. This increases the chance of catching loads of immutable COW class properties or elements.
This commit is contained in:
Erik Eckstein
2021-11-24 16:32:52 +01:00
parent b89f58de6d
commit f97876c9e7
7 changed files with 131 additions and 49 deletions

View File

@@ -26,8 +26,6 @@ namespace swift {
/// nothing left to strip. /// nothing left to strip.
SILValue getUnderlyingObject(SILValue V); SILValue getUnderlyingObject(SILValue V);
SILValue getUnderlyingObjectStopAtMarkDependence(SILValue V);
SILValue stripSinglePredecessorArgs(SILValue V); SILValue stripSinglePredecessorArgs(SILValue V);
/// Return the underlying SILValue after stripping off all casts from the /// Return the underlying SILValue after stripping off all casts from the

View File

@@ -379,21 +379,37 @@ public:
static void reduce(LSLocation Base, SILModule *Mod, static void reduce(LSLocation Base, SILModule *Mod,
TypeExpansionContext context, LSLocationList &Locs); TypeExpansionContext context, LSLocationList &Locs);
/// Gets the base address for `v`.
/// If `stopAtImmutable` is true, the base address is only calculated up to
/// a `ref_element_addr [immutable]` or a `ref_tail_addr [immutable]`.
/// Return the base address and true if such an immutable class projection
/// is found.
static std::pair<SILValue, bool>
getBaseAddressOrObject(SILValue v, bool stopAtImmutable);
/// Enumerate the given Mem LSLocation. /// Enumerate the given Mem LSLocation.
static void enumerateLSLocation(TypeExpansionContext context, SILModule *M, /// If `stopAtImmutable` is true, the base address is only calculated up to
/// a `ref_element_addr [immutable]` or a `ref_tail_addr [immutable]`.
/// Returns true if it's an immutable location.
static bool enumerateLSLocation(TypeExpansionContext context, SILModule *M,
SILValue Mem, SILValue Mem,
std::vector<LSLocation> &LSLocationVault, std::vector<LSLocation> &LSLocationVault,
LSLocationIndexMap &LocToBit, LSLocationIndexMap &LocToBit,
LSLocationBaseMap &BaseToLoc, LSLocationBaseMap &BaseToLoc,
TypeExpansionAnalysis *TE); TypeExpansionAnalysis *TE,
bool stopAtImmutable);
/// Enumerate all the locations in the function. /// Enumerate all the locations in the function.
/// If `stopAtImmutable` is true, the base addresses are only calculated up to
/// a `ref_element_addr [immutable]` or a `ref_tail_addr [immutable]`.
static void enumerateLSLocations(SILFunction &F, static void enumerateLSLocations(SILFunction &F,
std::vector<LSLocation> &LSLocationVault, std::vector<LSLocation> &LSLocationVault,
LSLocationIndexMap &LocToBit, LSLocationIndexMap &LocToBit,
LSLocationBaseMap &BaseToLoc, LSLocationBaseMap &BaseToLoc,
TypeExpansionAnalysis *TE, TypeExpansionAnalysis *TE,
std::pair<int, int> &LSCount); bool stopAtImmutable,
int &numLoads, int &numStores,
bool &immutableLoadsFound);
}; };
static inline llvm::hash_code hash_value(const LSLocation &L) { static inline llvm::hash_code hash_value(const LSLocation &L) {

View File

@@ -62,18 +62,6 @@ SILValue swift::getUnderlyingObject(SILValue v) {
} }
} }
SILValue swift::getUnderlyingObjectStopAtMarkDependence(SILValue v) {
while (true) {
SILValue v2 = stripCastsWithoutMarkDependence(v);
v2 = stripAddressProjections(v2);
v2 = stripIndexingInsts(v2);
v2 = lookThroughOwnershipInsts(v2);
if (v2 == v)
return v2;
v = v2;
}
}
/// Return the underlying SILValue after stripping off identity SILArguments if /// Return the underlying SILValue after stripping off identity SILArguments if
/// we belong to a BB with one predecessor. /// we belong to a BB with one predecessor.
SILValue swift::stripSinglePredecessorArgs(SILValue V) { SILValue swift::stripSinglePredecessorArgs(SILValue V) {

View File

@@ -1182,14 +1182,17 @@ void DSEContext::runIterativeDSE() {
} }
bool DSEContext::run() { bool DSEContext::run() {
std::pair<int, int> LSCount = std::make_pair(0, 0); int numLoads = 0, numStores = 0;
bool immutableLoadsFound = false;
// Walk over the function and find all the locations accessed by // Walk over the function and find all the locations accessed by
// this function. // this function.
LSLocation::enumerateLSLocations(*F, LocationVault, LocToBitIndex, LSLocation::enumerateLSLocations(*F, LocationVault, LocToBitIndex,
BaseToLocIndex, TE, LSCount); BaseToLocIndex, TE,
/*stopAtImmutable*/ false,
numLoads, numStores, immutableLoadsFound);
// Check how to optimize this function. // Check how to optimize this function.
ProcessKind Kind = getProcessFunctionKind(LSCount.second); ProcessKind Kind = getProcessFunctionKind(numStores);
// We do not optimize this function at all. // We do not optimize this function at all.
if (Kind == ProcessKind::ProcessNone) if (Kind == ProcessKind::ProcessNone)

View File

@@ -486,14 +486,21 @@ private:
/// If set, RLE ignores loads from that array type. /// If set, RLE ignores loads from that array type.
NominalTypeDecl *ArrayType; NominalTypeDecl *ArrayType;
/// Se to true if loads with a `ref_element_addr [immutable]` or
/// `ref_tail_addr [immutable]` base address are found.
bool immutableLoadsFound = false;
/// Only optimize loads with a base address of `ref_element_addr [immutable]`
/// `ref_tail_addr [immutable]`.
bool onlyImmutableLoads;
#ifndef NDEBUG #ifndef NDEBUG
SILPrintContext printCtx; SILPrintContext printCtx;
#endif #endif
public: public:
RLEContext(SILFunction *F, SILPassManager *PM, AliasAnalysis *AA, RLEContext(SILFunction *F, SILPassManager *PM,
TypeExpansionAnalysis *TE, PostOrderFunctionInfo *PO, bool disableArrayLoads, bool onlyImmutableLoads);
EpilogueARCFunctionInfo *EAFI, bool disableArrayLoads);
RLEContext(const RLEContext &) = delete; RLEContext(const RLEContext &) = delete;
RLEContext(RLEContext &&) = delete; RLEContext(RLEContext &&) = delete;
@@ -504,6 +511,8 @@ public:
/// Entry point to redundant load elimination. /// Entry point to redundant load elimination.
bool run(); bool run();
bool shouldOptimizeImmutableLoads() const { return immutableLoadsFound; }
SILFunction *getFunction() const { return Fn; } SILFunction *getFunction() const { return Fn; }
/// Use a set of ad hoc rules to tell whether we should run a pessimistic /// Use a set of ad hoc rules to tell whether we should run a pessimistic
@@ -570,6 +579,11 @@ public:
LI->getType().getNominalOrBoundGenericNominal() != ArrayType) { LI->getType().getNominalOrBoundGenericNominal() != ArrayType) {
return LI; return LI;
} }
if (onlyImmutableLoads &&
!LSLocation::getBaseAddressOrObject(LI->getOperand(),
/*stopAtImmutable*/ true).second) {
return nullptr;
}
} }
return nullptr; return nullptr;
} }
@@ -1200,14 +1214,17 @@ void BlockState::dump(RLEContext &Ctx) {
// RLEContext Implementation // RLEContext Implementation
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
RLEContext::RLEContext(SILFunction *F, SILPassManager *PM, AliasAnalysis *AA, RLEContext::RLEContext(SILFunction *F, SILPassManager *PM,
TypeExpansionAnalysis *TE, PostOrderFunctionInfo *PO, bool disableArrayLoads, bool onlyImmutableLoads)
EpilogueARCFunctionInfo *EAFI, bool disableArrayLoads) : Fn(F), PM(PM), AA(PM->getAnalysis<AliasAnalysis>(F)),
: Fn(F), PM(PM), AA(AA), TE(TE), PO(PO), EAFI(EAFI), BBToLocState(F), TE(PM->getAnalysis<TypeExpansionAnalysis>()),
BBWithLoads(F), PO(PM->getAnalysis<PostOrderAnalysis>()->get(F)),
EAFI(PM->getAnalysis<EpilogueARCAnalysis>()->get(F)),
BBToLocState(F), BBWithLoads(F),
ArrayType(disableArrayLoads ArrayType(disableArrayLoads
? F->getModule().getASTContext().getArrayDecl() ? F->getModule().getASTContext().getArrayDecl()
: nullptr) : nullptr),
onlyImmutableLoads(onlyImmutableLoads)
#ifndef NDEBUG #ifndef NDEBUG
, ,
printCtx(llvm::dbgs(), /*Verbose=*/false, /*Sorted=*/true) printCtx(llvm::dbgs(), /*Verbose=*/false, /*Sorted=*/true)
@@ -1567,14 +1584,15 @@ bool RLEContext::run() {
// Phase 4. we perform the redundant load elimination. // Phase 4. we perform the redundant load elimination.
// Walk over the function and find all the locations accessed by // Walk over the function and find all the locations accessed by
// this function. // this function.
std::pair<int, int> LSCount = std::make_pair(0, 0); int numLoads = 0, numStores = 0;
LSLocation::enumerateLSLocations(*Fn, LocationVault, LSLocation::enumerateLSLocations(*Fn, LocationVault,
LocToBitIndex, LocToBitIndex,
BaseToLocIndex, TE, BaseToLocIndex, TE,
LSCount); /*stopAtImmutable*/ onlyImmutableLoads,
numLoads, numStores, immutableLoadsFound);
// Check how to optimize this function. // Check how to optimize this function.
ProcessKind Kind = getProcessFunctionKind(LSCount.first, LSCount.second); ProcessKind Kind = getProcessFunctionKind(numLoads, numStores);
// We do not optimize this function at all. // We do not optimize this function at all.
if (Kind == ProcessKind::ProcessNone) if (Kind == ProcessKind::ProcessNone)
@@ -1681,15 +1699,21 @@ public:
LLVM_DEBUG(llvm::dbgs() << "*** RLE on function: " << F->getName() LLVM_DEBUG(llvm::dbgs() << "*** RLE on function: " << F->getName()
<< " ***\n"); << " ***\n");
auto *AA = PM->getAnalysis<AliasAnalysis>(F); RLEContext RLE(F, PM, disableArrayLoads,
auto *TE = PM->getAnalysis<TypeExpansionAnalysis>(); /*onlyImmutableLoads*/ false);
auto *PO = PM->getAnalysis<PostOrderAnalysis>()->get(F);
auto *EAFI = PM->getAnalysis<EpilogueARCAnalysis>()->get(F);
RLEContext RLE(F, PM, AA, TE, PO, EAFI, disableArrayLoads);
if (RLE.run()) { if (RLE.run()) {
invalidateAnalysis(SILAnalysis::InvalidationKind::Instructions); invalidateAnalysis(SILAnalysis::InvalidationKind::Instructions);
} }
if (RLE.shouldOptimizeImmutableLoads()) {
/// Re-running RLE with cutting base addresses off at
/// `ref_element_addr [immutable]` or `ref_tail_addr [immutable]` can
/// expose additional opportunities.
RLEContext RLE2(F, PM, disableArrayLoads,
/*onlyImmutableLoads*/ true);
if (RLE2.run()) {
invalidateAnalysis(SILAnalysis::InvalidationKind::Instructions);
}
}
} }
}; };

View File

@@ -245,28 +245,60 @@ void LSLocation::reduce(LSLocation Base, SILModule *M,
replaceSubLocations(Base, M, context, Locs, SubLocations); replaceSubLocations(Base, M, context, Locs, SubLocations);
} }
void LSLocation::enumerateLSLocation(TypeExpansionContext context, SILModule *M, std::pair<SILValue, bool>
LSLocation::getBaseAddressOrObject(SILValue v, bool stopAtImmutable) {
bool isImmutable = false;
while (true) {
if (auto *rea = dyn_cast<RefElementAddrInst>(v)) {
if (rea->isImmutable()) {
isImmutable = true;
if (stopAtImmutable)
return {v, true};
}
}
if (auto *rta = dyn_cast<RefTailAddrInst>(v)) {
if (rta->isImmutable()) {
isImmutable = true;
if (stopAtImmutable)
return {v, true};
}
}
SILValue v2 = stripCastsWithoutMarkDependence(v);
v2 = stripSinglePredecessorArgs(v2);
if (Projection::isAddressProjection(v2))
v2 = cast<SingleValueInstruction>(v2)->getOperand(0);
v2 = stripIndexingInsts(v2);
v2 = lookThroughOwnershipInsts(v2);
if (v2 == v)
return {v2, isImmutable};
v = v2;
}
}
bool LSLocation::enumerateLSLocation(TypeExpansionContext context, SILModule *M,
SILValue Mem, SILValue Mem,
std::vector<LSLocation> &Locations, std::vector<LSLocation> &Locations,
LSLocationIndexMap &IndexMap, LSLocationIndexMap &IndexMap,
LSLocationBaseMap &BaseMap, LSLocationBaseMap &BaseMap,
TypeExpansionAnalysis *TypeCache) { TypeExpansionAnalysis *TypeCache,
bool stopAtImmutable) {
// We have processed this SILValue before. // We have processed this SILValue before.
if (BaseMap.find(Mem) != BaseMap.end()) if (BaseMap.find(Mem) != BaseMap.end())
return; return false;
// Construct a Location to represent the memory written by this instruction. // Construct a Location to represent the memory written by this instruction.
// ProjectionPath currently does not handle mark_dependence so stop our // ProjectionPath currently does not handle mark_dependence so stop our
// underlying object search at these instructions. // underlying object search at these instructions.
// We still get a benefit if we cse mark_dependence instructions and then // We still get a benefit if we cse mark_dependence instructions and then
// merge loads from them. // merge loads from them.
SILValue UO = getUnderlyingObjectStopAtMarkDependence(Mem); auto baseAndImmutable = getBaseAddressOrObject(Mem, stopAtImmutable);
SILValue UO = baseAndImmutable.first;
LSLocation L(UO, ProjectionPath::getProjectionPath(UO, Mem)); LSLocation L(UO, ProjectionPath::getProjectionPath(UO, Mem));
// If we can't figure out the Base or Projection Path for the memory location, // If we can't figure out the Base or Projection Path for the memory location,
// simply ignore it for now. // simply ignore it for now.
if (!L.isValid()) if (!L.isValid())
return; return false;
// Record the SILValue to location mapping. // Record the SILValue to location mapping.
BaseMap[Mem] = L; BaseMap[Mem] = L;
@@ -281,6 +313,7 @@ void LSLocation::enumerateLSLocation(TypeExpansionContext context, SILModule *M,
IndexMap[Loc] = Locations.size(); IndexMap[Loc] = Locations.size();
Locations.push_back(Loc); Locations.push_back(Loc);
} }
return baseAndImmutable.first;
} }
void void
@@ -289,22 +322,26 @@ LSLocation::enumerateLSLocations(SILFunction &F,
LSLocationIndexMap &IndexMap, LSLocationIndexMap &IndexMap,
LSLocationBaseMap &BaseMap, LSLocationBaseMap &BaseMap,
TypeExpansionAnalysis *TypeCache, TypeExpansionAnalysis *TypeCache,
std::pair<int, int> &LSCount) { bool stopAtImmutable,
int &numLoads, int &numStores,
bool &immutableLoadsFound) {
// Enumerate all locations accessed by the loads or stores. // Enumerate all locations accessed by the loads or stores.
for (auto &B : F) { for (auto &B : F) {
for (auto &I : B) { for (auto &I : B) {
if (auto *LI = dyn_cast<LoadInst>(&I)) { if (auto *LI = dyn_cast<LoadInst>(&I)) {
enumerateLSLocation(F.getTypeExpansionContext(), &I.getModule(), if (enumerateLSLocation(F.getTypeExpansionContext(), &I.getModule(),
LI->getOperand(), Locations, IndexMap, BaseMap, LI->getOperand(), Locations, IndexMap, BaseMap,
TypeCache); TypeCache, stopAtImmutable)) {
++LSCount.first; immutableLoadsFound = true;
}
++numLoads;
continue; continue;
} }
if (auto *SI = dyn_cast<StoreInst>(&I)) { if (auto *SI = dyn_cast<StoreInst>(&I)) {
enumerateLSLocation(F.getTypeExpansionContext(), &I.getModule(), enumerateLSLocation(F.getTypeExpansionContext(), &I.getModule(),
SI->getDest(), Locations, IndexMap, BaseMap, SI->getDest(), Locations, IndexMap, BaseMap,
TypeCache); TypeCache, stopAtImmutable);
++LSCount.second; ++numStores;
continue; continue;
} }
} }

View File

@@ -163,6 +163,22 @@ bb0(%0 : @owned $AB):
return %5 : $Int return %5 : $Int
} }
// CHECK-LABEL: sil [ossa] @forward_load_of_immutable_class_property
// CHECK: [[L:%[0-9]+]] = load
// CHECK: apply %{{[0-9]+}}([[L]])
// CHECK-NOT: load
// CHECK: return [[L]]
// CHECK-LABEL: } // end sil function 'forward_load_of_immutable_class_property'
sil [ossa] @forward_load_of_immutable_class_property : $@convention(thin) (@guaranteed AB) -> Int {
bb0(%0 : @guaranteed $AB):
%1 = ref_element_addr [immutable] %0 : $AB, #AB.value
%2 = load [trivial] %1 : $*Int
%3 = function_ref @use_Int : $@convention(thin) (Int) -> ()
apply %3(%2) : $@convention(thin) (Int) -> ()
%5 = load [trivial] %1 : $*Int
return %5 : $Int
}
// CHECK-LABEL: sil hidden [ossa] @load_forward_across_end_cow_mutation : // CHECK-LABEL: sil hidden [ossa] @load_forward_across_end_cow_mutation :
// CHECK-NOT: = load // CHECK-NOT: = load
// CHECK-LABEL: } // end sil function 'load_forward_across_end_cow_mutation' // CHECK-LABEL: } // end sil function 'load_forward_across_end_cow_mutation'