IRGen: fix failing unconditional class casts

When unconditionally casting from a base to a final derived class, e.g. `base as! Derived`, the program did not abort with a trap.
Instead the resulting null-pointer caused a crash later in the program.
This fix inserts a trap condition for the failing case of such a cast.

rdar://151462303
This commit is contained in:
Erik Eckstein
2025-05-19 12:41:10 +02:00
parent 371e4ebd88
commit a768037d0b
5 changed files with 101 additions and 14 deletions

View File

@@ -1032,7 +1032,7 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF,
}
if (llvm::Value *fastResult = emitFastClassCastIfPossible(
IGF, instance, sourceFormalType, targetFormalType,
IGF, instance, sourceFormalType, targetFormalType, mode,
sourceWrappedInOptional, nilCheckBB, nilMergeBB)) {
Explosion fastExplosion;
fastExplosion.add(fastResult);
@@ -1054,7 +1054,7 @@ void irgen::emitScalarCheckedCast(IRGenFunction &IGF,
/// not required that the metadata is fully initialized.
llvm::Value *irgen::emitFastClassCastIfPossible(
IRGenFunction &IGF, llvm::Value *instance, CanType sourceFormalType,
CanType targetFormalType, bool sourceWrappedInOptional,
CanType targetFormalType, CheckedCastMode mode, bool sourceWrappedInOptional,
llvm::BasicBlock *&nilCheckBB, llvm::BasicBlock *&nilMergeBB) {
if (!doesCastPreserveOwnershipForTypes(IGF.IGM.getSILModule(),
sourceFormalType, targetFormalType)) {
@@ -1089,15 +1089,18 @@ llvm::Value *irgen::emitFastClassCastIfPossible(
// If the source was originally wrapped in an Optional, check it for nil now.
if (sourceWrappedInOptional) {
auto isNotNil = IGF.Builder.CreateICmpNE(
auto isNil = IGF.Builder.CreateICmpEQ(
instance, llvm::ConstantPointerNull::get(
cast<llvm::PointerType>(instance->getType())));
auto *isNotNilContBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilMergeBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilCheckBB = IGF.Builder.GetInsertBlock();
IGF.Builder.CreateCondBr(isNotNil, isNotNilContBB, nilMergeBB);
IGF.Builder.emitBlock(isNotNilContBB);
if (mode == CheckedCastMode::Unconditional) {
IGF.emitConditionalTrap(isNil, "Unexpectedly found nil while unwrapping an Optional value");
} else {
auto *isNotNilContBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilMergeBB = llvm::BasicBlock::Create(IGF.IGM.getLLVMContext());
nilCheckBB = IGF.Builder.GetInsertBlock();
IGF.Builder.CreateCondBr(isNil, nilMergeBB, isNotNilContBB);
IGF.Builder.emitBlock(isNotNilContBB);
}
}
// Get the metadata pointer of the destination class type.
@@ -1121,11 +1124,15 @@ llvm::Value *irgen::emitFastClassCastIfPossible(
llvm::Value *rhs = IGF.Builder.CreateBitCast(objMetadata, IGF.IGM.Int8PtrTy);
// return isa_ptr == metadata_ptr ? instance : nullptr
llvm::Value *isEqual = IGF.Builder.CreateCmp(llvm::CmpInst::Predicate::ICMP_EQ,
llvm::Value *isNotEqual = IGF.Builder.CreateCmp(llvm::CmpInst::Predicate::ICMP_NE,
lhs, rhs);
if (mode == CheckedCastMode::Unconditional) {
IGF.emitConditionalTrap(isNotEqual, "Unconditional cast failed");
return instance;
}
auto *instanceTy = cast<llvm::PointerType>(instance->getType());
auto *nullPtr = llvm::ConstantPointerNull::get(instanceTy);
auto *select = IGF.Builder.CreateSelect(isEqual, instance, nullPtr);
auto *select = IGF.Builder.CreateSelect(isNotEqual, nullPtr, instance);
llvm::Type *destTy = IGF.getTypeInfoForUnlowered(targetFormalType).getStorageType();
return IGF.Builder.CreateBitCast(select, destTy);
}

View File

@@ -61,7 +61,7 @@ namespace irgen {
llvm::Value *emitFastClassCastIfPossible(
IRGenFunction &IGF, llvm::Value *instance, CanType sourceFormalType,
CanType targetFormalType, bool sourceWrappedInOptional,
CanType targetFormalType, CheckedCastMode mode, bool sourceWrappedInOptional,
llvm::BasicBlock *&nilCheckBB, llvm::BasicBlock *&nilMergeBB);
/// Convert a class object to the given destination type,

View File

@@ -142,4 +142,35 @@ CastTrapsTestSuite.test("Unexpected Obj-C null")
}
#endif
class Base {}
final class Derived: Base {}
final class Other: Base {}
@inline(never)
func getDerived(_ v: Base) -> Derived {
return v as! Derived
}
@inline(never)
func getDerivedFromOptional(_ v: Base?) -> Derived {
return v as! Derived
}
CastTrapsTestSuite.test("unconditinal fast class cast") {
let c = Other()
expectCrashLater()
_ = getDerived(c)
}
CastTrapsTestSuite.test("unconditinal optional fast class cast") {
let c = Other()
expectCrashLater()
_ = getDerivedFromOptional(c)
}
CastTrapsTestSuite.test("unconditinal optional nil fast class cast") {
expectCrashLater()
_ = getDerivedFromOptional(nil)
}
runAllTests()

View File

@@ -56,6 +56,15 @@ func unconditionalCastToFinal(_ b: Classes.Base) -> Classes.Final {
return b as! Classes.Final
}
// CHECK-LABEL: define {{.*}} @"$s4Main32unconditionalOptionalCastToFinaly7Classes0F0CAC4BaseCSgF"
// CHECK-NOT: call {{.*}}@object_getClass
// CHECK-NOT: @swift_dynamicCastClass
// CHECK: }
@inline(never)
func unconditionalOptionalCastToFinal(_ b: Classes.Base?) -> Classes.Final {
return b as! Classes.Final
}
// CHECK-LABEL: define {{.*}} @"$s4Main20castToResilientFinaly0D7Classes0E0CSgAC4BaseCF"
// CHECK: @swift_dynamicCastClass
// CHECK: }
@@ -132,7 +141,9 @@ func test() {
// CHECK-OUTPUT: Optional(Classes.Final)
print(castToFinal(Classes.Final()) as Any)
// CHECK-OUTPUT: Classes.Final
print(unconditionalCastToFinal(Classes.Final()) as Any)
print(unconditionalCastToFinal(Classes.Final()))
// CHECK-OUTPUT: Classes.Final
print(unconditionalOptionalCastToFinal(Classes.Final()))
// CHECK-OUTPUT: nil
print(castToResilientFinal(ResilientClasses.Base()) as Any)

View File

@@ -1,6 +1,6 @@
// RUN: %target-swift-frontend %s -emit-ir -enable-objc-interop -disable-objc-attr-requires-foundation-module | %FileCheck %s -DINT=i%target-ptrsize
// REQUIRES: CPU=i386 || CPU=x86_64
// REQUIRES: CPU=i386 || CPU=x86_64 || CPU=arm64
sil_stage canonical
@@ -11,9 +11,11 @@ struct NotClass {}
class A {}
class B: A {}
final class F: A {}
sil_vtable A {}
sil_vtable B {}
sil_vtable F {}
// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc ptr @unchecked_addr_cast(ptr noalias {{(nocapture|captures\(none\))}} dereferenceable({{.*}}) %0) {{.*}} {
sil @unchecked_addr_cast : $(@in A) -> B {
@@ -115,6 +117,42 @@ entry(%a : $@thick Any.Type):
return %p : $@thick (CP & OP & CP2).Type
}
// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc ptr @unconditional_fast_class_cast(ptr %0)
// CHECK: [[ISA:%.*]] = load ptr, ptr %0
// CHECK: [[NE:%.*]] = icmp ne {{.*}}, [[ISA]]
// CHECK: [[E:%.*]] = call i1 @llvm.expect.i1(i1 [[NE]], i1 false)
// CHECK: br i1 [[E]], label %[[TRAPBB:[0-9]*]], label %[[RETBB:[0-9]*]]
// CHECK: [[RETBB]]:
// CHECK-NEXT: ret ptr %0
// CHECK: [[TRAPBB]]:
// CHECK-NEXT: call void @llvm.trap()
sil @unconditional_fast_class_cast : $@convention(thin) (@owned A) -> @owned F {
entry(%0 : $A):
%1 = unconditional_checked_cast %0 to F
return %1
}
// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc ptr @unconditional_optional_fast_class_cast(i64 %0)
// CHECK: [[PTR:%.*]] = inttoptr i64 %0 to ptr
// CHECK: [[ISNULL:%.*]] = icmp eq ptr [[PTR]], null
// CHECK: [[ENN:%.*]] = call i1 @llvm.expect.i1(i1 [[ISNULL]], i1 false)
// CHECK: br i1 [[ENN]], label %[[NULLTRAPBB:[0-9]*]], label %[[CONTBB:[0-9]*]]
// CHECK: [[CONTBB]]:
// CHECK: [[ISA:%.*]] = load ptr, ptr [[PTR]]
// CHECK: [[NE:%.*]] = icmp ne {{.*}}, [[ISA]]
// CHECK: [[E:%.*]] = call i1 @llvm.expect.i1(i1 [[NE]], i1 false)
// CHECK: br i1 [[E]], label %[[TRAPBB:[0-9]*]], label %[[RETBB:[0-9]*]]
// CHECK: [[RETBB]]:
// CHECK-NEXT: ret ptr [[PTR]]
// CHECK: [[NULLTRAPBB]]:
// CHECK-NEXT: call void @llvm.trap()
// CHECK: [[TRAPBB]]:
// CHECK-NEXT: call void @llvm.trap()
sil @unconditional_optional_fast_class_cast : $@convention(thin) (@owned Optional<A>) -> @owned F {
entry(%0 : $Optional<A>):
%1 = unconditional_checked_cast %0 to F
return %1
}
// CHECK-LABEL: define{{( dllexport)?}}{{( protected)?}} swiftcc { ptr, ptr } @c_cast_to_class_existential(ptr %0)
// CHECK: call { ptr, ptr } @dynamic_cast_existential_1_conditional(ptr {{.*}}, ptr %.Type, {{.*}} @"$s5casts2CPMp"