From d6fc306b167ce0c68f296c522a92ca7e8db4fdb3 Mon Sep 17 00:00:00 2001 From: Mike Ash Date: Wed, 12 Nov 2025 14:10:58 -0500 Subject: [PATCH] [Reflection] Improve size computation for remote type metadata reading. Take advantage of the ability to iterate over TrailingObjects when computing the size of certain kinds of type metadata. This avoids the occasional issue where a new trailing object is added to a type, the remote metadata reader isn't fully updated for it, and doesn't read enough data. This change fixes an issue with function type metadata where we didn't read the global actor field. Not all type metadata is amenable to this, as some don't use TrailingObjects for their trailing data (e.g. various nominal types) and extended existentials need to dereference their Shape pointer to determine the number of TrailingObjects, which needs some additional code when done remotely. We are able to automatically calculate the sizes of Existential and Function. rdar://162855053 --- include/swift/ABI/TrailingObjects.h | 9 + include/swift/Remote/MetadataReader.h | 176 ++++++++---------- .../Reflection/function_types.swift | 59 ++++++ 3 files changed, 148 insertions(+), 96 deletions(-) create mode 100644 validation-test/Reflection/function_types.swift diff --git a/include/swift/ABI/TrailingObjects.h b/include/swift/ABI/TrailingObjects.h index 4ed23b7de61..e9b64230ce4 100644 --- a/include/swift/ABI/TrailingObjects.h +++ b/include/swift/ABI/TrailingObjects.h @@ -421,6 +421,15 @@ public: } }; +// Helper function to determine at build time if a type has TrailingObjects. +// This is useful for determining if trailingTypeCount and +// sizeWithTrailingTypeCount are available for code that reads TrailingObjects +// values in a generalized fashion. +template +static constexpr bool typeHasTrailingObjects() { + return std::is_base_of_v; +} + } // end namespace ABI } // end namespace swift diff --git a/include/swift/Remote/MetadataReader.h b/include/swift/Remote/MetadataReader.h index 770024ecc4e..8ea3580a5e7 100644 --- a/include/swift/Remote/MetadataReader.h +++ b/include/swift/Remote/MetadataReader.h @@ -1493,33 +1493,33 @@ public: break; case ContextDescriptorKind::Extension: success = - readFullContextDescriptor>( - remoteAddress, ptr); + readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Anonymous: success = - readFullContextDescriptor>( - remoteAddress, ptr); + readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Class: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Enum: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Struct: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Protocol: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::OpaqueType: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; default: // We don't know about this kind of context. @@ -1535,12 +1535,18 @@ public: return ContextDescriptorRef(remoteAddress, descriptor); } - template - bool readFullContextDescriptor(RemoteAddress address, - MemoryReader::ReadBytesResult &ptr) { + /// Read all memory occupied by a value with TrailingObjects. This will + /// incrementally read pieces of the object to figure out the full size of it. + /// - address: The address of the value. + /// - ptr: The bytes that have been read so far. On return, the full object. + /// - existingByteCount: The number of bytes in ptr. + template + bool readFullTrailingObjects(RemoteAddress address, + MemoryReader::ReadBytesResult &ptr, + size_t existingByteCount) { // Read the full base descriptor if it's bigger than what we have so far. - if (sizeof(DescriptorTy) > sizeof(TargetContextDescriptor)) { - ptr = Reader->template readObj(address); + if (sizeof(BaseTy) > existingByteCount) { + ptr = Reader->template readObj(address); if (!ptr) return false; } @@ -1556,13 +1562,17 @@ public: // size. Once we've walked through all the trailing objects, we've read // everything. - size_t sizeSoFar = sizeof(DescriptorTy); + size_t sizeSoFar = sizeof(BaseTy); - for (size_t i = 0; i < DescriptorTy::trailingTypeCount(); i++) { - const DescriptorTy *descriptorSoFar = - reinterpret_cast(ptr.get()); + for (size_t i = 0; i < BaseTy::trailingTypeCount(); i++) { + const BaseTy *descriptorSoFar = + reinterpret_cast(ptr.get()); size_t thisSize = descriptorSoFar->sizeWithTrailingTypeCount(i); if (thisSize > sizeSoFar) { + // Make sure we haven't ended up with a ridiculous size. + if (thisSize > MaxMetadataSize) + return false; + ptr = Reader->readBytes(address, thisSize); if (!ptr) return false; @@ -2141,45 +2151,22 @@ protected: switch (getEnumeratedMetadataKind(KindValue)) { case MetadataKind::Class: - - return _readMetadata(address); - + return _readMetadataFixedSize(address); case MetadataKind::Enum: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::ErrorObject: - return _readMetadata(address); - case MetadataKind::Existential: { - RemoteAddress flagsAddress = address + sizeof(StoredPointer); - - ExistentialTypeFlags::int_type flagsData; - if (!Reader->readInteger(flagsAddress, &flagsData)) - return nullptr; - - ExistentialTypeFlags flags(flagsData); - - RemoteAddress numProtocolsAddress = flagsAddress + sizeof(flagsData); - uint32_t numProtocols; - if (!Reader->readInteger(numProtocolsAddress, &numProtocols)) - return nullptr; - - // Make sure the number of protocols is reasonable - if (numProtocols >= 256) - return nullptr; - - auto totalSize = sizeof(TargetExistentialTypeMetadata) - + numProtocols * - sizeof(ConstTargetMetadataPointer); - - if (flags.hasSuperclassConstraint()) - totalSize += sizeof(StoredPointer); - - return _readMetadata(address, totalSize); - } + return _readMetadataFixedSize(address); + case MetadataKind::Existential: + return _readMetadataVariableSize( + address); case MetadataKind::ExistentialMetatype: - return _readMetadata(address); + return _readMetadataFixedSize( + address); case MetadataKind::ExtendedExistential: { // We need to read the shape in order to figure out how large - // the generalization arguments are. + // the generalization arguments are. This prevents us from using + // _readMetadataVariableSize, which requires the Shape field to be + // dereferenceable here. RemoteAddress shapeAddress = address + sizeof(StoredPointer); RemoteAddress signedShapePtr; if (!Reader->template readRemoteAddress(shapeAddress, @@ -2198,46 +2185,24 @@ protected: return _readMetadata(address, totalSize); } case MetadataKind::ForeignClass: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::ForeignReferenceType: - return _readMetadata(address); - case MetadataKind::Function: { - StoredSize flagsValue; - auto flagsAddr = - address + TargetFunctionTypeMetadata::OffsetToFlags; - if (!Reader->readInteger(flagsAddr, &flagsValue)) - return nullptr; - - auto flags = - TargetFunctionTypeFlags::fromIntValue(flagsValue); - - auto totalSize = - sizeof(TargetFunctionTypeMetadata) + - flags.getNumParameters() * sizeof(FunctionTypeMetadata::Parameter); - - if (flags.hasParameterFlags()) - totalSize += flags.getNumParameters() * sizeof(uint32_t); - - if (flags.isDifferentiable()) - totalSize = roundUpToAlignment(totalSize, sizeof(StoredPointer)) + - sizeof(TargetFunctionMetadataDifferentiabilityKind< - typename Runtime::StoredSize>); - - return _readMetadata(address, - roundUpToAlignment(totalSize, sizeof(StoredPointer))); - } + return _readMetadataFixedSize( + address); + case MetadataKind::Function: + return _readMetadataVariableSize(address); case MetadataKind::HeapGenericLocalVariable: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::HeapLocalVariable: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::Metatype: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::ObjCClassWrapper: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::Optional: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::Struct: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::Tuple: { auto numElementsAddress = address + TargetTupleTypeMetadata::getOffsetToNumElements(); @@ -2255,7 +2220,7 @@ protected: } case MetadataKind::Opaque: default: - return _readMetadata(address); + return _readMetadataFixedSize(address); } // We can fall out here if the value wasn't actually a valid @@ -2333,20 +2298,39 @@ protected: private: template