Merge pull request #76951 from kovdan01/issue60102

[AutoDiff] Enhance performance of custom derivatives lookup
This commit is contained in:
Doug Gregor
2024-11-11 14:55:54 -08:00
committed by GitHub
10 changed files with 160 additions and 57 deletions

View File

@@ -166,6 +166,7 @@ class swift::SourceLookupCache {
ValueDeclMap TopLevelValues;
ValueDeclMap ClassMembers;
bool MemberCachePopulated = false;
llvm::SmallVector<AbstractFunctionDecl *, 0> CustomDerivatives;
DeclName UniqueMacroNamePlaceholder;
template<typename T>
@@ -173,8 +174,9 @@ class swift::SourceLookupCache {
OperatorMap<OperatorDecl> Operators;
OperatorMap<PrecedenceGroupDecl> PrecedenceGroups;
template<typename Range>
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators);
template <typename Range>
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators,
bool onlyDerivatives);
template<typename Range>
void addToMemberCache(Range decls);
@@ -205,6 +207,10 @@ public:
/// guaranteed to be meaningful.
void getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl *> &results);
/// Retrieves all the function decls marked as @derivative. The order of the
/// results is not guaranteed to be meaningful.
llvm::SmallVector<AbstractFunctionDecl *, 0> getCustomDerivativeDecls();
/// Look up an operator declaration.
///
/// \param name The operator name ("+", ">>", etc.)
@@ -249,9 +255,10 @@ static Decl *getAsDecl(Decl *decl) { return decl; }
static Expr *getAsExpr(ASTNode node) { return node.dyn_cast<Expr *>(); }
static Decl *getAsDecl(ASTNode node) { return node.dyn_cast<Decl *>(); }
template<typename Range>
template <typename Range>
void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
bool onlyOperators) {
bool onlyOperators,
bool onlyDerivatives) {
for (auto item : items) {
// In script mode, we'll see macro expansion expressions for freestanding
// macros.
@@ -268,19 +275,36 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
continue;
if (auto *VD = dyn_cast<ValueDecl>(D)) {
if (onlyOperators ? VD->isOperator() : VD->hasName()) {
// Cache the value under both its compound name and its full name.
auto getDerivative = [onlyDerivatives, VD]() -> AbstractFunctionDecl * {
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(VD))
if (AFD->getAttrs().hasAttribute<DerivativeAttr>())
return AFD;
return nullptr;
};
if (onlyOperators && VD->isOperator())
TopLevelValues.add(VD);
if (!onlyOperators && VD->getAttrs().hasAttribute<CustomAttr>()) {
if (onlyDerivatives)
if (AbstractFunctionDecl *AFD = getDerivative())
CustomDerivatives.push_back(AFD);
if (!onlyOperators && !onlyDerivatives && VD->hasName()) {
TopLevelValues.add(VD);
if (VD->getAttrs().hasAttribute<CustomAttr>())
MayHaveAuxiliaryDecls.push_back(VD);
}
if (AbstractFunctionDecl *AFD = getDerivative())
CustomDerivatives.push_back(AFD);
}
}
if (auto *NTD = dyn_cast<NominalTypeDecl>(D))
if (!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations())
addToUnqualifiedLookupCache(NTD->getMembers(), true);
if (auto *NTD = dyn_cast<NominalTypeDecl>(D)) {
bool onlyOperatorsArg =
(!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations());
bool onlyDerivativesArg =
(!NTD->hasUnparsedMembers() || NTD->maybeHasDerivativeDeclarations());
if (onlyOperatorsArg || onlyDerivativesArg) {
addToUnqualifiedLookupCache(NTD->getMembers(), onlyOperatorsArg,
onlyDerivativesArg);
}
}
if (auto *ED = dyn_cast<ExtensionDecl>(D)) {
// Avoid populating the cache with the members of invalid extension
@@ -292,8 +316,14 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
MayHaveAuxiliaryDecls.push_back(ED);
}
if (!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations())
addToUnqualifiedLookupCache(ED->getMembers(), true);
bool onlyOperatorsArg =
(!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations());
bool onlyDerivativesArg =
(!ED->hasUnparsedMembers() || ED->maybeHasDerivativeDeclarations());
if (onlyOperatorsArg || onlyDerivativesArg) {
addToUnqualifiedLookupCache(ED->getMembers(), onlyOperatorsArg,
onlyDerivativesArg);
}
}
if (auto *OD = dyn_cast<OperatorDecl>(D))
@@ -307,7 +337,8 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
MayHaveAuxiliaryDecls.push_back(MED);
} else if (auto TLCD = dyn_cast<TopLevelCodeDecl>(D)) {
if (auto body = TLCD->getBody()){
addToUnqualifiedLookupCache(body->getElements(), onlyOperators);
addToUnqualifiedLookupCache(body->getElements(), onlyOperators,
onlyDerivatives);
}
}
}
@@ -488,8 +519,8 @@ SourceLookupCache::SourceLookupCache(const SourceFile &SF)
{
FrontendStatsTracer tracer(SF.getASTContext().Stats,
"source-file-populate-cache");
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false);
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false);
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false, false);
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false, false);
}
SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
@@ -499,11 +530,11 @@ SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
"module-populate-cache");
for (const FileUnit *file : M.getFiles()) {
auto *SF = cast<SourceFile>(file);
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false);
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false);
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false, false);
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false, false);
if (auto *SFU = file->getSynthesizedFile()) {
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false);
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false, false);
}
}
}
@@ -572,6 +603,11 @@ void SourceLookupCache::getOperatorDecls(
results.append(ops.second.begin(), ops.second.end());
}
llvm::SmallVector<AbstractFunctionDecl *, 0>
SourceLookupCache::getCustomDerivativeDecls() {
return CustomDerivatives;
}
void SourceLookupCache::lookupOperator(Identifier name, OperatorFixity fixity,
TinyPtrVector<OperatorDecl *> &results) {
auto ops = Operators.find(name);
@@ -4026,6 +4062,23 @@ bool IsNonUserModuleRequest::evaluate(Evaluator &evaluator, ModuleDecl *mod) con
return false;
}
evaluator::SideEffect CustomDerivativesRequest::evaluate(Evaluator &evaluator,
SourceFile *sf) const {
ModuleDecl *module = sf->getParentModule();
assert(isParsedModule(module));
llvm::SmallVector<AbstractFunctionDecl *, 0> decls =
module->getSourceLookupCache().getCustomDerivativeDecls();
for (const AbstractFunctionDecl *afd : decls) {
for (const auto *derAttr :
afd->getAttrs().getAttributes<DerivativeAttr>()) {
// Resolve derivative function configurations from `@derivative`
// attributes by type-checking them.
(void)derAttr->getOriginalFunction(sf->getASTContext());
}
}
return {};
}
version::Version ModuleDecl::getLanguageVersionBuiltWith() const {
for (auto *F : getFiles()) {
auto *LD = dyn_cast<LoadedFile>(F);