[Reflection] Improve size computation for remote type metadata reading.

Take advantage of the ability to iterate over TrailingObjects when computing the size of certain kinds of type metadata. This avoids the occasional issue where a new trailing object is added to a type, the remote metadata reader isn't fully updated for it, and doesn't read enough data. This change fixes an issue with function type metadata where we didn't read the global actor field.

Not all type metadata is amenable to this, as some don't use TrailingObjects for their trailing data (e.g. various nominal types) and extended existentials need to dereference their Shape pointer to determine the number of TrailingObjects, which needs some additional code when done remotely. We are able to automatically calculate the sizes of Existential and Function.

rdar://162855053
This commit is contained in:
Mike Ash
2025-11-12 14:10:58 -05:00
parent a1b41acf8d
commit d6fc306b16
3 changed files with 148 additions and 96 deletions

View File

@@ -421,6 +421,15 @@ public:
}
};
// Helper function to determine at build time if a type has TrailingObjects.
// This is useful for determining if trailingTypeCount and
// sizeWithTrailingTypeCount are available for code that reads TrailingObjects
// values in a generalized fashion.
template <typename T>
static constexpr bool typeHasTrailingObjects() {
return std::is_base_of_v<trailing_objects_internal::TrailingObjectsBase, T>;
}
} // end namespace ABI
} // end namespace swift

View File

@@ -1493,33 +1493,33 @@ public:
break;
case ContextDescriptorKind::Extension:
success =
readFullContextDescriptor<TargetExtensionContextDescriptor<Runtime>>(
remoteAddress, ptr);
readFullTrailingObjects<TargetExtensionContextDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Anonymous:
success =
readFullContextDescriptor<TargetAnonymousContextDescriptor<Runtime>>(
remoteAddress, ptr);
readFullTrailingObjects<TargetAnonymousContextDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Class:
success = readFullContextDescriptor<TargetClassDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetClassDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Enum:
success = readFullContextDescriptor<TargetEnumDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetEnumDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Struct:
success = readFullContextDescriptor<TargetStructDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetStructDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Protocol:
success = readFullContextDescriptor<TargetProtocolDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetProtocolDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::OpaqueType:
success = readFullContextDescriptor<TargetOpaqueTypeDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetOpaqueTypeDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
default:
// We don't know about this kind of context.
@@ -1535,12 +1535,18 @@ public:
return ContextDescriptorRef(remoteAddress, descriptor);
}
template <typename DescriptorTy>
bool readFullContextDescriptor(RemoteAddress address,
MemoryReader::ReadBytesResult &ptr) {
/// Read all memory occupied by a value with TrailingObjects. This will
/// incrementally read pieces of the object to figure out the full size of it.
/// - address: The address of the value.
/// - ptr: The bytes that have been read so far. On return, the full object.
/// - existingByteCount: The number of bytes in ptr.
template <typename BaseTy>
bool readFullTrailingObjects(RemoteAddress address,
MemoryReader::ReadBytesResult &ptr,
size_t existingByteCount) {
// Read the full base descriptor if it's bigger than what we have so far.
if (sizeof(DescriptorTy) > sizeof(TargetContextDescriptor<Runtime>)) {
ptr = Reader->template readObj<DescriptorTy>(address);
if (sizeof(BaseTy) > existingByteCount) {
ptr = Reader->template readObj<BaseTy>(address);
if (!ptr)
return false;
}
@@ -1556,13 +1562,17 @@ public:
// size. Once we've walked through all the trailing objects, we've read
// everything.
size_t sizeSoFar = sizeof(DescriptorTy);
size_t sizeSoFar = sizeof(BaseTy);
for (size_t i = 0; i < DescriptorTy::trailingTypeCount(); i++) {
const DescriptorTy *descriptorSoFar =
reinterpret_cast<const DescriptorTy *>(ptr.get());
for (size_t i = 0; i < BaseTy::trailingTypeCount(); i++) {
const BaseTy *descriptorSoFar =
reinterpret_cast<const BaseTy *>(ptr.get());
size_t thisSize = descriptorSoFar->sizeWithTrailingTypeCount(i);
if (thisSize > sizeSoFar) {
// Make sure we haven't ended up with a ridiculous size.
if (thisSize > MaxMetadataSize)
return false;
ptr = Reader->readBytes(address, thisSize);
if (!ptr)
return false;
@@ -2141,45 +2151,22 @@ protected:
switch (getEnumeratedMetadataKind(KindValue)) {
case MetadataKind::Class:
return _readMetadata<TargetClassMetadataType>(address);
return _readMetadataFixedSize<TargetClassMetadataType>(address);
case MetadataKind::Enum:
return _readMetadata<TargetEnumMetadata>(address);
return _readMetadataFixedSize<TargetEnumMetadata>(address);
case MetadataKind::ErrorObject:
return _readMetadata<TargetEnumMetadata>(address);
case MetadataKind::Existential: {
RemoteAddress flagsAddress = address + sizeof(StoredPointer);
ExistentialTypeFlags::int_type flagsData;
if (!Reader->readInteger(flagsAddress, &flagsData))
return nullptr;
ExistentialTypeFlags flags(flagsData);
RemoteAddress numProtocolsAddress = flagsAddress + sizeof(flagsData);
uint32_t numProtocols;
if (!Reader->readInteger(numProtocolsAddress, &numProtocols))
return nullptr;
// Make sure the number of protocols is reasonable
if (numProtocols >= 256)
return nullptr;
auto totalSize = sizeof(TargetExistentialTypeMetadata<Runtime>)
+ numProtocols *
sizeof(ConstTargetMetadataPointer<Runtime, TargetProtocolDescriptor>);
if (flags.hasSuperclassConstraint())
totalSize += sizeof(StoredPointer);
return _readMetadata(address, totalSize);
}
return _readMetadataFixedSize<TargetMetadata>(address);
case MetadataKind::Existential:
return _readMetadataVariableSize<TargetExistentialTypeMetadata>(
address);
case MetadataKind::ExistentialMetatype:
return _readMetadata<TargetExistentialMetatypeMetadata>(address);
return _readMetadataFixedSize<TargetExistentialMetatypeMetadata>(
address);
case MetadataKind::ExtendedExistential: {
// We need to read the shape in order to figure out how large
// the generalization arguments are.
// the generalization arguments are. This prevents us from using
// _readMetadataVariableSize, which requires the Shape field to be
// dereferenceable here.
RemoteAddress shapeAddress = address + sizeof(StoredPointer);
RemoteAddress signedShapePtr;
if (!Reader->template readRemoteAddress<StoredPointer>(shapeAddress,
@@ -2198,46 +2185,24 @@ protected:
return _readMetadata(address, totalSize);
}
case MetadataKind::ForeignClass:
return _readMetadata<TargetForeignClassMetadata>(address);
return _readMetadataFixedSize<TargetForeignClassMetadata>(address);
case MetadataKind::ForeignReferenceType:
return _readMetadata<TargetForeignReferenceTypeMetadata>(address);
case MetadataKind::Function: {
StoredSize flagsValue;
auto flagsAddr =
address + TargetFunctionTypeMetadata<Runtime>::OffsetToFlags;
if (!Reader->readInteger(flagsAddr, &flagsValue))
return nullptr;
auto flags =
TargetFunctionTypeFlags<StoredSize>::fromIntValue(flagsValue);
auto totalSize =
sizeof(TargetFunctionTypeMetadata<Runtime>) +
flags.getNumParameters() * sizeof(FunctionTypeMetadata::Parameter);
if (flags.hasParameterFlags())
totalSize += flags.getNumParameters() * sizeof(uint32_t);
if (flags.isDifferentiable())
totalSize = roundUpToAlignment(totalSize, sizeof(StoredPointer)) +
sizeof(TargetFunctionMetadataDifferentiabilityKind<
typename Runtime::StoredSize>);
return _readMetadata(address,
roundUpToAlignment(totalSize, sizeof(StoredPointer)));
}
return _readMetadataFixedSize<TargetForeignReferenceTypeMetadata>(
address);
case MetadataKind::Function:
return _readMetadataVariableSize<TargetFunctionTypeMetadata>(address);
case MetadataKind::HeapGenericLocalVariable:
return _readMetadata<TargetGenericBoxHeapMetadata>(address);
return _readMetadataFixedSize<TargetGenericBoxHeapMetadata>(address);
case MetadataKind::HeapLocalVariable:
return _readMetadata<TargetHeapLocalVariableMetadata>(address);
return _readMetadataFixedSize<TargetHeapLocalVariableMetadata>(address);
case MetadataKind::Metatype:
return _readMetadata<TargetMetatypeMetadata>(address);
return _readMetadataFixedSize<TargetMetatypeMetadata>(address);
case MetadataKind::ObjCClassWrapper:
return _readMetadata<TargetObjCClassWrapperMetadata>(address);
return _readMetadataFixedSize<TargetObjCClassWrapperMetadata>(address);
case MetadataKind::Optional:
return _readMetadata<TargetEnumMetadata>(address);
return _readMetadataFixedSize<TargetEnumMetadata>(address);
case MetadataKind::Struct:
return _readMetadata<TargetStructMetadata>(address);
return _readMetadataFixedSize<TargetStructMetadata>(address);
case MetadataKind::Tuple: {
auto numElementsAddress = address +
TargetTupleTypeMetadata<Runtime>::getOffsetToNumElements();
@@ -2255,7 +2220,7 @@ protected:
}
case MetadataKind::Opaque:
default:
return _readMetadata<TargetOpaqueMetadata>(address);
return _readMetadataFixedSize<TargetOpaqueMetadata>(address);
}
// We can fall out here if the value wasn't actually a valid
@@ -2333,20 +2298,39 @@ protected:
private:
template <template <class R> class M>
MetadataRef _readMetadata(RemoteAddress address) {
MetadataRef _readMetadataFixedSize(RemoteAddress address) {
static_assert(!ABI::typeHasTrailingObjects<M<Runtime>>(),
"Type must not have trailing objects. Use "
"_readMetadataVariableSize for types that have them.");
return _readMetadata(address, sizeof(M<Runtime>));
}
template <template <class R> class M>
MetadataRef _readMetadataVariableSize(RemoteAddress address) {
static_assert(ABI::typeHasTrailingObjects<M<Runtime>>(),
"Type must have trailing objects. Use _readMetadataFixedSize "
"for types that don't.");
MemoryReader::ReadBytesResult bytes;
auto readResult = readFullTrailingObjects<M<Runtime>>(address, bytes, 0);
if (!readResult)
return nullptr;
return _cacheMetadata(address, bytes);
}
MetadataRef _readMetadata(RemoteAddress address, size_t sizeAfter) {
if (sizeAfter > MaxMetadataSize)
return nullptr;
auto readResult = Reader->readBytes(address, sizeAfter);
if (!readResult)
return nullptr;
return _cacheMetadata(address, readResult);
}
MetadataRef _cacheMetadata(RemoteAddress address,
MemoryReader::ReadBytesResult &bytes) {
auto metadata =
reinterpret_cast<const TargetMetadata<Runtime> *>(readResult.get());
MetadataCache.insert(std::make_pair(address, std::move(readResult)));
reinterpret_cast<const TargetMetadata<Runtime> *>(bytes.get());
MetadataCache.insert(std::make_pair(address, std::move(bytes)));
return MetadataRef(address, metadata);
}

View File

@@ -0,0 +1,59 @@
// Target Swift 5.5 so functions with global actors are encoded as a proper
// mangled name, not a function accessor. Remote Mirror can't call function
// accessors and will fail to read the type.
// RUN: %empty-directory(%t)
// RUN: %target-build-swift -target %target-swift-5.5-abi-triple -lswiftSwiftReflectionTest %s -o %t/function_types
// RUN: %target-codesign %t/function_types
// RUN: %target-run %target-swift-reflection-test %t/function_types | %FileCheck %s --check-prefix=CHECK
// REQUIRES: reflection_test_support
// REQUIRES: executable_test
// UNSUPPORTED: use_os_stdlib
// UNSUPPORTED: asan
import SwiftReflectionTest
struct S {
var f = { @MainActor in }
}
// This Task is necessary to ensure that the concurrency runtime is brought in.
// Without that, the type lookup for @MainActor may fail.
Task {}
// CHECK: Type reference:
// CHECK: (struct function_types.S)
// CHECK: Type info:
// CHECK: (struct size=
// CHECK: (field name=f offset=0
// CHECK: (thick_function size=
// CHECK: (field name=function offset=0
// CHECK: (builtin size=
// CHECK: (field name=context
// CHECK: (reference kind=strong refcounting=native)))))
// CHECK: Mangled name: $s14function_types1SV
// CHECK: Demangled name: function_types.S
reflect(any: S())
// CHECK: Type reference:
// CHECK: (function
// CHECK: (global-actor
// CHECK: (class Swift.MainActor))
// CHECK: (parameters)
// CHECK: (result
// CHECK: (tuple))
// CHECK: Type info:
// CHECK: (thick_function size=
// CHECK: (field name=function offset=0
// CHECK: (builtin size=
// CHECK: (field name=context offset=
// CHECK: (reference kind=strong refcounting=native)))
// CHECK: Mangled name: $syyScMYcc
// CHECK: Demangled name: @Swift.MainActor () -> ()
reflect(any: S().f)
doneReflecting()