mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Merge pull request #76951 from kovdan01/issue60102
[AutoDiff] Enhance performance of custom derivatives lookup
This commit is contained in:
@@ -166,6 +166,7 @@ class swift::SourceLookupCache {
|
||||
ValueDeclMap TopLevelValues;
|
||||
ValueDeclMap ClassMembers;
|
||||
bool MemberCachePopulated = false;
|
||||
llvm::SmallVector<AbstractFunctionDecl *, 0> CustomDerivatives;
|
||||
DeclName UniqueMacroNamePlaceholder;
|
||||
|
||||
template<typename T>
|
||||
@@ -173,8 +174,9 @@ class swift::SourceLookupCache {
|
||||
OperatorMap<OperatorDecl> Operators;
|
||||
OperatorMap<PrecedenceGroupDecl> PrecedenceGroups;
|
||||
|
||||
template<typename Range>
|
||||
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators);
|
||||
template <typename Range>
|
||||
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators,
|
||||
bool onlyDerivatives);
|
||||
template<typename Range>
|
||||
void addToMemberCache(Range decls);
|
||||
|
||||
@@ -205,6 +207,10 @@ public:
|
||||
/// guaranteed to be meaningful.
|
||||
void getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl *> &results);
|
||||
|
||||
/// Retrieves all the function decls marked as @derivative. The order of the
|
||||
/// results is not guaranteed to be meaningful.
|
||||
llvm::SmallVector<AbstractFunctionDecl *, 0> getCustomDerivativeDecls();
|
||||
|
||||
/// Look up an operator declaration.
|
||||
///
|
||||
/// \param name The operator name ("+", ">>", etc.)
|
||||
@@ -249,9 +255,10 @@ static Decl *getAsDecl(Decl *decl) { return decl; }
|
||||
static Expr *getAsExpr(ASTNode node) { return node.dyn_cast<Expr *>(); }
|
||||
static Decl *getAsDecl(ASTNode node) { return node.dyn_cast<Decl *>(); }
|
||||
|
||||
template<typename Range>
|
||||
template <typename Range>
|
||||
void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
|
||||
bool onlyOperators) {
|
||||
bool onlyOperators,
|
||||
bool onlyDerivatives) {
|
||||
for (auto item : items) {
|
||||
// In script mode, we'll see macro expansion expressions for freestanding
|
||||
// macros.
|
||||
@@ -268,19 +275,36 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
|
||||
continue;
|
||||
|
||||
if (auto *VD = dyn_cast<ValueDecl>(D)) {
|
||||
if (onlyOperators ? VD->isOperator() : VD->hasName()) {
|
||||
// Cache the value under both its compound name and its full name.
|
||||
auto getDerivative = [onlyDerivatives, VD]() -> AbstractFunctionDecl * {
|
||||
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(VD))
|
||||
if (AFD->getAttrs().hasAttribute<DerivativeAttr>())
|
||||
return AFD;
|
||||
return nullptr;
|
||||
};
|
||||
if (onlyOperators && VD->isOperator())
|
||||
TopLevelValues.add(VD);
|
||||
|
||||
if (!onlyOperators && VD->getAttrs().hasAttribute<CustomAttr>()) {
|
||||
if (onlyDerivatives)
|
||||
if (AbstractFunctionDecl *AFD = getDerivative())
|
||||
CustomDerivatives.push_back(AFD);
|
||||
if (!onlyOperators && !onlyDerivatives && VD->hasName()) {
|
||||
TopLevelValues.add(VD);
|
||||
if (VD->getAttrs().hasAttribute<CustomAttr>())
|
||||
MayHaveAuxiliaryDecls.push_back(VD);
|
||||
}
|
||||
if (AbstractFunctionDecl *AFD = getDerivative())
|
||||
CustomDerivatives.push_back(AFD);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto *NTD = dyn_cast<NominalTypeDecl>(D))
|
||||
if (!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations())
|
||||
addToUnqualifiedLookupCache(NTD->getMembers(), true);
|
||||
if (auto *NTD = dyn_cast<NominalTypeDecl>(D)) {
|
||||
bool onlyOperatorsArg =
|
||||
(!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations());
|
||||
bool onlyDerivativesArg =
|
||||
(!NTD->hasUnparsedMembers() || NTD->maybeHasDerivativeDeclarations());
|
||||
if (onlyOperatorsArg || onlyDerivativesArg) {
|
||||
addToUnqualifiedLookupCache(NTD->getMembers(), onlyOperatorsArg,
|
||||
onlyDerivativesArg);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto *ED = dyn_cast<ExtensionDecl>(D)) {
|
||||
// Avoid populating the cache with the members of invalid extension
|
||||
@@ -292,8 +316,14 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
|
||||
MayHaveAuxiliaryDecls.push_back(ED);
|
||||
}
|
||||
|
||||
if (!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations())
|
||||
addToUnqualifiedLookupCache(ED->getMembers(), true);
|
||||
bool onlyOperatorsArg =
|
||||
(!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations());
|
||||
bool onlyDerivativesArg =
|
||||
(!ED->hasUnparsedMembers() || ED->maybeHasDerivativeDeclarations());
|
||||
if (onlyOperatorsArg || onlyDerivativesArg) {
|
||||
addToUnqualifiedLookupCache(ED->getMembers(), onlyOperatorsArg,
|
||||
onlyDerivativesArg);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto *OD = dyn_cast<OperatorDecl>(D))
|
||||
@@ -307,7 +337,8 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
|
||||
MayHaveAuxiliaryDecls.push_back(MED);
|
||||
} else if (auto TLCD = dyn_cast<TopLevelCodeDecl>(D)) {
|
||||
if (auto body = TLCD->getBody()){
|
||||
addToUnqualifiedLookupCache(body->getElements(), onlyOperators);
|
||||
addToUnqualifiedLookupCache(body->getElements(), onlyOperators,
|
||||
onlyDerivatives);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -488,8 +519,8 @@ SourceLookupCache::SourceLookupCache(const SourceFile &SF)
|
||||
{
|
||||
FrontendStatsTracer tracer(SF.getASTContext().Stats,
|
||||
"source-file-populate-cache");
|
||||
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false);
|
||||
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false);
|
||||
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false, false);
|
||||
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false, false);
|
||||
}
|
||||
|
||||
SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
|
||||
@@ -499,11 +530,11 @@ SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
|
||||
"module-populate-cache");
|
||||
for (const FileUnit *file : M.getFiles()) {
|
||||
auto *SF = cast<SourceFile>(file);
|
||||
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false);
|
||||
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false);
|
||||
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false, false);
|
||||
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false, false);
|
||||
|
||||
if (auto *SFU = file->getSynthesizedFile()) {
|
||||
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false);
|
||||
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -572,6 +603,11 @@ void SourceLookupCache::getOperatorDecls(
|
||||
results.append(ops.second.begin(), ops.second.end());
|
||||
}
|
||||
|
||||
llvm::SmallVector<AbstractFunctionDecl *, 0>
|
||||
SourceLookupCache::getCustomDerivativeDecls() {
|
||||
return CustomDerivatives;
|
||||
}
|
||||
|
||||
void SourceLookupCache::lookupOperator(Identifier name, OperatorFixity fixity,
|
||||
TinyPtrVector<OperatorDecl *> &results) {
|
||||
auto ops = Operators.find(name);
|
||||
@@ -4026,6 +4062,23 @@ bool IsNonUserModuleRequest::evaluate(Evaluator &evaluator, ModuleDecl *mod) con
|
||||
return false;
|
||||
}
|
||||
|
||||
evaluator::SideEffect CustomDerivativesRequest::evaluate(Evaluator &evaluator,
|
||||
SourceFile *sf) const {
|
||||
ModuleDecl *module = sf->getParentModule();
|
||||
assert(isParsedModule(module));
|
||||
llvm::SmallVector<AbstractFunctionDecl *, 0> decls =
|
||||
module->getSourceLookupCache().getCustomDerivativeDecls();
|
||||
for (const AbstractFunctionDecl *afd : decls) {
|
||||
for (const auto *derAttr :
|
||||
afd->getAttrs().getAttributes<DerivativeAttr>()) {
|
||||
// Resolve derivative function configurations from `@derivative`
|
||||
// attributes by type-checking them.
|
||||
(void)derAttr->getOriginalFunction(sf->getASTContext());
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
version::Version ModuleDecl::getLanguageVersionBuiltWith() const {
|
||||
for (auto *F : getFiles()) {
|
||||
auto *LD = dyn_cast<LoadedFile>(F);
|
||||
|
||||
Reference in New Issue
Block a user