[Reflection] Improve size computation for remote type metadata reading.

Take advantage of the ability to iterate over TrailingObjects when computing the size of certain kinds of type metadata. This avoids the occasional issue where a new trailing object is added to a type, the remote metadata reader isn't fully updated for it, and doesn't read enough data. This change fixes an issue with function type metadata where we didn't read the global actor field.

Not all type metadata is amenable to this, as some don't use TrailingObjects for their trailing data (e.g. various nominal types) and extended existentials need to dereference their Shape pointer to determine the number of TrailingObjects, which needs some additional code when done remotely. We are able to automatically calculate the sizes of Existential and Function.

rdar://162855053
This commit is contained in:
Mike Ash
2025-11-12 14:10:58 -05:00
parent a1b41acf8d
commit d6fc306b16
3 changed files with 148 additions and 96 deletions

View File

@@ -421,6 +421,15 @@ public:
} }
}; };
// Helper function to determine at build time if a type has TrailingObjects.
// This is useful for determining if trailingTypeCount and
// sizeWithTrailingTypeCount are available for code that reads TrailingObjects
// values in a generalized fashion.
template <typename T>
static constexpr bool typeHasTrailingObjects() {
return std::is_base_of_v<trailing_objects_internal::TrailingObjectsBase, T>;
}
} // end namespace ABI } // end namespace ABI
} // end namespace swift } // end namespace swift

View File

@@ -1493,33 +1493,33 @@ public:
break; break;
case ContextDescriptorKind::Extension: case ContextDescriptorKind::Extension:
success = success =
readFullContextDescriptor<TargetExtensionContextDescriptor<Runtime>>( readFullTrailingObjects<TargetExtensionContextDescriptor<Runtime>>(
remoteAddress, ptr); remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break; break;
case ContextDescriptorKind::Anonymous: case ContextDescriptorKind::Anonymous:
success = success =
readFullContextDescriptor<TargetAnonymousContextDescriptor<Runtime>>( readFullTrailingObjects<TargetAnonymousContextDescriptor<Runtime>>(
remoteAddress, ptr); remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break; break;
case ContextDescriptorKind::Class: case ContextDescriptorKind::Class:
success = readFullContextDescriptor<TargetClassDescriptor<Runtime>>( success = readFullTrailingObjects<TargetClassDescriptor<Runtime>>(
remoteAddress, ptr); remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break; break;
case ContextDescriptorKind::Enum: case ContextDescriptorKind::Enum:
success = readFullContextDescriptor<TargetEnumDescriptor<Runtime>>( success = readFullTrailingObjects<TargetEnumDescriptor<Runtime>>(
remoteAddress, ptr); remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break; break;
case ContextDescriptorKind::Struct: case ContextDescriptorKind::Struct:
success = readFullContextDescriptor<TargetStructDescriptor<Runtime>>( success = readFullTrailingObjects<TargetStructDescriptor<Runtime>>(
remoteAddress, ptr); remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break; break;
case ContextDescriptorKind::Protocol: case ContextDescriptorKind::Protocol:
success = readFullContextDescriptor<TargetProtocolDescriptor<Runtime>>( success = readFullTrailingObjects<TargetProtocolDescriptor<Runtime>>(
remoteAddress, ptr); remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break; break;
case ContextDescriptorKind::OpaqueType: case ContextDescriptorKind::OpaqueType:
success = readFullContextDescriptor<TargetOpaqueTypeDescriptor<Runtime>>( success = readFullTrailingObjects<TargetOpaqueTypeDescriptor<Runtime>>(
remoteAddress, ptr); remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break; break;
default: default:
// We don't know about this kind of context. // We don't know about this kind of context.
@@ -1535,12 +1535,18 @@ public:
return ContextDescriptorRef(remoteAddress, descriptor); return ContextDescriptorRef(remoteAddress, descriptor);
} }
template <typename DescriptorTy> /// Read all memory occupied by a value with TrailingObjects. This will
bool readFullContextDescriptor(RemoteAddress address, /// incrementally read pieces of the object to figure out the full size of it.
MemoryReader::ReadBytesResult &ptr) { /// - address: The address of the value.
/// - ptr: The bytes that have been read so far. On return, the full object.
/// - existingByteCount: The number of bytes in ptr.
template <typename BaseTy>
bool readFullTrailingObjects(RemoteAddress address,
MemoryReader::ReadBytesResult &ptr,
size_t existingByteCount) {
// Read the full base descriptor if it's bigger than what we have so far. // Read the full base descriptor if it's bigger than what we have so far.
if (sizeof(DescriptorTy) > sizeof(TargetContextDescriptor<Runtime>)) { if (sizeof(BaseTy) > existingByteCount) {
ptr = Reader->template readObj<DescriptorTy>(address); ptr = Reader->template readObj<BaseTy>(address);
if (!ptr) if (!ptr)
return false; return false;
} }
@@ -1556,13 +1562,17 @@ public:
// size. Once we've walked through all the trailing objects, we've read // size. Once we've walked through all the trailing objects, we've read
// everything. // everything.
size_t sizeSoFar = sizeof(DescriptorTy); size_t sizeSoFar = sizeof(BaseTy);
for (size_t i = 0; i < DescriptorTy::trailingTypeCount(); i++) { for (size_t i = 0; i < BaseTy::trailingTypeCount(); i++) {
const DescriptorTy *descriptorSoFar = const BaseTy *descriptorSoFar =
reinterpret_cast<const DescriptorTy *>(ptr.get()); reinterpret_cast<const BaseTy *>(ptr.get());
size_t thisSize = descriptorSoFar->sizeWithTrailingTypeCount(i); size_t thisSize = descriptorSoFar->sizeWithTrailingTypeCount(i);
if (thisSize > sizeSoFar) { if (thisSize > sizeSoFar) {
// Make sure we haven't ended up with a ridiculous size.
if (thisSize > MaxMetadataSize)
return false;
ptr = Reader->readBytes(address, thisSize); ptr = Reader->readBytes(address, thisSize);
if (!ptr) if (!ptr)
return false; return false;
@@ -2141,45 +2151,22 @@ protected:
switch (getEnumeratedMetadataKind(KindValue)) { switch (getEnumeratedMetadataKind(KindValue)) {
case MetadataKind::Class: case MetadataKind::Class:
return _readMetadataFixedSize<TargetClassMetadataType>(address);
return _readMetadata<TargetClassMetadataType>(address);
case MetadataKind::Enum: case MetadataKind::Enum:
return _readMetadata<TargetEnumMetadata>(address); return _readMetadataFixedSize<TargetEnumMetadata>(address);
case MetadataKind::ErrorObject: case MetadataKind::ErrorObject:
return _readMetadata<TargetEnumMetadata>(address); return _readMetadataFixedSize<TargetMetadata>(address);
case MetadataKind::Existential: { case MetadataKind::Existential:
RemoteAddress flagsAddress = address + sizeof(StoredPointer); return _readMetadataVariableSize<TargetExistentialTypeMetadata>(
address);
ExistentialTypeFlags::int_type flagsData;
if (!Reader->readInteger(flagsAddress, &flagsData))
return nullptr;
ExistentialTypeFlags flags(flagsData);
RemoteAddress numProtocolsAddress = flagsAddress + sizeof(flagsData);
uint32_t numProtocols;
if (!Reader->readInteger(numProtocolsAddress, &numProtocols))
return nullptr;
// Make sure the number of protocols is reasonable
if (numProtocols >= 256)
return nullptr;
auto totalSize = sizeof(TargetExistentialTypeMetadata<Runtime>)
+ numProtocols *
sizeof(ConstTargetMetadataPointer<Runtime, TargetProtocolDescriptor>);
if (flags.hasSuperclassConstraint())
totalSize += sizeof(StoredPointer);
return _readMetadata(address, totalSize);
}
case MetadataKind::ExistentialMetatype: case MetadataKind::ExistentialMetatype:
return _readMetadata<TargetExistentialMetatypeMetadata>(address); return _readMetadataFixedSize<TargetExistentialMetatypeMetadata>(
address);
case MetadataKind::ExtendedExistential: { case MetadataKind::ExtendedExistential: {
// We need to read the shape in order to figure out how large // We need to read the shape in order to figure out how large
// the generalization arguments are. // the generalization arguments are. This prevents us from using
// _readMetadataVariableSize, which requires the Shape field to be
// dereferenceable here.
RemoteAddress shapeAddress = address + sizeof(StoredPointer); RemoteAddress shapeAddress = address + sizeof(StoredPointer);
RemoteAddress signedShapePtr; RemoteAddress signedShapePtr;
if (!Reader->template readRemoteAddress<StoredPointer>(shapeAddress, if (!Reader->template readRemoteAddress<StoredPointer>(shapeAddress,
@@ -2198,46 +2185,24 @@ protected:
return _readMetadata(address, totalSize); return _readMetadata(address, totalSize);
} }
case MetadataKind::ForeignClass: case MetadataKind::ForeignClass:
return _readMetadata<TargetForeignClassMetadata>(address); return _readMetadataFixedSize<TargetForeignClassMetadata>(address);
case MetadataKind::ForeignReferenceType: case MetadataKind::ForeignReferenceType:
return _readMetadata<TargetForeignReferenceTypeMetadata>(address); return _readMetadataFixedSize<TargetForeignReferenceTypeMetadata>(
case MetadataKind::Function: { address);
StoredSize flagsValue; case MetadataKind::Function:
auto flagsAddr = return _readMetadataVariableSize<TargetFunctionTypeMetadata>(address);
address + TargetFunctionTypeMetadata<Runtime>::OffsetToFlags;
if (!Reader->readInteger(flagsAddr, &flagsValue))
return nullptr;
auto flags =
TargetFunctionTypeFlags<StoredSize>::fromIntValue(flagsValue);
auto totalSize =
sizeof(TargetFunctionTypeMetadata<Runtime>) +
flags.getNumParameters() * sizeof(FunctionTypeMetadata::Parameter);
if (flags.hasParameterFlags())
totalSize += flags.getNumParameters() * sizeof(uint32_t);
if (flags.isDifferentiable())
totalSize = roundUpToAlignment(totalSize, sizeof(StoredPointer)) +
sizeof(TargetFunctionMetadataDifferentiabilityKind<
typename Runtime::StoredSize>);
return _readMetadata(address,
roundUpToAlignment(totalSize, sizeof(StoredPointer)));
}
case MetadataKind::HeapGenericLocalVariable: case MetadataKind::HeapGenericLocalVariable:
return _readMetadata<TargetGenericBoxHeapMetadata>(address); return _readMetadataFixedSize<TargetGenericBoxHeapMetadata>(address);
case MetadataKind::HeapLocalVariable: case MetadataKind::HeapLocalVariable:
return _readMetadata<TargetHeapLocalVariableMetadata>(address); return _readMetadataFixedSize<TargetHeapLocalVariableMetadata>(address);
case MetadataKind::Metatype: case MetadataKind::Metatype:
return _readMetadata<TargetMetatypeMetadata>(address); return _readMetadataFixedSize<TargetMetatypeMetadata>(address);
case MetadataKind::ObjCClassWrapper: case MetadataKind::ObjCClassWrapper:
return _readMetadata<TargetObjCClassWrapperMetadata>(address); return _readMetadataFixedSize<TargetObjCClassWrapperMetadata>(address);
case MetadataKind::Optional: case MetadataKind::Optional:
return _readMetadata<TargetEnumMetadata>(address); return _readMetadataFixedSize<TargetEnumMetadata>(address);
case MetadataKind::Struct: case MetadataKind::Struct:
return _readMetadata<TargetStructMetadata>(address); return _readMetadataFixedSize<TargetStructMetadata>(address);
case MetadataKind::Tuple: { case MetadataKind::Tuple: {
auto numElementsAddress = address + auto numElementsAddress = address +
TargetTupleTypeMetadata<Runtime>::getOffsetToNumElements(); TargetTupleTypeMetadata<Runtime>::getOffsetToNumElements();
@@ -2255,7 +2220,7 @@ protected:
} }
case MetadataKind::Opaque: case MetadataKind::Opaque:
default: default:
return _readMetadata<TargetOpaqueMetadata>(address); return _readMetadataFixedSize<TargetOpaqueMetadata>(address);
} }
// We can fall out here if the value wasn't actually a valid // We can fall out here if the value wasn't actually a valid
@@ -2333,20 +2298,39 @@ protected:
private: private:
template <template <class R> class M> template <template <class R> class M>
MetadataRef _readMetadata(RemoteAddress address) { MetadataRef _readMetadataFixedSize(RemoteAddress address) {
static_assert(!ABI::typeHasTrailingObjects<M<Runtime>>(),
"Type must not have trailing objects. Use "
"_readMetadataVariableSize for types that have them.");
return _readMetadata(address, sizeof(M<Runtime>)); return _readMetadata(address, sizeof(M<Runtime>));
} }
template <template <class R> class M>
MetadataRef _readMetadataVariableSize(RemoteAddress address) {
static_assert(ABI::typeHasTrailingObjects<M<Runtime>>(),
"Type must have trailing objects. Use _readMetadataFixedSize "
"for types that don't.");
MemoryReader::ReadBytesResult bytes;
auto readResult = readFullTrailingObjects<M<Runtime>>(address, bytes, 0);
if (!readResult)
return nullptr;
return _cacheMetadata(address, bytes);
}
MetadataRef _readMetadata(RemoteAddress address, size_t sizeAfter) { MetadataRef _readMetadata(RemoteAddress address, size_t sizeAfter) {
if (sizeAfter > MaxMetadataSize) if (sizeAfter > MaxMetadataSize)
return nullptr; return nullptr;
auto readResult = Reader->readBytes(address, sizeAfter); auto readResult = Reader->readBytes(address, sizeAfter);
if (!readResult) return _cacheMetadata(address, readResult);
return nullptr; }
MetadataRef _cacheMetadata(RemoteAddress address,
MemoryReader::ReadBytesResult &bytes) {
auto metadata = auto metadata =
reinterpret_cast<const TargetMetadata<Runtime> *>(readResult.get()); reinterpret_cast<const TargetMetadata<Runtime> *>(bytes.get());
MetadataCache.insert(std::make_pair(address, std::move(readResult))); MetadataCache.insert(std::make_pair(address, std::move(bytes)));
return MetadataRef(address, metadata); return MetadataRef(address, metadata);
} }

View File

@@ -0,0 +1,59 @@
// Target Swift 5.5 so functions with global actors are encoded as a proper
// mangled name, not a function accessor. Remote Mirror can't call function
// accessors and will fail to read the type.
// RUN: %empty-directory(%t)
// RUN: %target-build-swift -target %target-swift-5.5-abi-triple -lswiftSwiftReflectionTest %s -o %t/function_types
// RUN: %target-codesign %t/function_types
// RUN: %target-run %target-swift-reflection-test %t/function_types | %FileCheck %s --check-prefix=CHECK
// REQUIRES: reflection_test_support
// REQUIRES: executable_test
// UNSUPPORTED: use_os_stdlib
// UNSUPPORTED: asan
import SwiftReflectionTest
struct S {
var f = { @MainActor in }
}
// This Task is necessary to ensure that the concurrency runtime is brought in.
// Without that, the type lookup for @MainActor may fail.
Task {}
// CHECK: Type reference:
// CHECK: (struct function_types.S)
// CHECK: Type info:
// CHECK: (struct size=
// CHECK: (field name=f offset=0
// CHECK: (thick_function size=
// CHECK: (field name=function offset=0
// CHECK: (builtin size=
// CHECK: (field name=context
// CHECK: (reference kind=strong refcounting=native)))))
// CHECK: Mangled name: $s14function_types1SV
// CHECK: Demangled name: function_types.S
reflect(any: S())
// CHECK: Type reference:
// CHECK: (function
// CHECK: (global-actor
// CHECK: (class Swift.MainActor))
// CHECK: (parameters)
// CHECK: (result
// CHECK: (tuple))
// CHECK: Type info:
// CHECK: (thick_function size=
// CHECK: (field name=function offset=0
// CHECK: (builtin size=
// CHECK: (field name=context offset=
// CHECK: (reference kind=strong refcounting=native)))
// CHECK: Mangled name: $syyScMYcc
// CHECK: Demangled name: @Swift.MainActor () -> ()
reflect(any: S().f)
doneReflecting()