diff --git a/lib/ASTGen/Sources/ASTGen/SourceFile.swift b/lib/ASTGen/Sources/ASTGen/SourceFile.swift index f90dfd6bd2a..d604f3d27d8 100644 --- a/lib/ASTGen/Sources/ASTGen/SourceFile.swift +++ b/lib/ASTGen/Sources/ASTGen/SourceFile.swift @@ -249,13 +249,11 @@ public func emitParserDiagnostics( } } -/// Retrieve a syntax node in the given source file, with the given type. -public func findSyntaxNodeInSourceFile( - sourceFilePtr: UnsafeRawPointer, - sourceLocationPtr: UnsafePointer?, - type: Node.Type, - wantOutermost: Bool = false -) -> Node? { +/// Find a token in the given source file at the given location. +func findToken( + in sourceFilePtr: UnsafeRawPointer, + at sourceLocationPtr: UnsafePointer? +) -> TokenSyntax? { guard let sourceLocationPtr = sourceLocationPtr else { return nil } @@ -277,6 +275,20 @@ public func findSyntaxNodeInSourceFile( return nil } + return token +} + +/// Retrieve a syntax node in the given source file, with the given type. +public func findSyntaxNodeInSourceFile( + sourceFilePtr: UnsafeRawPointer, + sourceLocationPtr: UnsafePointer?, + type: Node.Type, + wantOutermost: Bool = false +) -> Node? { + guard let token = findToken(in: sourceFilePtr, at: sourceLocationPtr) else { + return nil + } + var currentSyntax = Syntax(token) var resultSyntax: Node? = nil while let parentSyntax = currentSyntax.parent { @@ -309,6 +321,28 @@ public func findSyntaxNodeInSourceFile( return resultSyntax } +/// Retrieve a syntax node in the given source file that satisfies the +/// given predicate. +public func findSyntaxNodeInSourceFile( + sourceFilePtr: UnsafeRawPointer, + sourceLocationPtr: UnsafePointer?, + where predicate: (Syntax) -> Bool +) -> Syntax? { + guard let token = findToken(in: sourceFilePtr, at: sourceLocationPtr) else { + return nil + } + + var currentSyntax = Syntax(token) + while let parentSyntax = currentSyntax.parent { + currentSyntax = parentSyntax + if predicate(currentSyntax) { + return currentSyntax + } + } + + return nil +} + @_cdecl("swift_ASTGen_virtualFiles") @usableFromInline func getVirtualFiles( diff --git a/lib/ASTGen/Sources/MacroEvaluation/Macros.swift b/lib/ASTGen/Sources/MacroEvaluation/Macros.swift index 291ace95a03..21c42fbbcbc 100644 --- a/lib/ASTGen/Sources/MacroEvaluation/Macros.swift +++ b/lib/ASTGen/Sources/MacroEvaluation/Macros.swift @@ -583,22 +583,15 @@ func expandAttachedMacro( return 1 } - func findNode(type: T.Type) -> T? { - findSyntaxNodeInSourceFile( - sourceFilePtr: declarationSourceFilePtr, - sourceLocationPtr: declarationSourceLocPointer, - type: T.self - ) - } - // Dig out the node for the closure or declaration to which the custom // attribute is attached. - let node: Syntax - if let closureNode = findNode(type: ClosureExprSyntax.self) { - node = Syntax(closureNode) - } else if let declNode = findNode(type: DeclSyntax.self) { - node = Syntax(declNode) - } else { + let node = findSyntaxNodeInSourceFile( + sourceFilePtr: declarationSourceFilePtr, + sourceLocationPtr: declarationSourceLocPointer, + where: { $0.is(DeclSyntax.self) || $0.is(ClosureExprSyntax.self) } + ) + + guard let node else { return 1 }