Standardize serialization of DeclNameRefs

There are four attributes which serialize out a DeclNameRef, sometimes by dropping some of its components. Standardize them with a representation that can handle module selectors.
This commit is contained in:
Becca Royal-Gordon
2025-05-10 01:53:28 -07:00
parent e71f837b1f
commit b6be85222f
6 changed files with 139 additions and 100 deletions

View File

@@ -213,12 +213,12 @@ OTHER(DERIVATIVE_FUNCTION_CONFIGURATION, 154)
OTHER(ERROR_FLAG, 155)
OTHER(ABI_ONLY_COUNTERPART, 156)
OTHER(DECL_NAME_REF, 157)
TRAILING_INFO(CONDITIONAL_SUBSTITUTION)
TRAILING_INFO(CONDITIONAL_SUBSTITUTION_COND)
OTHER(LIFETIME_DEPENDENCE, 160)
TRAILING_INFO(INHERITED_PROTOCOLS)
#ifndef DECL_ATTR

View File

@@ -3498,6 +3498,58 @@ public:
/// offsets in \c customAttrOffsets.
llvm::Error deserializeCustomAttrs();
DeclNameRef deserializeDeclNameRefIfPresent() {
using namespace decls_block;
SmallVector<uint64_t, 64> scratch;
StringRef blobData;
BCOffsetRAII restoreOffset(MF.DeclTypeCursor);
llvm::BitstreamEntry entry =
MF.fatalIfUnexpected(MF.DeclTypeCursor.advance());
unsigned recordID = MF.fatalIfUnexpected(
MF.DeclTypeCursor.readRecord(entry.ID, scratch, &blobData));
if (recordID != DECL_NAME_REF)
// This is normal--it just means there isn't a DeclNameRef here.
return { DeclNameRef() };
bool isCompoundName;
bool hasModuleSelector;
ArrayRef<uint64_t> rawPieceIDs;
DeclNameRefLayout::readRecord(scratch, isCompoundName, hasModuleSelector,
rawPieceIDs);
restoreOffset.cancel();
Identifier moduleSelector;
DeclBaseName baseName;
unsigned restIndex = 0;
ASSERT(rawPieceIDs.size() > 0);
if (hasModuleSelector) {
moduleSelector = MF.getIdentifier(rawPieceIDs[restIndex]);
restIndex++;
}
ASSERT(rawPieceIDs.size() > restIndex);
baseName = MF.getDeclBaseName(rawPieceIDs[restIndex]);
restIndex++;
if (isCompoundName) {
SmallVector<Identifier, 8> argLabels;
for (auto rawArgLabel : rawPieceIDs.drop_front(restIndex))
argLabels.push_back(MF.getIdentifier(rawArgLabel));
return DeclNameRef(ctx, moduleSelector, baseName, argLabels);
}
ASSERT(rawPieceIDs.size() == restIndex);
return DeclNameRef(ctx, moduleSelector, baseName);
}
Expected<Decl *> getDeclCheckedImpl(
llvm::function_ref<bool(DeclAttributes)> matchAttributes = nullptr);
@@ -6141,56 +6193,32 @@ llvm::Error DeclDeserializer::deserializeDeclCommon() {
unsigned specializationKindVal;
GenericSignatureID specializedSigID;
ArrayRef<uint64_t> rawPieceIDs;
uint64_t numArgs;
ArrayRef<uint64_t> rawTrailingIDs;
uint64_t numSPIGroups;
uint64_t numAvailabilityAttrs;
uint64_t numTypeErasedParams;
DeclID targetFunID;
serialization::decls_block::SpecializeDeclAttrLayout::readRecord(
scratch, exported, specializationKindVal, specializedSigID,
targetFunID, numArgs, numSPIGroups, numAvailabilityAttrs,
numTypeErasedParams, rawPieceIDs);
targetFunID, numSPIGroups, numAvailabilityAttrs, rawTrailingIDs);
assert(rawPieceIDs.size() == numArgs + numSPIGroups + numTypeErasedParams ||
rawPieceIDs.size() == (numArgs - 1 + numSPIGroups + numTypeErasedParams));
specializationKind = specializationKindVal
? SpecializeAttr::SpecializationKind::Partial
: SpecializeAttr::SpecializationKind::Full;
// The 'target' parameter.
DeclNameRef replacedFunctionName;
if (numArgs) {
bool numArgumentLabels = (numArgs == 1) ? 0 : numArgs - 2;
auto baseName = MF.getDeclBaseName(rawPieceIDs[0]);
SmallVector<Identifier, 4> pieces;
if (numArgumentLabels) {
for (auto pieceID : rawPieceIDs.slice(1, numArgumentLabels))
pieces.push_back(MF.getIdentifier(pieceID));
}
replacedFunctionName = (numArgs == 1)
? DeclNameRef({baseName}) // simple name
: DeclNameRef({ctx, baseName, pieces});
}
auto specializedSig = MF.getGenericSignature(specializedSigID);
// Take `numSPIGroups` trailing identifiers for the SPI groups.
SmallVector<Identifier, 4> spis;
if (numSPIGroups) {
auto numTargetFunctionPiecesToSkip =
(rawPieceIDs.size() == numArgs + numSPIGroups + numTypeErasedParams) ? numArgs
: numArgs - 1;
for (auto id : rawPieceIDs.slice(numTargetFunctionPiecesToSkip))
for (auto id : rawTrailingIDs.take_front(numSPIGroups))
spis.push_back(MF.getIdentifier(id));
}
// Take the rest for type-erased parameters.
SmallVector<Type, 4> typeErasedParams;
if (numTypeErasedParams) {
auto numTargetFunctionPiecesToSkip =
(rawPieceIDs.size() == numArgs + numSPIGroups + numTypeErasedParams) ? numArgs + numSPIGroups
: numArgs - 1 + numSPIGroups;
for (auto id : rawPieceIDs.slice(numTargetFunctionPiecesToSkip))
typeErasedParams.push_back(MF.getType(id));
}
for (auto id : rawTrailingIDs.drop_front(numSPIGroups))
typeErasedParams.push_back(MF.getType(id));
// Read availability attrs.
SmallVector<AvailableAttr *, 4> availabilityAttrs;
while (numAvailabilityAttrs) {
// Prepare to read the next record.
@@ -6220,10 +6248,12 @@ llvm::Error DeclDeserializer::deserializeDeclCommon() {
--numAvailabilityAttrs;
}
auto specializedSig = MF.getGenericSignature(specializedSigID);
// Read target function DeclNameRef, if present.
DeclNameRef targetFunName = deserializeDeclNameRefIfPresent();
Attr = SpecializeAttr::create(ctx, exported != 0, specializationKind,
spis, availabilityAttrs, typeErasedParams,
specializedSig, replacedFunctionName, &MF,
specializedSig, targetFunName, &MF,
targetFunID);
break;
}
@@ -6254,21 +6284,15 @@ llvm::Error DeclDeserializer::deserializeDeclCommon() {
case decls_block::DynamicReplacement_DECL_ATTR: {
bool isImplicit;
uint64_t numArgs;
ArrayRef<uint64_t> rawPieceIDs;
DeclID replacedFunID;
serialization::decls_block::DynamicReplacementDeclAttrLayout::
readRecord(scratch, isImplicit, replacedFunID, numArgs, rawPieceIDs);
readRecord(scratch, isImplicit, replacedFunID);
auto baseName = MF.getDeclBaseName(rawPieceIDs[0]);
SmallVector<Identifier, 4> pieces;
for (auto pieceID : rawPieceIDs.slice(1))
pieces.push_back(MF.getIdentifier(pieceID));
DeclNameRef replacedFunName = deserializeDeclNameRefIfPresent();
assert(numArgs != 0);
assert(!isImplicit && "Need to update for implicit");
Attr = DynamicReplacementAttr::create(
ctx, DeclNameRef({ ctx, baseName, pieces }), &MF, replacedFunID);
Attr = DynamicReplacementAttr::create(ctx, replacedFunName, &MF,
replacedFunID);
break;
}
@@ -6339,7 +6363,6 @@ llvm::Error DeclDeserializer::deserializeDeclCommon() {
case decls_block::Derivative_DECL_ATTR: {
bool isImplicit;
uint64_t origNameId;
bool hasAccessorKind;
uint64_t rawAccessorKind;
DeclID origDeclId;
@@ -6347,7 +6370,7 @@ llvm::Error DeclDeserializer::deserializeDeclCommon() {
ArrayRef<uint64_t> parameters;
serialization::decls_block::DerivativeDeclAttrLayout::readRecord(
scratch, isImplicit, origNameId, hasAccessorKind, rawAccessorKind,
scratch, isImplicit, hasAccessorKind, rawAccessorKind,
origDeclId, rawDerivativeKind, parameters);
std::optional<AccessorKind> accessorKind = std::nullopt;
@@ -6358,8 +6381,6 @@ llvm::Error DeclDeserializer::deserializeDeclCommon() {
accessorKind = *maybeAccessorKind;
}
DeclNameRefWithLoc origName{DeclNameRef(MF.getDeclBaseName(origNameId)),
DeclNameLoc(), accessorKind};
auto derivativeKind =
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
if (!derivativeKind)
@@ -6369,9 +6390,14 @@ llvm::Error DeclDeserializer::deserializeDeclCommon() {
parametersBitVector[i] = parameters[i];
auto *indices = IndexSubset::get(ctx, parametersBitVector);
auto origName = deserializeDeclNameRefIfPresent();
DeclNameRefWithLoc origNameWithLoc{origName, DeclNameLoc(),
accessorKind};
auto *derivativeAttr =
DerivativeAttr::create(ctx, isImplicit, SourceLoc(), SourceRange(),
/*baseType*/ nullptr, origName, indices);
/*baseType*/ nullptr, origNameWithLoc,
indices);
derivativeAttr->setOriginalFunctionResolver(&MF, origDeclId);
derivativeAttr->setDerivativeKind(*derivativeKind);
Attr = derivativeAttr;
@@ -6380,20 +6406,21 @@ llvm::Error DeclDeserializer::deserializeDeclCommon() {
case decls_block::Transpose_DECL_ATTR: {
bool isImplicit;
uint64_t origNameId;
DeclID origDeclId;
ArrayRef<uint64_t> parameters;
serialization::decls_block::TransposeDeclAttrLayout::readRecord(
scratch, isImplicit, origNameId, origDeclId, parameters);
scratch, isImplicit, origDeclId, parameters);
DeclNameRefWithLoc origName{DeclNameRef(MF.getDeclBaseName(origNameId)),
DeclNameLoc(), std::nullopt};
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
llvm::SmallBitVector parametersBitVector(parameters.size());
for (unsigned i : indices(parameters))
parametersBitVector[i] = parameters[i];
auto *indices = IndexSubset::get(ctx, parametersBitVector);
auto origNameRef = deserializeDeclNameRefIfPresent();
DeclNameRefWithLoc origName{origNameRef, DeclNameLoc(), std::nullopt};
auto *transposeAttr =
TransposeAttr::create(ctx, isImplicit, SourceLoc(), SourceRange(),
/*baseTypeRepr*/ nullptr, origName, indices);

View File

@@ -58,7 +58,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
const uint16_t SWIFTMODULE_VERSION_MINOR = 949; // vector_base_addr instruction
const uint16_t SWIFTMODULE_VERSION_MINOR = 950; // DeclNameRef module selectors
/// A standard hash seed used for all string hashes in a serialized module.
///
@@ -1258,6 +1258,15 @@ namespace decls_block {
DeclIDField // API decl
>;
/// A field containing the pieces of a \c DeclNameRef and the information
/// needed to reconstruct it.
using DeclNameRefLayout = BCRecordLayout<
DECL_NAME_REF,
BCFixed<1>, // isCompoundName
BCFixed<1>, // hasModuleSelector
BCArray<IdentifierIDField> // module selector, base name, arg labels
>;
/// A placeholder for invalid types
TYPE_LAYOUT(ErrorTypeLayout,
ERROR_TYPE,
@@ -2444,11 +2453,9 @@ namespace decls_block {
BCFixed<1>, // specialization kind
GenericSignatureIDField, // specialized signature
DeclIDField, // target function
BCVBR<4>, // # of arguments (+1) or 1 if simple decl name, 0 if no target
BCVBR<4>, // # of SPI groups
BCVBR<4>, // # of availability attributes
BCVBR<4>, // # of type erased parameters
BCArray<IdentifierIDField> // target function pieces, spi groups, type erased params
BCArray<IdentifierIDField> // spi groups, type erased params
>;
using StorageRestrictionsDeclAttrLayout = BCRecordLayout<
@@ -2468,7 +2475,6 @@ namespace decls_block {
using DerivativeDeclAttrLayout = BCRecordLayout<
Derivative_DECL_ATTR,
BCFixed<1>, // Implicit flag.
IdentifierIDField, // Original name.
BCFixed<1>, // Has original accessor kind?
AccessorKindField, // Original accessor kind.
DeclIDField, // Original function declaration.
@@ -2479,7 +2485,6 @@ namespace decls_block {
using TransposeDeclAttrLayout = BCRecordLayout<
Transpose_DECL_ATTR,
BCFixed<1>, // Implicit flag.
IdentifierIDField, // Original name.
DeclIDField, // Original function declaration.
BCArray<BCFixed<1>> // Transposed parameter indices' bitvector.
>;
@@ -2494,9 +2499,7 @@ namespace decls_block {
using DynamicReplacementDeclAttrLayout = BCRecordLayout<
DynamicReplacement_DECL_ATTR,
BCFixed<1>, // implicit flag
DeclIDField, // replaced function
BCVBR<4>, // # of arguments (+1) or zero if no name
BCArray<IdentifierIDField>
DeclIDField // replaced function
>;
using TypeEraserDeclAttrLayout = BCRecordLayout<

View File

@@ -3204,49 +3204,34 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
case DeclAttrKind::Specialize: {
auto abbrCode = S.DeclTypeAbbrCodes[SpecializeDeclAttrLayout::Code];
auto attr = cast<SpecializeAttr>(DA);
auto targetFun = attr->getTargetFunctionName();
auto *afd = cast<AbstractFunctionDecl>(D);
auto *targetFunDecl = attr->getTargetFunctionDecl(afd);
SmallVector<IdentifierID, 4> pieces;
// encodes whether this a a simple or compound name by adding one.
size_t numArgs = 0;
if (targetFun) {
pieces.push_back(S.addDeclBaseNameRef(targetFun.getBaseName()));
for (auto argName : targetFun.getArgumentNames())
pieces.push_back(S.addDeclBaseNameRef(argName));
if (targetFun.isSimpleName()) {
assert(pieces.size() == 1);
numArgs = 1;
} else
numArgs = pieces.size() + 1;
}
// SPI groups
auto numSPIGroups = attr->getSPIGroups().size();
for (auto spi : attr->getSPIGroups()) {
assert(!spi.empty() && "Empty SPI name");
pieces.push_back(S.addDeclBaseNameRef(spi));
}
// Type-erased params
for (auto ty : attr->getTypeErasedParams()) {
pieces.push_back(S.addTypeRef(ty));
}
auto numSPIGroups = attr->getSPIGroups().size();
auto numTypeErasedParams = attr->getTypeErasedParams().size();
assert(pieces.size() == numArgs + numSPIGroups + numTypeErasedParams ||
pieces.size() == (numArgs - 1 + numSPIGroups + numTypeErasedParams));
auto numAvailabilityAttrs = attr->getAvailableAttrs().size();
SpecializeDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, (unsigned)attr->isExported(),
(unsigned)attr->getSpecializationKind(),
S.addGenericSignatureRef(attr->getSpecializedSignature(afd)),
S.addDeclRef(targetFunDecl), numArgs, numSPIGroups,
numAvailabilityAttrs, numTypeErasedParams,
S.addDeclRef(targetFunDecl), numSPIGroups, numAvailabilityAttrs,
pieces);
for (auto availAttr : attr->getAvailableAttrs()) {
writeDeclAttribute(D, availAttr);
}
writeDeclNameRefIfNeeded(attr->getTargetFunctionName());
return;
}
@@ -3278,16 +3263,13 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
auto abbrCode =
S.DeclTypeAbbrCodes[DynamicReplacementDeclAttrLayout::Code];
auto theAttr = cast<DynamicReplacementAttr>(DA);
auto replacedFun = theAttr->getReplacedFunctionName();
SmallVector<IdentifierID, 4> pieces;
pieces.push_back(S.addDeclBaseNameRef(replacedFun.getBaseName()));
for (auto argName : replacedFun.getArgumentNames())
pieces.push_back(S.addDeclBaseNameRef(argName));
auto *afd = cast<ValueDecl>(D)->getDynamicallyReplacedDecl();
assert(afd && "Missing replaced decl!");
DynamicReplacementDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, false, /*implicit flag*/
S.addDeclRef(afd), pieces.size(), pieces);
S.addDeclRef(afd));
writeDeclNameRefIfNeeded(theAttr->getReplacedFunctionName());
return;
}
@@ -3360,8 +3342,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
"`@derivative` attribute should have original declaration set "
"during construction or parsing");
auto origDeclNameRef = attr->getOriginalFunctionName();
auto origName = origDeclNameRef.Name.getBaseName();
IdentifierID origNameId = S.addDeclBaseNameRef(origName);
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction(ctx));
auto derivativeKind =
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
@@ -3375,9 +3356,10 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
for (unsigned i : range(parameterIndices->getCapacity()))
paramIndicesVector.push_back(parameterIndices->contains(i));
DerivativeDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
origAccessorKind.has_value(), rawAccessorKind, origDeclID,
derivativeKind, paramIndicesVector);
writeDeclNameRefIfNeeded(origDeclNameRef.Name);
return;
}
@@ -3387,8 +3369,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
assert(attr->getOriginalFunction() &&
"`@transpose` attribute should have original declaration set "
"during construction or parsing");
auto origName = attr->getOriginalFunctionName().Name.getBaseName();
IdentifierID origNameId = S.addDeclBaseNameRef(origName);
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
auto *parameterIndices = attr->getParameterIndices();
assert(parameterIndices && "Parameter indices must be resolved");
@@ -3396,8 +3377,9 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
for (unsigned i : range(parameterIndices->getCapacity()))
paramIndicesVector.push_back(parameterIndices->contains(i));
TransposeDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
origDeclID, paramIndicesVector);
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origDeclID,
paramIndicesVector);
writeDeclNameRefIfNeeded(attr->getOriginalFunctionName().Name);
return;
}
@@ -3621,6 +3603,32 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
}
}
void writeDeclNameRefIfNeeded(DeclNameRef name) {
using namespace decls_block;
// DeclNameRefs are always optional and write nothing when absent.
if (!name)
return;
bool isCompoundName = name.isCompoundName();
bool hasModuleSelector = name.hasModuleSelector();
SmallVector<IdentifierID, 8> rawPieceIDs;
if (hasModuleSelector)
rawPieceIDs.push_back(S.addDeclBaseNameRef(name.getModuleSelector()));
rawPieceIDs.push_back(S.addDeclBaseNameRef(name.getBaseName()));
if (isCompoundName)
for (auto argName : name.getArgumentNames())
rawPieceIDs.push_back(S.addDeclBaseNameRef(argName));
auto abbrCode = S.DeclTypeAbbrCodes[DeclNameRefLayout::Code];
DeclNameRefLayout::emitRecord(S.Out, S.ScratchRecord, abbrCode,
isCompoundName, hasModuleSelector,
rawPieceIDs);
}
size_t addConformances(const IterableDeclContext *declContext,
ConformanceLookupKind lookupKind,
SmallVectorImpl<uint64_t> &data) {
@@ -6337,6 +6345,7 @@ void Serializer::writeAllDeclsAndTypes() {
registerDeclTypeAbbr<ErrorFlagLayout>();
registerDeclTypeAbbr<ErrorTypeLayout>();
registerDeclTypeAbbr<ABIOnlyCounterpartLayout>();
registerDeclTypeAbbr<DeclNameRefLayout>();
registerDeclTypeAbbr<ClangTypeLayout>();

View File

@@ -141,7 +141,7 @@ extension S {
self
}
// CHECK: @derivative(of: subscript, wrt: self)
// CHECK: @derivative(of: subscript(_:), wrt: self)
@derivative(of: subscript(_:), wrt: self)
func derivativeSubscript<T: Differentiable>(x: T) -> (value: S, differential: (S) -> S) {
(self, { $0 })

View File

@@ -89,7 +89,7 @@ extension S {
extension S {
subscript<T: Differentiable>(x: T) -> Self { self }
// CHECK: @transpose(of: subscript, wrt: self)
// CHECK: @transpose(of: subscript(_:), wrt: self)
@transpose(of: subscript(_:), wrt: self)
static func transposeSubscript<T: Differentiable>(x: T, t: Self) -> Self {
t