[Macros] Handle macro overloading.

Allow more than one macro plugin to introduce a macro with the same
name, and let the constraint solver figure out which one to call. Also
eliminates a potential use-after-free if we somehow find additional
compiler plugins to load after having expanded a macro.
This commit is contained in:
Doug Gregor
2022-11-13 16:29:48 -08:00
parent bdf7762f55
commit 19d1588d13
6 changed files with 126 additions and 19 deletions

View File

@@ -352,12 +352,6 @@ public:
llvm::SmallPtrSet<DerivativeAttr *, 1>>
DerivativeAttrs;
/// Cache of compiler plugins keyed by their name.
llvm::StringMap<CompilerPlugin> LoadedPlugins;
/// Cache of loaded symbols.
llvm::StringMap<void *> LoadedSymbols;
private:
/// The current generation number, which reflects the number of
/// times that external modules have been loaded.
@@ -1452,8 +1446,11 @@ public:
/// The declared interface type of Builtin.TheTupleType.
BuiltinTupleType *getBuiltinTupleType();
/// Finds the loaded compiler plugin given its name.
CompilerPlugin *getLoadedPlugin(StringRef name);
/// Finds the loaded compiler plugins with the given name.
TinyPtrVector<CompilerPlugin *> getLoadedPlugins(StringRef name);
/// Add a loaded plugin with the given name.
void addLoadedPlugin(StringRef name, CompilerPlugin *plugin);
/// Finds the address of the given symbol. If `libraryHandleHint` is non-null,
/// search within the library.

View File

@@ -515,6 +515,15 @@ struct ASTContext::Implementation {
llvm::StringMap<OptionSet<SearchPathKind>> SearchPathsSet;
/// Cache of compiler plugins keyed by their name.
///
/// Names can be overloaded, so there can be multiple plugins with the same
/// name.
llvm::StringMap<TinyPtrVector<CompilerPlugin*>> LoadedPlugins;
/// Cache of loaded symbols.
llvm::StringMap<void *> LoadedSymbols;
/// The permanent arena.
Arena Permanent;
@@ -579,6 +588,12 @@ ASTContext::Implementation::Implementation()
ASTContext::Implementation::~Implementation() {
for (auto &cleanup : Cleanups)
cleanup();
for (const auto &pluginsByName : LoadedPlugins) {
for (auto plugin : pluginsByName.second) {
delete plugin;
}
}
}
ConstraintCheckerArenaRAII::
@@ -6047,16 +6062,21 @@ BuiltinTupleType *ASTContext::getBuiltinTupleType() {
return result;
}
CompilerPlugin *ASTContext::getLoadedPlugin(StringRef name) {
auto lookup = LoadedPlugins.find(name);
if (lookup == LoadedPlugins.end())
return nullptr;
return &lookup->second;
TinyPtrVector<CompilerPlugin *> ASTContext::getLoadedPlugins(StringRef name) {
auto &loadedPlugins = getImpl().LoadedPlugins;
auto lookup = loadedPlugins.find(name);
if (lookup == loadedPlugins.end())
return { };
return lookup->second;
}
void ASTContext::addLoadedPlugin(StringRef name, CompilerPlugin *plugin) {
getImpl().LoadedPlugins[name].push_back(plugin);
}
void *ASTContext::getAddressOfSymbol(const char *name,
void *libraryHandleHint) {
auto lookup = LoadedSymbols.try_emplace(name, nullptr);
auto lookup = getImpl().LoadedSymbols.try_emplace(name, nullptr);
void *&address = lookup.first->getValue();
#if !defined(_WIN32)
if (lookup.second) {

View File

@@ -188,9 +188,9 @@ void ASTContext::loadCompilerPlugins() {
swift_ASTGen_getMacroTypes(getter, &metatypesAddress, &metatypeCount);
ArrayRef<const void *> metatypes(metatypesAddress, metatypeCount);
for (const void *metatype : metatypes) {
CompilerPlugin plugin(metatype, lib, *this);
auto name = plugin.getName();
LoadedPlugins.try_emplace(name, std::move(plugin));
auto plugin = new CompilerPlugin(metatype, lib, *this);
auto name = plugin->getName();
addLoadedPlugin(name, plugin);
}
free(const_cast<void *>((const void *)metatypes.data()));
#endif // SWIFT_SWIFT_PARSER

View File

@@ -317,7 +317,7 @@ ArrayRef<MacroDecl *> MacroLookupRequest::evaluate(
// Look for a loaded plugin based on the macro name.
// FIXME: This API needs to be able to return multiple plugins, because
// several plugins could export a macro with the same name.
if (auto *plugin = ctx.getLoadedPlugin(macroName.str())) {
for (auto plugin: ctx.getLoadedPlugins(macroName.str())) {
if (auto pluginMacro = createPluginMacro(mod, macroName, plugin)) {
macros.push_back(pluginMacro);
}

View File

@@ -180,10 +180,95 @@ struct ColorLiteralMacro: _CompilerPlugin {
}
}
struct HSVColorLiteralMacro: _CompilerPlugin {
static func _name() -> (UnsafePointer<UInt8>, count: Int) {
var name = "customColorLiteral"
return name.withUTF8 { buffer in
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
result.initialize(from: buffer.baseAddress!, count: buffer.count)
return (UnsafePointer(result), count: buffer.count)
}
}
static func _genericSignature() -> (UnsafePointer<UInt8>?, count: Int) {
var genSig = "<T>"
return genSig.withUTF8 { buffer in
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
result.initialize(from: buffer.baseAddress!, count: buffer.count)
return (UnsafePointer(result), count: buffer.count)
}
}
static func _typeSignature() -> (UnsafePointer<UInt8>, count: Int) {
var typeSig =
"""
(
hue hue: Float, saturation saturation: Float, value value: Float
) -> T
"""
return typeSig.withUTF8 { buffer in
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
result.initialize(from: buffer.baseAddress!, count: buffer.count)
return (UnsafePointer(result), count: buffer.count)
}
}
static func _owningModule() -> (UnsafePointer<UInt8>, count: Int) {
var swiftModule = "Swift"
return swiftModule.withUTF8 { buffer in
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
result.initialize(from: buffer.baseAddress!, count: buffer.count)
return (UnsafePointer(result), count: buffer.count)
}
}
static func _supplementalSignatureModules() -> (UnsafePointer<UInt8>, count: Int) {
var nothing = ""
return nothing.withUTF8 { buffer in
let result = UnsafeMutablePointer<UInt8>.allocate(capacity: buffer.count)
result.initialize(from: buffer.baseAddress!, count: buffer.count)
return (UnsafePointer(result), count: buffer.count)
}
}
static func _kind() -> _CompilerPluginKind {
.expressionMacro
}
static func _rewrite(
targetModuleName: UnsafePointer<UInt8>,
targetModuleNameCount: Int,
filePath: UnsafePointer<UInt8>,
filePathCount: Int,
sourceFileText: UnsafePointer<UInt8>,
sourceFileTextCount: Int,
localSourceText: UnsafePointer<UInt8>,
localSourceTextCount: Int
) -> (UnsafePointer<UInt8>?, count: Int) {
let meeTextBuffer = UnsafeBufferPointer(
start: localSourceText, count: localSourceTextCount)
let meeText = String(decoding: meeTextBuffer, as: UTF8.self)
let prefix = "#customColorLiteral(hue:"
guard meeText.starts(with: prefix), meeText.last == ")" else {
return (nil, 0)
}
let expr = meeText.dropFirst(prefix.count).dropLast()
var resultString = ".init(_colorLiteralHue:\(expr))"
return resultString.withUTF8 { buffer in
let result = UnsafeMutableBufferPointer<UInt8>.allocate(
capacity: buffer.count + 1)
_ = result.initialize(from: buffer)
result[buffer.count] = 0
return (UnsafePointer(result.baseAddress), buffer.count)
}
}
}
public var allMacros: [Any.Type] {
[
StringifyMacro.self,
ColorLiteralMacro.self
ColorLiteralMacro.self,
HSVColorLiteralMacro.self
]
}

View File

@@ -37,7 +37,12 @@ let _ = #customStringify(["a", "b", "c"] + ["d", "e", "f"])
struct MyColor: _ExpressibleByColorLiteral {
init(_colorLiteralRed red: Float, green: Float, blue: Float, alpha: Float) { }
init(_colorLiteralHue hue: Float, saturation: Float, value: Float) { }
}
// CHECK: (macro_expansion_expr type='MyColor' {{.*}} name=customColorLiteral
let _: MyColor = #customColorLiteral(red: 0.5, green: 0.5, blue: 0.2, alpha: 0.9)
// CHECK: (macro_expansion_expr type='MyColor' {{.*}} name=customColorLiteral
let _: MyColor = #customColorLiteral(hue: 0.5, saturation: 0.5, value: 0.2)