[AutoDiff upstream] Add @differentiable attribute serialization. (#30605)

Serialize "is linear?" flag, differentiability parameter indices, and
differentiability generic signature.

Deserialization has some ad-hoc logic for setting the original declaration and
parameter indices for `@differentiable` attributes because
`DeclDeserializer::deserializeDeclAttributes` does not have access to the
original declaration.

Resolves TF-836.
This commit is contained in:
Dan Zheng
2020-03-24 08:22:56 -07:00
committed by GitHub
parent 0873622b4b
commit a856d59623
3 changed files with 83 additions and 24 deletions

View File

@@ -2395,23 +2395,20 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
case DAK_Differentiable: {
auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
auto *attr = cast<DifferentiableAttr>(DA);
auto paramIndices = attr->getParameterIndices();
// NOTE(TF-836): `@differentiable` attribute serialization is blocked by
// `@differentiable` attribute type-checking (TF-828), which resolves
// parameter indices (`IndexSubset *`).
if (!paramIndices)
return;
assert(attr->getOriginalDeclaration() &&
"`@differentiable` attribute should have original declaration set "
"during construction or parsing");
auto *paramIndices = attr->getParameterIndices();
assert(paramIndices && "Parameter indices must be resolved");
SmallVector<bool, 4> indices;
SmallVector<bool, 4> paramIndicesVector;
for (unsigned i : range(paramIndices->getCapacity()))
indices.push_back(paramIndices->contains(i));
paramIndicesVector.push_back(paramIndices->contains(i));
DifferentiableDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
attr->isLinear(),
S.addGenericSignatureRef(attr->getDerivativeGenericSignature()),
indices);
paramIndicesVector);
return;
}
@@ -2428,12 +2425,12 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
auto *parameterIndices = attr->getParameterIndices();
assert(parameterIndices && "Parameter indices must be resolved");
SmallVector<bool, 4> indices;
SmallVector<bool, 4> paramIndicesVector;
for (unsigned i : range(parameterIndices->getCapacity()))
indices.push_back(parameterIndices->contains(i));
paramIndicesVector.push_back(parameterIndices->contains(i));
DerivativeDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
origDeclID, derivativeKind, indices);
origDeclID, derivativeKind, paramIndicesVector);
return;
}
@@ -2453,12 +2450,12 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
auto *parameterIndices = attr->getParameterIndices();
assert(parameterIndices && "Parameter indices must be resolved");
SmallVector<bool, 4> indices;
SmallVector<bool, 4> paramIndicesVector;
for (unsigned i : range(parameterIndices->getCapacity()))
indices.push_back(parameterIndices->contains(i));
paramIndicesVector.push_back(parameterIndices->contains(i));
TransposeDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
origDeclID, indices);
origDeclID, paramIndicesVector);
return;
}
}