LoadableByAddress: convert types of functions which are contained in structs inside global static initializers

We missed converting such types inside static initializers of global variables.
This results in ptrauth crashes when ptrauth is enabled.

rdar://108165425
This commit is contained in:
Erik Eckstein
2023-04-28 08:46:08 +02:00
parent caeab5889a
commit 40a7527b67
2 changed files with 74 additions and 56 deletions

View File

@@ -26,6 +26,7 @@
#include "swift/SIL/DebugUtils.h"
#include "swift/SIL/SILArgument.h"
#include "swift/SIL/SILBuilder.h"
#include "swift/SIL/SILCloner.h"
#include "swift/SIL/SILUndef.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
#include "swift/SILOptimizer/Utils/InstOptUtils.h"
@@ -1713,6 +1714,8 @@ private:
bool shouldTransformGlobal(SILGlobalVariable *global);
bool shouldTransformInitExprOfGlobal(SILGlobalVariable *global);
private:
llvm::SetVector<SILFunction *> modFuncs;
llvm::SetVector<SingleValueInstruction *> conversionInstrs;
@@ -1724,6 +1727,8 @@ private:
llvm::SetVector<StoreInst *> storeToBlockStorageInstrs;
llvm::SetVector<SILInstruction *> modApplies;
llvm::MapVector<SILInstruction *, SILValue> allApplyRetToAllocMap;
public:
LargeSILTypeMapper MapperCache;
};
} // end anonymous namespace
@@ -2940,6 +2945,42 @@ bool LoadableByAddress::shouldTransformGlobal(SILGlobalVariable *global) {
return false;
}
bool LoadableByAddress::shouldTransformInitExprOfGlobal(SILGlobalVariable *global) {
for (const SILInstruction &initInst : *global) {
if (auto *fri = dyn_cast<FunctionRefBaseInst>(&initInst)) {
SILFunction *refF = fri->getInitiallyReferencedFunction();
if (modFuncs.count(refF) != 0)
return true;
}
}
return false;
}
namespace {
class GlobalInitCloner : public SILCloner<GlobalInitCloner> {
LoadableByAddress *pass;
IRGenModule *irgenModule;
public:
GlobalInitCloner(SILGlobalVariable *global, LoadableByAddress *pass,
IRGenModule *irgenModule)
: SILCloner<GlobalInitCloner>(global), pass(pass), irgenModule(irgenModule) {
}
SILType remapType(SILType ty) {
if (auto fnType = ty.getAs<SILFunctionType>()) {
GenericEnvironment *genEnv = getSubstGenericEnvironment(fnType);
return SILType::getPrimitiveObjectType(
pass->MapperCache.getNewSILFunctionType(genEnv, fnType, *irgenModule));
}
return ty;
}
void clone(SILInstruction *inst) {
visit(inst);
}
};
}
/// The entry point to this function transformation.
void LoadableByAddress::run() {
// Set the SIL state before the PassManager has a chance to run
@@ -3078,72 +3119,34 @@ void LoadableByAddress::run() {
updateLoweredTypes(F);
}
auto computeNewResultType = [&](SILType ty, IRGenModule *mod) -> SILType {
auto currSILFunctionType = ty.castTo<SILFunctionType>();
GenericEnvironment *genEnv =
getSubstGenericEnvironment(currSILFunctionType);
return SILType::getPrimitiveObjectType(
MapperCache.getNewSILFunctionType(genEnv, currSILFunctionType, *mod));
};
// Update globals' initializer.
SmallVector<SILGlobalVariable *, 16> deadGlobals;
for (SILGlobalVariable &global : getModule()->getSILGlobals()) {
SILInstruction *init = global.getStaticInitializerValue();
if (!init)
continue;
auto silTy = global.getLoweredType();
if (!isa<SILFunctionType>(silTy.getASTType()))
continue;
auto *decl = global.getDecl();
IRGenModule *currIRMod = getIRGenModule()->IRGen.getGenModule(
decl ? decl->getDeclContext() : nullptr);
auto silFnTy = global.getLoweredFunctionType();
GenericEnvironment *genEnv = getSubstGenericEnvironment(silFnTy);
if (shouldTransformInitExprOfGlobal(&global)) {
auto *decl = global.getDecl();
IRGenModule *currIRMod = getIRGenModule()->IRGen.getGenModule(
decl ? decl->getDeclContext() : nullptr);
// Update the global's type.
if (MapperCache.shouldTransformFunctionType(genEnv, silFnTy, *currIRMod)) {
auto newSILFnType =
MapperCache.getNewSILFunctionType(genEnv, silFnTy, *currIRMod);
global.unsafeSetLoweredType(
SILType::getPrimitiveObjectType(newSILFnType));
auto silTy = global.getLoweredType();
if (isa<SILFunctionType>(silTy.getASTType())) {
auto silFnTy = global.getLoweredFunctionType();
GenericEnvironment *genEnv = getSubstGenericEnvironment(silFnTy);
if (MapperCache.shouldTransformFunctionType(genEnv, silFnTy, *currIRMod)) {
auto newSILFnType =
MapperCache.getNewSILFunctionType(genEnv, silFnTy, *currIRMod);
global.unsafeSetLoweredType(SILType::getPrimitiveObjectType(newSILFnType));
}
}
// Rewrite the init basic block...
SmallVector<SILInstruction *, 8> initBlockInsts;
for (auto it = global.begin(), end = global.end(); it != end; ++it) {
initBlockInsts.push_back(const_cast<SILInstruction *>(&*it));
}
GlobalInitCloner cloner(&global, this, currIRMod);
for (auto *oldInst : initBlockInsts) {
if (auto *f = dyn_cast<FunctionRefInst>(oldInst)) {
SILBuilder builder(&global);
auto *newInst = builder.createFunctionRef(
f->getLoc(), f->getInitiallyReferencedFunction(), f->getKind());
f->replaceAllUsesWith(newInst);
global.unsafeRemove(f, *getModule());
} else if (auto *cvt = dyn_cast<ConvertFunctionInst>(oldInst)) {
auto newType = computeNewResultType(cvt->getType(), currIRMod);
SILBuilder builder(&global);
auto *newInst = builder.createConvertFunction(
cvt->getLoc(), cvt->getOperand(), newType,
cvt->withoutActuallyEscaping());
cvt->replaceAllUsesWith(newInst);
global.unsafeRemove(cvt, *getModule());
} else if (auto *thinToThick =
dyn_cast<ThinToThickFunctionInst>(oldInst)) {
auto newType =
computeNewResultType(thinToThick->getType(), currIRMod);
SILBuilder builder(&global);
auto *newInstr = builder.createThinToThickFunction(
thinToThick->getLoc(), thinToThick->getOperand(), newType);
thinToThick->replaceAllUsesWith(newInstr);
global.unsafeRemove(thinToThick, *getModule());
} else {
auto *sv = cast<SingleValueInstruction>(oldInst);
auto *newInst = cast<SingleValueInstruction>(oldInst->clone());
global.unsafeAppend(newInst);
sv->replaceAllUsesWith(newInst);
global.unsafeRemove(oldInst, *getModule());
}
cloner.clone(oldInst);
global.unsafeRemove(oldInst, *getModule());
}
}
}