mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[AutoDiff upstream] Serialize derivative function configurations. (#30672)
Serialize derivative function configurations per module. `@differentiable` and `@derivative` attributes register derivatives for `AbstractFunctionDecl`s for a particular "derivative function configuration": parameter indices and dervative generic signature. To find `@derivative` functions registered in other Swift modules, derivative function configurations must be serialized per module. When configurations for a `AbstractFunctionDecl` are requested, all configurations from imported modules are deserialized. This module serialization technique has precedent: it is used for protocol conformances (e.g. extension declarations for a nominal type) and Obj-C members for a class type. Add `AbstractFunctionDecl::getDerivativeFunctionConfigurations` entry point for accessing derivative function configurations. In the differentiation transform: use `AbstractFunctionDecl::getDerivativeFunctionConfigurations` to implement `findMinimalDerivativeConfiguration` for canonical derivative function configuration lookup, replacing `getMinimalASTDifferentiableAttr`. Resolves TF-1100.
This commit is contained in:
@@ -910,6 +910,66 @@ ModuleFile::readObjCMethodTable(ArrayRef<uint64_t> fields, StringRef blobData) {
|
||||
base + sizeof(uint32_t), base));
|
||||
}
|
||||
|
||||
/// Used to deserialize entries in the on-disk derivative function configuration
|
||||
/// table.
|
||||
class ModuleFile::DerivativeFunctionConfigTableInfo {
|
||||
public:
|
||||
using internal_key_type = StringRef;
|
||||
using external_key_type = internal_key_type;
|
||||
using data_type = SmallVector<std::pair<std::string, GenericSignatureID>, 8>;
|
||||
using hash_value_type = uint32_t;
|
||||
using offset_type = unsigned;
|
||||
|
||||
external_key_type GetExternalKey(internal_key_type ID) { return ID; }
|
||||
|
||||
internal_key_type GetInternalKey(external_key_type ID) { return ID; }
|
||||
|
||||
hash_value_type ComputeHash(internal_key_type key) {
|
||||
return llvm::djbHash(key, SWIFTMODULE_HASH_SEED);
|
||||
}
|
||||
|
||||
static bool EqualKey(internal_key_type lhs, internal_key_type rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
static std::pair<unsigned, unsigned> ReadKeyDataLength(const uint8_t *&data) {
|
||||
unsigned keyLength = endian::readNext<uint16_t, little, unaligned>(data);
|
||||
unsigned dataLength = endian::readNext<uint16_t, little, unaligned>(data);
|
||||
return {keyLength, dataLength};
|
||||
}
|
||||
|
||||
static internal_key_type ReadKey(const uint8_t *data, unsigned length) {
|
||||
return StringRef(reinterpret_cast<const char *>(data), length);
|
||||
}
|
||||
|
||||
static data_type ReadData(internal_key_type key, const uint8_t *data,
|
||||
unsigned length) {
|
||||
data_type result;
|
||||
const uint8_t *limit = data + length;
|
||||
while (data < limit) {
|
||||
DeclID genSigId = endian::readNext<uint32_t, little, unaligned>(data);
|
||||
int32_t nameLength = endian::readNext<int32_t, little, unaligned>(data);
|
||||
StringRef mangledName(reinterpret_cast<const char *>(data), nameLength);
|
||||
data += nameLength;
|
||||
result.push_back({mangledName, genSigId});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<ModuleFile::SerializedDerivativeFunctionConfigTable>
|
||||
ModuleFile::readDerivativeFunctionConfigTable(ArrayRef<uint64_t> fields,
|
||||
StringRef blobData) {
|
||||
uint32_t tableOffset;
|
||||
index_block::DerivativeFunctionConfigTableLayout::readRecord(fields,
|
||||
tableOffset);
|
||||
auto base = reinterpret_cast<const uint8_t *>(blobData.data());
|
||||
|
||||
using OwnedTable = std::unique_ptr<SerializedDerivativeFunctionConfigTable>;
|
||||
return OwnedTable(SerializedDerivativeFunctionConfigTable::Create(
|
||||
base + tableOffset, base + sizeof(uint32_t), base));
|
||||
}
|
||||
|
||||
bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) {
|
||||
if (llvm::Error Err = cursor.EnterSubBlock(INDEX_BLOCK_ID)) {
|
||||
// FIXME this drops the error on the floor.
|
||||
@@ -1015,6 +1075,10 @@ bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) {
|
||||
case index_block::OBJC_METHODS:
|
||||
ObjCMethods = readObjCMethodTable(scratch, blobData);
|
||||
break;
|
||||
case index_block::DERIVATIVE_FUNCTION_CONFIGURATIONS:
|
||||
DerivativeFunctionConfigurations =
|
||||
readDerivativeFunctionConfigTable(scratch, blobData);
|
||||
break;
|
||||
case index_block::ENTRY_POINT:
|
||||
assert(blobData.empty());
|
||||
setEntryPointClassID(scratch.front());
|
||||
@@ -2405,6 +2469,34 @@ void ModuleFile::loadObjCMethods(
|
||||
}
|
||||
}
|
||||
|
||||
void ModuleFile::loadDerivativeFunctionConfigurations(
|
||||
AbstractFunctionDecl *originalAFD,
|
||||
llvm::SetVector<AutoDiffConfig> &results) {
|
||||
if (!DerivativeFunctionConfigurations)
|
||||
return;
|
||||
auto &ctx = originalAFD->getASTContext();
|
||||
Mangle::ASTMangler Mangler;
|
||||
auto mangledName = Mangler.mangleDeclAsUSR(originalAFD, "");
|
||||
auto configs = DerivativeFunctionConfigurations->find(mangledName);
|
||||
if (configs == DerivativeFunctionConfigurations->end())
|
||||
return;
|
||||
for (auto entry : *configs) {
|
||||
auto *parameterIndices = IndexSubset::getFromString(ctx, entry.first);
|
||||
auto derivativeGenSigOrError = getGenericSignatureChecked(entry.second);
|
||||
if (!derivativeGenSigOrError) {
|
||||
if (!getContext().LangOpts.EnableDeserializationRecovery)
|
||||
fatal(derivativeGenSigOrError.takeError());
|
||||
llvm::consumeError(derivativeGenSigOrError.takeError());
|
||||
}
|
||||
auto derivativeGenSig = derivativeGenSigOrError.get();
|
||||
// NOTE(TF-1038): Result indices are currently unsupported in derivative
|
||||
// registration attributes. In the meantime, always use `{0}` (wrt the
|
||||
// first and only result).
|
||||
auto resultIndices = IndexSubset::get(ctx, 1, {0});
|
||||
results.insert({parameterIndices, resultIndices, derivativeGenSig});
|
||||
}
|
||||
}
|
||||
|
||||
TinyPtrVector<ValueDecl *>
|
||||
ModuleFile::loadNamedMembers(const IterableDeclContext *IDC, DeclBaseName N,
|
||||
uint64_t contextData) {
|
||||
|
||||
Reference in New Issue
Block a user