mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[AutoDiff upstream] Add @differentiable attribute serialization. (#30605)
Serialize "is linear?" flag, differentiability parameter indices, and differentiability generic signature. Deserialization has some ad-hoc logic for setting the original declaration and parameter indices for `@differentiable` attributes because `DeclDeserializer::deserializeDeclAttributes` does not have access to the original declaration. Resolves TF-836.
This commit is contained in:
@@ -2395,23 +2395,20 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
|
||||
case DAK_Differentiable: {
|
||||
auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
|
||||
auto *attr = cast<DifferentiableAttr>(DA);
|
||||
|
||||
auto paramIndices = attr->getParameterIndices();
|
||||
// NOTE(TF-836): `@differentiable` attribute serialization is blocked by
|
||||
// `@differentiable` attribute type-checking (TF-828), which resolves
|
||||
// parameter indices (`IndexSubset *`).
|
||||
if (!paramIndices)
|
||||
return;
|
||||
assert(attr->getOriginalDeclaration() &&
|
||||
"`@differentiable` attribute should have original declaration set "
|
||||
"during construction or parsing");
|
||||
auto *paramIndices = attr->getParameterIndices();
|
||||
assert(paramIndices && "Parameter indices must be resolved");
|
||||
SmallVector<bool, 4> indices;
|
||||
SmallVector<bool, 4> paramIndicesVector;
|
||||
for (unsigned i : range(paramIndices->getCapacity()))
|
||||
indices.push_back(paramIndices->contains(i));
|
||||
paramIndicesVector.push_back(paramIndices->contains(i));
|
||||
|
||||
DifferentiableDeclAttrLayout::emitRecord(
|
||||
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
|
||||
attr->isLinear(),
|
||||
S.addGenericSignatureRef(attr->getDerivativeGenericSignature()),
|
||||
indices);
|
||||
paramIndicesVector);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -2428,12 +2425,12 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
|
||||
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
|
||||
auto *parameterIndices = attr->getParameterIndices();
|
||||
assert(parameterIndices && "Parameter indices must be resolved");
|
||||
SmallVector<bool, 4> indices;
|
||||
SmallVector<bool, 4> paramIndicesVector;
|
||||
for (unsigned i : range(parameterIndices->getCapacity()))
|
||||
indices.push_back(parameterIndices->contains(i));
|
||||
paramIndicesVector.push_back(parameterIndices->contains(i));
|
||||
DerivativeDeclAttrLayout::emitRecord(
|
||||
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
|
||||
origDeclID, derivativeKind, indices);
|
||||
origDeclID, derivativeKind, paramIndicesVector);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -2453,12 +2450,12 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
|
||||
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
|
||||
auto *parameterIndices = attr->getParameterIndices();
|
||||
assert(parameterIndices && "Parameter indices must be resolved");
|
||||
SmallVector<bool, 4> indices;
|
||||
SmallVector<bool, 4> paramIndicesVector;
|
||||
for (unsigned i : range(parameterIndices->getCapacity()))
|
||||
indices.push_back(parameterIndices->contains(i));
|
||||
paramIndicesVector.push_back(parameterIndices->contains(i));
|
||||
TransposeDeclAttrLayout::emitRecord(
|
||||
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
|
||||
origDeclID, indices);
|
||||
origDeclID, paramIndicesVector);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user