mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Allow attached macros to be applied to imported C declarations
The Clang importer maps arbitrary attributes spelled with `swift_attr("...")`
over to Swift attributes, using the Swift parser to process those attributes.
Extend this mechanism to allow `swift_attr` to refer to an attached macro,
expanding that macro as needed.
When a macro is applied to an imported declaration, that declaration is
pretty-printed (from the C++ AST) to provide to the macro implementation.
There are a few games we need to place to resolve the macro, and a few more
to lazily perform pretty-printing and adjust source locations to get the
right information to the macro, but this demonstrates that we could
take this path.
As an example, we use this mechanism to add an `async` version of a C
function that delivers its result via completion handler, using the
`@AddAsync` example macro implementation from the swift-syntax
repository.
This commit is contained in:
@@ -40,6 +40,7 @@
|
||||
#include "swift/Basic/SourceManager.h"
|
||||
#include "swift/Basic/Statistic.h"
|
||||
#include "swift/ClangImporter/ClangImporterRequests.h"
|
||||
#include "swift/ClangImporter/ClangModule.h"
|
||||
#include "swift/Parse/Lexer.h"
|
||||
#include "swift/Strings.h"
|
||||
#include "clang/AST/DeclObjC.h"
|
||||
@@ -1718,6 +1719,14 @@ SmallVector<MacroDecl *, 1> namelookup::lookupMacros(DeclContext *dc,
|
||||
ctx.evaluator, UnqualifiedLookupRequest{moduleLookupDesc}, {});
|
||||
auto foundTypeDecl = moduleLookup.getSingleTypeResult();
|
||||
auto *moduleDecl = dyn_cast_or_null<ModuleDecl>(foundTypeDecl);
|
||||
|
||||
// When resolving macro names for imported entities, we look for any
|
||||
// loaded module.
|
||||
if (!moduleDecl && isa<ClangModuleUnit>(moduleScopeDC)) {
|
||||
moduleDecl = ctx.getLoadedModule(moduleName.getBaseIdentifier());
|
||||
moduleScopeDC = moduleDecl;
|
||||
}
|
||||
|
||||
if (!moduleDecl)
|
||||
return {};
|
||||
|
||||
|
||||
@@ -4717,6 +4717,17 @@ bool ClangImporter::Implementation::lookupValue(SwiftLookupTable &table,
|
||||
}
|
||||
}
|
||||
|
||||
// Visit auxiliary declarations to check for name matches.
|
||||
decl->visitAuxiliaryDecls([&](Decl *aux) {
|
||||
if (auto auxValue = dyn_cast<ValueDecl>(aux)) {
|
||||
if (auxValue->getName().matchesRef(name) &&
|
||||
auxValue->getDeclContext()->isModuleScopeContext()) {
|
||||
consumer.foundDecl(auxValue, DeclVisibilityKind::VisibleAtTopLevel);
|
||||
anyMatching = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// If we have a declaration and nothing matched so far, try the names used
|
||||
// in other versions of Swift.
|
||||
if (auto clangDecl = entry.dyn_cast<clang::NamedDecl *>()) {
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "swift/AST/ASTContext.h"
|
||||
#include "swift/AST/ASTMangler.h"
|
||||
#include "swift/AST/ASTNode.h"
|
||||
#include "swift/AST/ASTPrinter.h"
|
||||
#include "swift/AST/DiagnosticsFrontend.h"
|
||||
#include "swift/AST/Expr.h"
|
||||
#include "swift/AST/FreestandingMacroExpansion.h"
|
||||
@@ -39,6 +40,7 @@
|
||||
#include "swift/Basic/Lazy.h"
|
||||
#include "swift/Basic/SourceManager.h"
|
||||
#include "swift/Basic/StringExtras.h"
|
||||
#include "swift/ClangImporter/ClangModule.h"
|
||||
#include "swift/Bridging/ASTGen.h"
|
||||
#include "swift/Bridging/Macros.h"
|
||||
#include "swift/Demangling/Demangler.h"
|
||||
@@ -1021,7 +1023,10 @@ createMacroSourceFile(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
auto macroSourceFile = new (ctx) SourceFile(
|
||||
*dc->getParentModule(), SourceFileKind::MacroExpansion, macroBufferID,
|
||||
/*parsingOpts=*/{}, /*isPrimary=*/false);
|
||||
macroSourceFile->setImports(dc->getParentSourceFile()->getImports());
|
||||
if (auto parentSourceFile = dc->getParentSourceFile())
|
||||
macroSourceFile->setImports(parentSourceFile->getImports());
|
||||
else if (isa<ClangModuleUnit>(dc->getModuleScopeContext()))
|
||||
macroSourceFile->setImports({});
|
||||
return macroSourceFile;
|
||||
}
|
||||
|
||||
@@ -1346,8 +1351,44 @@ static SourceFile *evaluateAttachedMacro(MacroDecl *macro, Decl *attachedTo,
|
||||
if (!attrSourceFile)
|
||||
return nullptr;
|
||||
|
||||
auto declSourceFile =
|
||||
SourceFile *declSourceFile =
|
||||
moduleDecl->getSourceFileContainingLocation(attachedTo->getStartLoc());
|
||||
if (!declSourceFile && isa<ClangModuleUnit>(dc->getModuleScopeContext())) {
|
||||
// Pretty-print the declaration into a buffer so we can macro-expand
|
||||
// it.
|
||||
// FIXME: Turn this into a request.
|
||||
llvm::SmallString<128> buffer;
|
||||
{
|
||||
llvm::raw_svector_ostream out(buffer);
|
||||
StreamPrinter printer(out);
|
||||
attachedTo->print(
|
||||
printer,
|
||||
PrintOptions::printForDiagnostics(
|
||||
AccessLevel::Public,
|
||||
ctx.TypeCheckerOpts.PrintFullConvention));
|
||||
}
|
||||
|
||||
// Create the buffer.
|
||||
SourceManager &sourceMgr = ctx.SourceMgr;
|
||||
auto bufferID = sourceMgr.addMemBufferCopy(buffer);
|
||||
auto memBufferStartLoc = sourceMgr.getLocForBufferStart(bufferID);
|
||||
sourceMgr.setGeneratedSourceInfo(
|
||||
bufferID,
|
||||
GeneratedSourceInfo{
|
||||
GeneratedSourceInfo::PrettyPrinted,
|
||||
CharSourceRange(),
|
||||
CharSourceRange(memBufferStartLoc, buffer.size()),
|
||||
ASTNode(const_cast<Decl *>(attachedTo)).getOpaqueValue(),
|
||||
nullptr
|
||||
}
|
||||
);
|
||||
|
||||
// Create a source file to go with it.
|
||||
declSourceFile = new (ctx)
|
||||
SourceFile(*moduleDecl, SourceFileKind::Library, bufferID);
|
||||
moduleDecl->addAuxiliaryFile(*declSourceFile);
|
||||
}
|
||||
|
||||
if (!declSourceFile)
|
||||
return nullptr;
|
||||
|
||||
@@ -1486,13 +1527,18 @@ static SourceFile *evaluateAttachedMacro(MacroDecl *macro, Decl *attachedTo,
|
||||
if (auto var = dyn_cast<VarDecl>(attachedTo))
|
||||
searchDecl = var->getParentPatternBinding();
|
||||
|
||||
auto startLoc = searchDecl->getStartLoc();
|
||||
if (startLoc.isInvalid() && isa<ClangModuleUnit>(dc->getModuleScopeContext())) {
|
||||
startLoc = ctx.SourceMgr.getLocForBufferStart(*declSourceFile->getBufferID());
|
||||
}
|
||||
|
||||
BridgedStringRef evaluatedSourceOut{nullptr, 0};
|
||||
assert(!externalDef.isError());
|
||||
swift_Macros_expandAttachedMacro(
|
||||
&ctx.Diags, externalDef.get(), discriminator->c_str(),
|
||||
extendedType.c_str(), conformanceList.c_str(), getRawMacroRole(role),
|
||||
astGenAttrSourceFile, attr->AtLoc.getOpaquePointerValue(),
|
||||
astGenDeclSourceFile, searchDecl->getStartLoc().getOpaquePointerValue(),
|
||||
astGenDeclSourceFile, startLoc.getOpaquePointerValue(),
|
||||
astGenParentDeclSourceFile, parentDeclLoc, &evaluatedSourceOut);
|
||||
if (!evaluatedSourceOut.unbridged().data())
|
||||
return nullptr;
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
void async_divide(double x, double y, void (* _Nonnull completionHandler)(double x))
|
||||
__attribute__((swift_attr("@ModuleUser.AddAsync")));
|
||||
@@ -153,3 +153,7 @@ module IncompleteTypes {
|
||||
header "IncompleteTypes.h"
|
||||
export *
|
||||
}
|
||||
|
||||
module CompletionHandlerGlobals {
|
||||
header "completion_handler_globals.h"
|
||||
}
|
||||
@@ -887,6 +887,163 @@ public enum LeftHandOperandFinderMacro: ExpressionMacro {
|
||||
}
|
||||
}
|
||||
|
||||
extension SyntaxCollection {
|
||||
mutating func removeLast() {
|
||||
self.remove(at: self.index(before: self.endIndex))
|
||||
}
|
||||
}
|
||||
|
||||
public struct AddAsyncMacro: PeerMacro {
|
||||
public static func expansion<
|
||||
Context: MacroExpansionContext,
|
||||
Declaration: DeclSyntaxProtocol
|
||||
>(
|
||||
of node: AttributeSyntax,
|
||||
providingPeersOf declaration: Declaration,
|
||||
in context: Context
|
||||
) throws -> [DeclSyntax] {
|
||||
|
||||
// Only on functions at the moment.
|
||||
guard var funcDecl = declaration.as(FunctionDeclSyntax.self) else {
|
||||
throw CustomError.message("@addAsync only works on functions")
|
||||
}
|
||||
|
||||
// This only makes sense for non async functions.
|
||||
if funcDecl.signature.effectSpecifiers?.asyncSpecifier != nil {
|
||||
throw CustomError.message(
|
||||
"@addAsync requires an non async function"
|
||||
)
|
||||
}
|
||||
|
||||
// This only makes sense void functions
|
||||
if let resultType = funcDecl.signature.returnClause?.type,
|
||||
resultType.as(IdentifierTypeSyntax.self)?.name.text != "Void" {
|
||||
throw CustomError.message(
|
||||
"@addAsync requires an function that returns void"
|
||||
)
|
||||
}
|
||||
|
||||
// Requires a completion handler block as last parameter
|
||||
let completionHandlerParameter = funcDecl
|
||||
.signature
|
||||
.parameterClause
|
||||
.parameters.last?
|
||||
.type.as(AttributedTypeSyntax.self)?
|
||||
.baseType.as(FunctionTypeSyntax.self)
|
||||
guard let completionHandlerParameter else {
|
||||
throw CustomError.message(
|
||||
"@AddAsync requires an function that has a completion handler as last parameter"
|
||||
)
|
||||
}
|
||||
|
||||
// Completion handler needs to return Void
|
||||
if completionHandlerParameter.returnClause.type.as(IdentifierTypeSyntax.self)?.name.text != "Void" {
|
||||
throw CustomError.message(
|
||||
"@AddAsync requires an function that has a completion handler that returns Void"
|
||||
)
|
||||
}
|
||||
|
||||
let returnType = completionHandlerParameter.parameters.first?.type
|
||||
|
||||
let isResultReturn = returnType?.children(viewMode: .all).first?.description == "Result"
|
||||
let successReturnType =
|
||||
if isResultReturn {
|
||||
returnType!.as(IdentifierTypeSyntax.self)!.genericArgumentClause?.arguments.first!.argument
|
||||
} else {
|
||||
returnType
|
||||
}
|
||||
|
||||
// Remove completionHandler and comma from the previous parameter
|
||||
var newParameterList = funcDecl.signature.parameterClause.parameters
|
||||
newParameterList.removeLast()
|
||||
var newParameterListLastParameter = newParameterList.last!
|
||||
newParameterList.removeLast()
|
||||
newParameterListLastParameter.trailingTrivia = []
|
||||
newParameterListLastParameter.trailingComma = nil
|
||||
newParameterList.append(newParameterListLastParameter)
|
||||
|
||||
// Drop the @AddAsync attribute from the new declaration.
|
||||
let newAttributeList = funcDecl.attributes.filter {
|
||||
guard case let .attribute(attribute) = $0,
|
||||
let attributeType = attribute.attributeName.as(IdentifierTypeSyntax.self),
|
||||
let nodeType = node.attributeName.as(IdentifierTypeSyntax.self)
|
||||
else {
|
||||
return true
|
||||
}
|
||||
|
||||
return attributeType.name.text != nodeType.name.text
|
||||
}
|
||||
|
||||
let callArguments: [String] = newParameterList.map { param in
|
||||
let argName = param.secondName ?? param.firstName
|
||||
|
||||
let paramName = param.firstName
|
||||
if paramName.text != "_" {
|
||||
return "\(paramName.text): \(argName.text)"
|
||||
}
|
||||
|
||||
return "\(argName.text)"
|
||||
}
|
||||
|
||||
let switchBody: ExprSyntax =
|
||||
"""
|
||||
switch returnValue {
|
||||
case .success(let value):
|
||||
continuation.resume(returning: value)
|
||||
case .failure(let error):
|
||||
continuation.resume(throwing: error)
|
||||
}
|
||||
"""
|
||||
|
||||
let newBody: ExprSyntax =
|
||||
"""
|
||||
|
||||
\(raw: isResultReturn ? "try await withCheckedThrowingContinuation { continuation in" : "await withCheckedContinuation { continuation in")
|
||||
\(raw: funcDecl.name)(\(raw: callArguments.joined(separator: ", "))) { \(raw: returnType != nil ? "returnValue in" : "")
|
||||
|
||||
\(raw: isResultReturn ? switchBody : "continuation.resume(returning: \(raw: returnType != nil ? "returnValue" : "()"))")
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
// add async
|
||||
funcDecl.signature.effectSpecifiers = FunctionEffectSpecifiersSyntax(
|
||||
leadingTrivia: .space,
|
||||
asyncSpecifier: .keyword(.async),
|
||||
throwsClause: isResultReturn ? ThrowsClauseSyntax(throwsSpecifier: .keyword(.throws)) : nil
|
||||
)
|
||||
|
||||
// add result type
|
||||
if let successReturnType {
|
||||
funcDecl.signature.returnClause = ReturnClauseSyntax(
|
||||
leadingTrivia: .space,
|
||||
type: successReturnType.with(\.leadingTrivia, .space)
|
||||
)
|
||||
} else {
|
||||
funcDecl.signature.returnClause = nil
|
||||
}
|
||||
|
||||
// drop completion handler
|
||||
funcDecl.signature.parameterClause.parameters = newParameterList
|
||||
funcDecl.signature.parameterClause.trailingTrivia = []
|
||||
|
||||
funcDecl.body = CodeBlockSyntax(
|
||||
leftBrace: .leftBraceToken(leadingTrivia: .space),
|
||||
statements: CodeBlockItemListSyntax(
|
||||
[CodeBlockItemSyntax(item: .expr(newBody))]
|
||||
),
|
||||
rightBrace: .rightBraceToken(leadingTrivia: .newline)
|
||||
)
|
||||
|
||||
funcDecl.attributes = newAttributeList
|
||||
|
||||
funcDecl.leadingTrivia = .newlines(2)
|
||||
|
||||
return [DeclSyntax(funcDecl)]
|
||||
}
|
||||
}
|
||||
|
||||
public struct AddCompletionHandler: PeerMacro {
|
||||
public static func expansion(
|
||||
of node: AttributeSyntax,
|
||||
|
||||
26
test/Macros/expand_on_imported.swift
Normal file
26
test/Macros/expand_on_imported.swift
Normal file
@@ -0,0 +1,26 @@
|
||||
// REQUIRES: swift_swift_parser, executable_test
|
||||
|
||||
// RUN: %empty-directory(%t)
|
||||
// RUN: %host-build-swift -swift-version 5 -emit-library -o %t/%target-library-name(MacroDefinition) -module-name=MacroDefinition %S/Inputs/syntax_macro_definitions.swift -g -no-toolchain-stdlib-rpath -swift-version 5
|
||||
|
||||
// Diagnostics testing
|
||||
// RUN: %target-swift-frontend(mock-sdk: %clang-importer-sdk) -typecheck -verify -swift-version 5 -enable-experimental-feature CodeItemMacros -load-plugin-library %t/%target-library-name(MacroDefinition) -module-name ModuleUser %s
|
||||
|
||||
@attached(peer, names: overloaded)
|
||||
public macro AddAsync() = #externalMacro(module: "MacroDefinition", type: "AddAsyncMacro")
|
||||
|
||||
import CompletionHandlerGlobals
|
||||
|
||||
// Make sure that @AddAsync works at all.
|
||||
@AddAsync
|
||||
@available(SwiftStdlib 5.1, *)
|
||||
func asyncTest(_ value: Int, completionHandler: @escaping (String) -> Void) {
|
||||
completionHandler(String(value))
|
||||
}
|
||||
|
||||
@available(SwiftStdlib 5.1, *)
|
||||
func testAll(x: Double, y: Double) async {
|
||||
_ = await asyncTest(17)
|
||||
|
||||
let _: Double = await async_divide(1.0, 2.0)
|
||||
}
|
||||
Reference in New Issue
Block a user