IRGen: Properly adjust the closure type of a partial_apply of an objc_method

It needs to match with the (large loadable) lowered closure type in the rest of
the program: Large types in the signature need to be passed indirectly.

rdar://127367321
This commit is contained in:
Arnold Schwaighofer
2024-05-21 10:05:09 -07:00
parent 55a2a412dc
commit d89bf2893b
5 changed files with 83 additions and 16 deletions

View File

@@ -81,10 +81,12 @@ public:
irgen::IRGenModule &IGM);
SmallVector<SILResultInfo, 2> getNewResults(GenericEnvironment *GenericEnv,
CanSILFunctionType fnType,
irgen::IRGenModule &Mod);
irgen::IRGenModule &Mod,
bool mustTransform = false);
CanSILFunctionType getNewSILFunctionType(GenericEnvironment *env,
CanSILFunctionType fnType,
irgen::IRGenModule &IGM);
irgen::IRGenModule &IGM,
bool mustTransform = false);
SILType getNewOptionalFunctionType(GenericEnvironment *GenericEnv,
SILType storageType,
irgen::IRGenModule &Mod);
@@ -241,8 +243,9 @@ bool LargeSILTypeMapper::newResultsDiffer(GenericEnvironment *GenericEnv,
static bool modNonFuncTypeResultType(GenericEnvironment *genEnv,
CanSILFunctionType loweredTy,
irgen::IRGenModule &Mod) {
if (!modifiableFunction(loweredTy)) {
irgen::IRGenModule &Mod,
bool mustTransform = false) {
if (!modifiableFunction(loweredTy) && !mustTransform) {
return false;
}
if (loweredTy->getNumResults() != 1) {
@@ -259,7 +262,8 @@ static bool modNonFuncTypeResultType(GenericEnvironment *genEnv,
SmallVector<SILResultInfo, 2>
LargeSILTypeMapper::getNewResults(GenericEnvironment *GenericEnv,
CanSILFunctionType fnType,
irgen::IRGenModule &Mod) {
irgen::IRGenModule &Mod,
bool mustTransform) {
// Get new SIL Function results - same as old results UNLESS:
// 1) Function type results might have a different signature
// 2) Large loadables are replaced by @out version
@@ -268,7 +272,7 @@ LargeSILTypeMapper::getNewResults(GenericEnvironment *GenericEnv,
for (auto result : origResults) {
SILType currResultTy = result.getSILStorageInterfaceType();
SILType newSILType = getNewSILType(GenericEnv, currResultTy, Mod);
if (modNonFuncTypeResultType(GenericEnv, fnType, Mod)) {
if (modNonFuncTypeResultType(GenericEnv, fnType, Mod, mustTransform)) {
// Case (2) Above
SILResultInfo newSILResultInfo(newSILType.getASTType(),
ResultConvention::Indirect);
@@ -288,8 +292,9 @@ LargeSILTypeMapper::getNewResults(GenericEnvironment *GenericEnv,
CanSILFunctionType
LargeSILTypeMapper::getNewSILFunctionType(GenericEnvironment *env,
CanSILFunctionType fnType,
irgen::IRGenModule &IGM) {
if (!modifiableFunction(fnType)) {
irgen::IRGenModule &IGM,
bool mustTransform) {
if (!modifiableFunction(fnType) && !mustTransform) {
return fnType;
}
@@ -301,7 +306,7 @@ LargeSILTypeMapper::getNewSILFunctionType(GenericEnvironment *env,
auto newParams = getNewParameters(env, fnType, IGM);
auto newYields = getNewYields(env, fnType, IGM);
auto newResults = getNewResults(env, fnType, IGM);
auto newResults = getNewResults(env, fnType, IGM, mustTransform);
auto newFnType = SILFunctionType::get(
fnType->getInvocationGenericSignature(),
fnType->getExtInfo(),
@@ -2623,7 +2628,20 @@ void LoadableByAddress::recreateSingleApply(
// Change the type of the Closure
auto partialApplyConvention = castedApply->getCalleeConvention();
auto resultIsolation = castedApply->getResultIsolation();
// We do need to update the closure's funtion type to match with the other
// uses inside of the binary. Pointer auth cares about the SIL function
// type.
if (callee->getType().castTo<SILFunctionType>()->getExtInfo().getRepresentation() ==
SILFunctionTypeRepresentation::ObjCMethod) {
CanSILFunctionType newFnType =
MapperCache.getNewSILFunctionType(
genEnv,
callee->getType().castTo<SILFunctionType>(), *currIRMod,
/*mustTransform*/ true);
SILType newType = SILType::getPrimitiveObjectType(newFnType);
callee = applyBuilder.createConvertFunction(castedApply->getLoc(),
callee, newType, false);
}
auto newApply = applyBuilder.createPartialApply(
castedApply->getLoc(), callee, applySite.getSubstitutionMap(), callArgs,
partialApplyConvention, resultIsolation, castedApply->isOnStack());