Allow attached macros to be applied to imported C declarations

The Clang importer maps arbitrary attributes spelled with `swift_attr("...")`
over to Swift attributes, using the Swift parser to process those attributes.
Extend this mechanism to allow `swift_attr` to refer to an attached macro,
expanding that macro as needed.

When a macro is applied to an imported declaration, that declaration is
pretty-printed (from the C++ AST) to provide to the macro implementation.
There are a few games we need to place to resolve the macro, and a few more
to lazily perform pretty-printing and adjust source locations to get the
right information to the macro, but this demonstrates that we could
take this path.

As an example, we use this mechanism to add an `async` version of a C
function that delivers its result via completion handler, using the
`@AddAsync` example macro implementation from the swift-syntax
repository.
This commit is contained in:
Doug Gregor
2024-09-09 22:34:06 -07:00
parent 5fa12d31ae
commit cdcabd05bd
7 changed files with 258 additions and 3 deletions

View File

@@ -40,6 +40,7 @@
#include "swift/Basic/SourceManager.h"
#include "swift/Basic/Statistic.h"
#include "swift/ClangImporter/ClangImporterRequests.h"
#include "swift/ClangImporter/ClangModule.h"
#include "swift/Parse/Lexer.h"
#include "swift/Strings.h"
#include "clang/AST/DeclObjC.h"
@@ -1718,6 +1719,14 @@ SmallVector<MacroDecl *, 1> namelookup::lookupMacros(DeclContext *dc,
ctx.evaluator, UnqualifiedLookupRequest{moduleLookupDesc}, {});
auto foundTypeDecl = moduleLookup.getSingleTypeResult();
auto *moduleDecl = dyn_cast_or_null<ModuleDecl>(foundTypeDecl);
// When resolving macro names for imported entities, we look for any
// loaded module.
if (!moduleDecl && isa<ClangModuleUnit>(moduleScopeDC)) {
moduleDecl = ctx.getLoadedModule(moduleName.getBaseIdentifier());
moduleScopeDC = moduleDecl;
}
if (!moduleDecl)
return {};

View File

@@ -4717,6 +4717,17 @@ bool ClangImporter::Implementation::lookupValue(SwiftLookupTable &table,
}
}
// Visit auxiliary declarations to check for name matches.
decl->visitAuxiliaryDecls([&](Decl *aux) {
if (auto auxValue = dyn_cast<ValueDecl>(aux)) {
if (auxValue->getName().matchesRef(name) &&
auxValue->getDeclContext()->isModuleScopeContext()) {
consumer.foundDecl(auxValue, DeclVisibilityKind::VisibleAtTopLevel);
anyMatching = true;
}
}
});
// If we have a declaration and nothing matched so far, try the names used
// in other versions of Swift.
if (auto clangDecl = entry.dyn_cast<clang::NamedDecl *>()) {

View File

@@ -23,6 +23,7 @@
#include "swift/AST/ASTContext.h"
#include "swift/AST/ASTMangler.h"
#include "swift/AST/ASTNode.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/DiagnosticsFrontend.h"
#include "swift/AST/Expr.h"
#include "swift/AST/FreestandingMacroExpansion.h"
@@ -39,6 +40,7 @@
#include "swift/Basic/Lazy.h"
#include "swift/Basic/SourceManager.h"
#include "swift/Basic/StringExtras.h"
#include "swift/ClangImporter/ClangModule.h"
#include "swift/Bridging/ASTGen.h"
#include "swift/Bridging/Macros.h"
#include "swift/Demangling/Demangler.h"
@@ -1021,7 +1023,10 @@ createMacroSourceFile(std::unique_ptr<llvm::MemoryBuffer> buffer,
auto macroSourceFile = new (ctx) SourceFile(
*dc->getParentModule(), SourceFileKind::MacroExpansion, macroBufferID,
/*parsingOpts=*/{}, /*isPrimary=*/false);
macroSourceFile->setImports(dc->getParentSourceFile()->getImports());
if (auto parentSourceFile = dc->getParentSourceFile())
macroSourceFile->setImports(parentSourceFile->getImports());
else if (isa<ClangModuleUnit>(dc->getModuleScopeContext()))
macroSourceFile->setImports({});
return macroSourceFile;
}
@@ -1346,8 +1351,44 @@ static SourceFile *evaluateAttachedMacro(MacroDecl *macro, Decl *attachedTo,
if (!attrSourceFile)
return nullptr;
auto declSourceFile =
SourceFile *declSourceFile =
moduleDecl->getSourceFileContainingLocation(attachedTo->getStartLoc());
if (!declSourceFile && isa<ClangModuleUnit>(dc->getModuleScopeContext())) {
// Pretty-print the declaration into a buffer so we can macro-expand
// it.
// FIXME: Turn this into a request.
llvm::SmallString<128> buffer;
{
llvm::raw_svector_ostream out(buffer);
StreamPrinter printer(out);
attachedTo->print(
printer,
PrintOptions::printForDiagnostics(
AccessLevel::Public,
ctx.TypeCheckerOpts.PrintFullConvention));
}
// Create the buffer.
SourceManager &sourceMgr = ctx.SourceMgr;
auto bufferID = sourceMgr.addMemBufferCopy(buffer);
auto memBufferStartLoc = sourceMgr.getLocForBufferStart(bufferID);
sourceMgr.setGeneratedSourceInfo(
bufferID,
GeneratedSourceInfo{
GeneratedSourceInfo::PrettyPrinted,
CharSourceRange(),
CharSourceRange(memBufferStartLoc, buffer.size()),
ASTNode(const_cast<Decl *>(attachedTo)).getOpaqueValue(),
nullptr
}
);
// Create a source file to go with it.
declSourceFile = new (ctx)
SourceFile(*moduleDecl, SourceFileKind::Library, bufferID);
moduleDecl->addAuxiliaryFile(*declSourceFile);
}
if (!declSourceFile)
return nullptr;
@@ -1486,13 +1527,18 @@ static SourceFile *evaluateAttachedMacro(MacroDecl *macro, Decl *attachedTo,
if (auto var = dyn_cast<VarDecl>(attachedTo))
searchDecl = var->getParentPatternBinding();
auto startLoc = searchDecl->getStartLoc();
if (startLoc.isInvalid() && isa<ClangModuleUnit>(dc->getModuleScopeContext())) {
startLoc = ctx.SourceMgr.getLocForBufferStart(*declSourceFile->getBufferID());
}
BridgedStringRef evaluatedSourceOut{nullptr, 0};
assert(!externalDef.isError());
swift_Macros_expandAttachedMacro(
&ctx.Diags, externalDef.get(), discriminator->c_str(),
extendedType.c_str(), conformanceList.c_str(), getRawMacroRole(role),
astGenAttrSourceFile, attr->AtLoc.getOpaquePointerValue(),
astGenDeclSourceFile, searchDecl->getStartLoc().getOpaquePointerValue(),
astGenDeclSourceFile, startLoc.getOpaquePointerValue(),
astGenParentDeclSourceFile, parentDeclLoc, &evaluatedSourceOut);
if (!evaluatedSourceOut.unbridged().data())
return nullptr;

View File

@@ -0,0 +1,2 @@
void async_divide(double x, double y, void (* _Nonnull completionHandler)(double x))
__attribute__((swift_attr("@ModuleUser.AddAsync")));

View File

@@ -153,3 +153,7 @@ module IncompleteTypes {
header "IncompleteTypes.h"
export *
}
module CompletionHandlerGlobals {
header "completion_handler_globals.h"
}

View File

@@ -887,6 +887,163 @@ public enum LeftHandOperandFinderMacro: ExpressionMacro {
}
}
extension SyntaxCollection {
mutating func removeLast() {
self.remove(at: self.index(before: self.endIndex))
}
}
public struct AddAsyncMacro: PeerMacro {
public static func expansion<
Context: MacroExpansionContext,
Declaration: DeclSyntaxProtocol
>(
of node: AttributeSyntax,
providingPeersOf declaration: Declaration,
in context: Context
) throws -> [DeclSyntax] {
// Only on functions at the moment.
guard var funcDecl = declaration.as(FunctionDeclSyntax.self) else {
throw CustomError.message("@addAsync only works on functions")
}
// This only makes sense for non async functions.
if funcDecl.signature.effectSpecifiers?.asyncSpecifier != nil {
throw CustomError.message(
"@addAsync requires an non async function"
)
}
// This only makes sense void functions
if let resultType = funcDecl.signature.returnClause?.type,
resultType.as(IdentifierTypeSyntax.self)?.name.text != "Void" {
throw CustomError.message(
"@addAsync requires an function that returns void"
)
}
// Requires a completion handler block as last parameter
let completionHandlerParameter = funcDecl
.signature
.parameterClause
.parameters.last?
.type.as(AttributedTypeSyntax.self)?
.baseType.as(FunctionTypeSyntax.self)
guard let completionHandlerParameter else {
throw CustomError.message(
"@AddAsync requires an function that has a completion handler as last parameter"
)
}
// Completion handler needs to return Void
if completionHandlerParameter.returnClause.type.as(IdentifierTypeSyntax.self)?.name.text != "Void" {
throw CustomError.message(
"@AddAsync requires an function that has a completion handler that returns Void"
)
}
let returnType = completionHandlerParameter.parameters.first?.type
let isResultReturn = returnType?.children(viewMode: .all).first?.description == "Result"
let successReturnType =
if isResultReturn {
returnType!.as(IdentifierTypeSyntax.self)!.genericArgumentClause?.arguments.first!.argument
} else {
returnType
}
// Remove completionHandler and comma from the previous parameter
var newParameterList = funcDecl.signature.parameterClause.parameters
newParameterList.removeLast()
var newParameterListLastParameter = newParameterList.last!
newParameterList.removeLast()
newParameterListLastParameter.trailingTrivia = []
newParameterListLastParameter.trailingComma = nil
newParameterList.append(newParameterListLastParameter)
// Drop the @AddAsync attribute from the new declaration.
let newAttributeList = funcDecl.attributes.filter {
guard case let .attribute(attribute) = $0,
let attributeType = attribute.attributeName.as(IdentifierTypeSyntax.self),
let nodeType = node.attributeName.as(IdentifierTypeSyntax.self)
else {
return true
}
return attributeType.name.text != nodeType.name.text
}
let callArguments: [String] = newParameterList.map { param in
let argName = param.secondName ?? param.firstName
let paramName = param.firstName
if paramName.text != "_" {
return "\(paramName.text): \(argName.text)"
}
return "\(argName.text)"
}
let switchBody: ExprSyntax =
"""
switch returnValue {
case .success(let value):
continuation.resume(returning: value)
case .failure(let error):
continuation.resume(throwing: error)
}
"""
let newBody: ExprSyntax =
"""
\(raw: isResultReturn ? "try await withCheckedThrowingContinuation { continuation in" : "await withCheckedContinuation { continuation in")
\(raw: funcDecl.name)(\(raw: callArguments.joined(separator: ", "))) { \(raw: returnType != nil ? "returnValue in" : "")
\(raw: isResultReturn ? switchBody : "continuation.resume(returning: \(raw: returnType != nil ? "returnValue" : "()"))")
}
}
"""
// add async
funcDecl.signature.effectSpecifiers = FunctionEffectSpecifiersSyntax(
leadingTrivia: .space,
asyncSpecifier: .keyword(.async),
throwsClause: isResultReturn ? ThrowsClauseSyntax(throwsSpecifier: .keyword(.throws)) : nil
)
// add result type
if let successReturnType {
funcDecl.signature.returnClause = ReturnClauseSyntax(
leadingTrivia: .space,
type: successReturnType.with(\.leadingTrivia, .space)
)
} else {
funcDecl.signature.returnClause = nil
}
// drop completion handler
funcDecl.signature.parameterClause.parameters = newParameterList
funcDecl.signature.parameterClause.trailingTrivia = []
funcDecl.body = CodeBlockSyntax(
leftBrace: .leftBraceToken(leadingTrivia: .space),
statements: CodeBlockItemListSyntax(
[CodeBlockItemSyntax(item: .expr(newBody))]
),
rightBrace: .rightBraceToken(leadingTrivia: .newline)
)
funcDecl.attributes = newAttributeList
funcDecl.leadingTrivia = .newlines(2)
return [DeclSyntax(funcDecl)]
}
}
public struct AddCompletionHandler: PeerMacro {
public static func expansion(
of node: AttributeSyntax,

View File

@@ -0,0 +1,26 @@
// REQUIRES: swift_swift_parser, executable_test
// RUN: %empty-directory(%t)
// RUN: %host-build-swift -swift-version 5 -emit-library -o %t/%target-library-name(MacroDefinition) -module-name=MacroDefinition %S/Inputs/syntax_macro_definitions.swift -g -no-toolchain-stdlib-rpath -swift-version 5
// Diagnostics testing
// RUN: %target-swift-frontend(mock-sdk: %clang-importer-sdk) -typecheck -verify -swift-version 5 -enable-experimental-feature CodeItemMacros -load-plugin-library %t/%target-library-name(MacroDefinition) -module-name ModuleUser %s
@attached(peer, names: overloaded)
public macro AddAsync() = #externalMacro(module: "MacroDefinition", type: "AddAsyncMacro")
import CompletionHandlerGlobals
// Make sure that @AddAsync works at all.
@AddAsync
@available(SwiftStdlib 5.1, *)
func asyncTest(_ value: Int, completionHandler: @escaping (String) -> Void) {
completionHandler(String(value))
}
@available(SwiftStdlib 5.1, *)
func testAll(x: Double, y: Double) async {
_ = await asyncTest(17)
let _: Double = await async_divide(1.0, 2.0)
}