[AutoDiff upstream] Serialize derivative function configurations. (#30672)

Serialize derivative function configurations per module.

`@differentiable` and `@derivative` attributes register derivatives for
`AbstractFunctionDecl`s for a particular "derivative function configuration":
parameter indices and dervative generic signature.

To find `@derivative` functions registered in other Swift modules, derivative
function configurations must be serialized per module. When configurations for
a `AbstractFunctionDecl` are requested, all configurations from imported
modules are deserialized. This module serialization technique has precedent: it
is used for protocol conformances (e.g. extension declarations for a nominal
type) and Obj-C members for a class type.

Add `AbstractFunctionDecl::getDerivativeFunctionConfigurations` entry point
for accessing derivative function configurations.

In the differentiation transform: use
`AbstractFunctionDecl::getDerivativeFunctionConfigurations` to implement
`findMinimalDerivativeConfiguration` for canonical derivative function
configuration lookup, replacing `getMinimalASTDifferentiableAttr`.

Resolves TF-1100.
This commit is contained in:
Dan Zheng
2020-03-27 06:40:27 -07:00
committed by GitHub
parent bbe86e908d
commit 28315487dc
15 changed files with 346 additions and 13 deletions

View File

@@ -910,6 +910,66 @@ ModuleFile::readObjCMethodTable(ArrayRef<uint64_t> fields, StringRef blobData) {
base + sizeof(uint32_t), base));
}
/// Used to deserialize entries in the on-disk derivative function configuration
/// table.
class ModuleFile::DerivativeFunctionConfigTableInfo {
public:
using internal_key_type = StringRef;
using external_key_type = internal_key_type;
using data_type = SmallVector<std::pair<std::string, GenericSignatureID>, 8>;
using hash_value_type = uint32_t;
using offset_type = unsigned;
external_key_type GetExternalKey(internal_key_type ID) { return ID; }
internal_key_type GetInternalKey(external_key_type ID) { return ID; }
hash_value_type ComputeHash(internal_key_type key) {
return llvm::djbHash(key, SWIFTMODULE_HASH_SEED);
}
static bool EqualKey(internal_key_type lhs, internal_key_type rhs) {
return lhs == rhs;
}
static std::pair<unsigned, unsigned> ReadKeyDataLength(const uint8_t *&data) {
unsigned keyLength = endian::readNext<uint16_t, little, unaligned>(data);
unsigned dataLength = endian::readNext<uint16_t, little, unaligned>(data);
return {keyLength, dataLength};
}
static internal_key_type ReadKey(const uint8_t *data, unsigned length) {
return StringRef(reinterpret_cast<const char *>(data), length);
}
static data_type ReadData(internal_key_type key, const uint8_t *data,
unsigned length) {
data_type result;
const uint8_t *limit = data + length;
while (data < limit) {
DeclID genSigId = endian::readNext<uint32_t, little, unaligned>(data);
int32_t nameLength = endian::readNext<int32_t, little, unaligned>(data);
StringRef mangledName(reinterpret_cast<const char *>(data), nameLength);
data += nameLength;
result.push_back({mangledName, genSigId});
}
return result;
}
};
std::unique_ptr<ModuleFile::SerializedDerivativeFunctionConfigTable>
ModuleFile::readDerivativeFunctionConfigTable(ArrayRef<uint64_t> fields,
StringRef blobData) {
uint32_t tableOffset;
index_block::DerivativeFunctionConfigTableLayout::readRecord(fields,
tableOffset);
auto base = reinterpret_cast<const uint8_t *>(blobData.data());
using OwnedTable = std::unique_ptr<SerializedDerivativeFunctionConfigTable>;
return OwnedTable(SerializedDerivativeFunctionConfigTable::Create(
base + tableOffset, base + sizeof(uint32_t), base));
}
bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) {
if (llvm::Error Err = cursor.EnterSubBlock(INDEX_BLOCK_ID)) {
// FIXME this drops the error on the floor.
@@ -1015,6 +1075,10 @@ bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) {
case index_block::OBJC_METHODS:
ObjCMethods = readObjCMethodTable(scratch, blobData);
break;
case index_block::DERIVATIVE_FUNCTION_CONFIGURATIONS:
DerivativeFunctionConfigurations =
readDerivativeFunctionConfigTable(scratch, blobData);
break;
case index_block::ENTRY_POINT:
assert(blobData.empty());
setEntryPointClassID(scratch.front());
@@ -2405,6 +2469,34 @@ void ModuleFile::loadObjCMethods(
}
}
void ModuleFile::loadDerivativeFunctionConfigurations(
AbstractFunctionDecl *originalAFD,
llvm::SetVector<AutoDiffConfig> &results) {
if (!DerivativeFunctionConfigurations)
return;
auto &ctx = originalAFD->getASTContext();
Mangle::ASTMangler Mangler;
auto mangledName = Mangler.mangleDeclAsUSR(originalAFD, "");
auto configs = DerivativeFunctionConfigurations->find(mangledName);
if (configs == DerivativeFunctionConfigurations->end())
return;
for (auto entry : *configs) {
auto *parameterIndices = IndexSubset::getFromString(ctx, entry.first);
auto derivativeGenSigOrError = getGenericSignatureChecked(entry.second);
if (!derivativeGenSigOrError) {
if (!getContext().LangOpts.EnableDeserializationRecovery)
fatal(derivativeGenSigOrError.takeError());
llvm::consumeError(derivativeGenSigOrError.takeError());
}
auto derivativeGenSig = derivativeGenSigOrError.get();
// NOTE(TF-1038): Result indices are currently unsupported in derivative
// registration attributes. In the meantime, always use `{0}` (wrt the
// first and only result).
auto resultIndices = IndexSubset::get(ctx, 1, {0});
results.insert({parameterIndices, resultIndices, derivativeGenSig});
}
}
TinyPtrVector<ValueDecl *>
ModuleFile::loadNamedMembers(const IterableDeclContext *IDC, DeclBaseName N,
uint64_t contextData) {