mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
This makes it possible to mark a pointer with __sized_by when the pointee type definition is not included. The wrapper function has the same interface as if the parameter were a void pointer, since the stdlib has no `OpaqueBufferPointer` type. * use swift-ide-test for checking interop signatures * add xfail test for Span + Optional combo (Optional requires Escapable)
917 lines
33 KiB
Swift
917 lines
33 KiB
Swift
import SwiftDiagnostics
|
|
import SwiftParser
|
|
import SwiftSyntax
|
|
import SwiftSyntaxBuilder
|
|
import SwiftSyntaxMacros
|
|
|
|
protocol ParamInfo: CustomStringConvertible {
|
|
var description: String { get }
|
|
var original: SyntaxProtocol { get }
|
|
var pointerIndex: Int { get }
|
|
var nonescaping: Bool { get set }
|
|
|
|
func getBoundsCheckedThunkBuilder(
|
|
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax,
|
|
_ variant: Variant
|
|
) -> BoundsCheckedThunkBuilder
|
|
}
|
|
|
|
struct CxxSpan: ParamInfo {
|
|
var pointerIndex: Int
|
|
var nonescaping: Bool
|
|
var original: SyntaxProtocol
|
|
var typeMappings: [String: String]
|
|
|
|
var description: String {
|
|
return "std::span(pointer: \(pointerIndex), nonescaping: \(nonescaping))"
|
|
}
|
|
|
|
func getBoundsCheckedThunkBuilder(
|
|
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax,
|
|
_ variant: Variant
|
|
) -> BoundsCheckedThunkBuilder {
|
|
CxxSpanThunkBuilder(base: base, index: pointerIndex - 1, signature: funcDecl.signature,
|
|
typeMappings: typeMappings, node: original)
|
|
}
|
|
}
|
|
|
|
struct CountedBy: ParamInfo {
|
|
var pointerIndex: Int
|
|
var count: ExprSyntax
|
|
var sizedBy: Bool
|
|
var nonescaping: Bool
|
|
var original: SyntaxProtocol
|
|
|
|
var description: String {
|
|
if sizedBy {
|
|
return ".sizedBy(pointer: \(pointerIndex), size: \"\(count)\", nonescaping: \(nonescaping))"
|
|
}
|
|
return ".countedBy(pointer: \(pointerIndex), count: \"\(count)\", nonescaping: \(nonescaping))"
|
|
}
|
|
|
|
func getBoundsCheckedThunkBuilder(
|
|
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax,
|
|
_ variant: Variant
|
|
) -> BoundsCheckedThunkBuilder {
|
|
let funcParam = getParam(funcDecl, pointerIndex - 1)
|
|
let paramName = funcParam.secondName ?? funcParam.firstName
|
|
let isNullable = funcParam.type.is(OptionalTypeSyntax.self)
|
|
return CountedOrSizedPointerThunkBuilder(
|
|
base: base, index: pointerIndex - 1, countExpr: count,
|
|
name: paramName, nullable: isNullable, signature: funcDecl.signature,
|
|
nonescaping: nonescaping, isSizedBy: sizedBy)
|
|
}
|
|
}
|
|
|
|
struct EndedBy: ParamInfo {
|
|
var pointerIndex: Int
|
|
var endIndex: Int
|
|
var nonescaping: Bool
|
|
var original: SyntaxProtocol
|
|
|
|
var description: String {
|
|
return ".endedBy(start: \(pointerIndex), end: \(endIndex), nonescaping: \(nonescaping))"
|
|
}
|
|
|
|
func getBoundsCheckedThunkBuilder(
|
|
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax,
|
|
_ variant: Variant
|
|
) -> BoundsCheckedThunkBuilder {
|
|
let funcParam = getParam(funcDecl, pointerIndex - 1)
|
|
let paramName = funcParam.secondName ?? funcParam.firstName
|
|
let isNullable = funcParam.type.is(OptionalTypeSyntax.self)
|
|
return EndedByPointerThunkBuilder(
|
|
base: base, startIndex: pointerIndex - 1, endIndex: endIndex - 1,
|
|
name: paramName, nullable: isNullable, signature: funcDecl.signature, nonescaping: nonescaping
|
|
)
|
|
}
|
|
}
|
|
|
|
struct RuntimeError: Error {
|
|
let description: String
|
|
|
|
init(_ description: String) {
|
|
self.description = description
|
|
}
|
|
|
|
var errorDescription: String? {
|
|
description
|
|
}
|
|
}
|
|
|
|
struct DiagnosticError: Error {
|
|
let description: String
|
|
let node: SyntaxProtocol
|
|
let notes: [Note]
|
|
|
|
init(_ description: String, node: SyntaxProtocol, notes: [Note] = []) {
|
|
self.description = description
|
|
self.node = node
|
|
self.notes = notes
|
|
}
|
|
|
|
var errorDescription: String? {
|
|
description
|
|
}
|
|
}
|
|
|
|
enum Mutability {
|
|
case Immutable
|
|
case Mutable
|
|
}
|
|
|
|
func getTypeName(_ type: TypeSyntax) throws -> TokenSyntax {
|
|
switch type.kind {
|
|
case .memberType:
|
|
let memberType = type.as(MemberTypeSyntax.self)!
|
|
if !memberType.baseType.isSwiftCoreModule {
|
|
throw DiagnosticError(
|
|
"expected pointer type in Swift core module, got type \(type) with base type \(memberType.baseType)",
|
|
node: type)
|
|
}
|
|
return memberType.name
|
|
case .identifierType:
|
|
return type.as(IdentifierTypeSyntax.self)!.name
|
|
default:
|
|
throw DiagnosticError("expected pointer type, got \(type) with kind \(type.kind)", node: type)
|
|
}
|
|
}
|
|
|
|
func replaceTypeName(_ type: TypeSyntax, _ name: TokenSyntax) -> TypeSyntax {
|
|
if let memberType = type.as(MemberTypeSyntax.self) {
|
|
return TypeSyntax(memberType.with(\.name, name))
|
|
}
|
|
let idType = type.as(IdentifierTypeSyntax.self)!
|
|
return TypeSyntax(idType.with(\.name, name))
|
|
}
|
|
|
|
func getPointerMutability(text: String) -> Mutability? {
|
|
switch text {
|
|
case "UnsafePointer": return .Immutable
|
|
case "UnsafeMutablePointer": return .Mutable
|
|
case "UnsafeRawPointer": return .Immutable
|
|
case "UnsafeMutableRawPointer": return .Mutable
|
|
case "OpaquePointer": return .Immutable
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func isRawPointerType(text: String) -> Bool {
|
|
switch text {
|
|
case "UnsafeRawPointer": return true
|
|
case "UnsafeMutableRawPointer": return true
|
|
case "OpaquePointer": return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func getSafePointerName(mut: Mutability, generateSpan: Bool, isRaw: Bool) -> TokenSyntax {
|
|
switch (mut, generateSpan, isRaw) {
|
|
case (.Immutable, true, true): return "RawSpan"
|
|
case (.Mutable, true, true): return "MutableRawSpan"
|
|
case (.Immutable, false, true): return "UnsafeRawBufferPointer"
|
|
case (.Mutable, false, true): return "UnsafeMutableRawBufferPointer"
|
|
|
|
case (.Immutable, true, false): return "Span"
|
|
case (.Mutable, true, false): return "MutableSpan"
|
|
case (.Immutable, false, false): return "UnsafeBufferPointer"
|
|
case (.Mutable, false, false): return "UnsafeMutableBufferPointer"
|
|
}
|
|
}
|
|
|
|
func transformType(_ prev: TypeSyntax, _ variant: Variant, _ isSizedBy: Bool) throws -> TypeSyntax {
|
|
if let optType = prev.as(OptionalTypeSyntax.self) {
|
|
return TypeSyntax(
|
|
optType.with(\.wrappedType, try transformType(optType.wrappedType, variant, isSizedBy)))
|
|
}
|
|
if let impOptType = prev.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
|
|
return try transformType(impOptType.wrappedType, variant, isSizedBy)
|
|
}
|
|
let name = try getTypeName(prev)
|
|
let text = name.text
|
|
let isRaw = isRawPointerType(text: text)
|
|
if isRaw && !isSizedBy {
|
|
throw DiagnosticError("raw pointers only supported for SizedBy", node: name)
|
|
}
|
|
if !isRaw && isSizedBy {
|
|
throw DiagnosticError("SizedBy only supported for raw pointers", node: name)
|
|
}
|
|
|
|
guard let kind: Mutability = getPointerMutability(text: text) else {
|
|
throw DiagnosticError(
|
|
"expected Unsafe[Mutable][Raw]Pointer type for type \(prev)"
|
|
+ " - first type token is '\(text)'", node: name)
|
|
}
|
|
let token = getSafePointerName(mut: kind, generateSpan: variant.generateSpan, isRaw: isSizedBy)
|
|
if isSizedBy {
|
|
return TypeSyntax(IdentifierTypeSyntax(name: token))
|
|
}
|
|
return replaceTypeName(prev, token)
|
|
}
|
|
|
|
struct Variant {
|
|
public let generateSpan: Bool
|
|
public let skipTrivialCount: Bool
|
|
}
|
|
|
|
protocol BoundsCheckedThunkBuilder {
|
|
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax
|
|
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item]
|
|
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
|
-> FunctionSignatureSyntax
|
|
}
|
|
|
|
func getParam(_ signature: FunctionSignatureSyntax, _ paramIndex: Int) -> FunctionParameterSyntax {
|
|
let params = signature.parameterClause.parameters
|
|
if paramIndex > 0 {
|
|
return params[params.index(params.startIndex, offsetBy: paramIndex)]
|
|
} else {
|
|
return params[params.startIndex]
|
|
}
|
|
}
|
|
|
|
func getParam(_ funcDecl: FunctionDeclSyntax, _ paramIndex: Int) -> FunctionParameterSyntax {
|
|
return getParam(funcDecl.signature, paramIndex)
|
|
}
|
|
|
|
struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
|
|
let base: FunctionDeclSyntax
|
|
init(_ function: FunctionDeclSyntax) {
|
|
base = function
|
|
}
|
|
|
|
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item] {
|
|
return []
|
|
}
|
|
|
|
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
|
-> FunctionSignatureSyntax
|
|
{
|
|
var newParams = base.signature.parameterClause.parameters.enumerated().filter {
|
|
let type = argTypes[$0.offset]
|
|
// filter out deleted parameters, i.e. ones where argTypes[i] _contains_ nil
|
|
return type == nil || type! != nil
|
|
}.map { (i: Int, e: FunctionParameterSyntax) in
|
|
e.with(\.type, (argTypes[i] ?? e.type)!)
|
|
}
|
|
let last = newParams.popLast()!
|
|
newParams.append(last.with(\.trailingComma, nil))
|
|
|
|
return base.signature.with(\.parameterClause.parameters, FunctionParameterListSyntax(newParams))
|
|
}
|
|
|
|
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax], _: Variant) throws -> ExprSyntax {
|
|
let functionRef = DeclReferenceExprSyntax(baseName: base.name)
|
|
let args: [ExprSyntax] = base.signature.parameterClause.parameters.enumerated()
|
|
.map { (i: Int, param: FunctionParameterSyntax) in
|
|
let name = param.secondName ?? param.firstName
|
|
let declref = DeclReferenceExprSyntax(baseName: name)
|
|
return pointerArgs[i] ?? ExprSyntax(declref)
|
|
}
|
|
let labels: [TokenSyntax?] = base.signature.parameterClause.parameters.map { param in
|
|
let firstName = param.firstName.trimmed
|
|
if firstName.text == "_" {
|
|
return nil
|
|
}
|
|
return firstName
|
|
}
|
|
let labeledArgs: [LabeledExprSyntax] = zip(labels, args).enumerated().map { (i, e) in
|
|
let (label, arg) = e
|
|
var comma: TokenSyntax? = nil
|
|
if i < args.count - 1 {
|
|
comma = .commaToken()
|
|
}
|
|
let colon: TokenSyntax? = label != nil ? .colonToken() : nil
|
|
return LabeledExprSyntax(label: label, colon: colon, expression: arg, trailingComma: comma)
|
|
}
|
|
return ExprSyntax(
|
|
FunctionCallExprSyntax(
|
|
calledExpression: functionRef, leftParen: .leftParenToken(),
|
|
arguments: LabeledExprListSyntax(labeledArgs), rightParen: .rightParenToken()))
|
|
}
|
|
}
|
|
|
|
struct CxxSpanThunkBuilder: BoundsCheckedThunkBuilder {
|
|
public let base: BoundsCheckedThunkBuilder
|
|
public let index: Int
|
|
public let signature: FunctionSignatureSyntax
|
|
public let typeMappings: [String: String]
|
|
public let node: SyntaxProtocol
|
|
|
|
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item] {
|
|
return []
|
|
}
|
|
|
|
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
|
-> FunctionSignatureSyntax {
|
|
var types = argTypes
|
|
let param = getParam(signature, index)
|
|
let typeName = try getTypeName(param.type).text;
|
|
guard let desugaredType = typeMappings[typeName] else {
|
|
throw DiagnosticError(
|
|
"unable to desugar type with name '\(typeName)'", node: node)
|
|
}
|
|
|
|
let parsedDesugaredType = try TypeSyntax("\(raw: desugaredType)")
|
|
types[index] = TypeSyntax(IdentifierTypeSyntax(name: "Span",
|
|
genericArgumentClause: parsedDesugaredType.as(IdentifierTypeSyntax.self)!.genericArgumentClause))
|
|
return try base.buildFunctionSignature(types, variant)
|
|
}
|
|
|
|
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax {
|
|
var args = pointerArgs
|
|
let param = getParam(signature, index)
|
|
let typeName = try getTypeName(param.type).text;
|
|
assert(args[index] == nil)
|
|
args[index] = ExprSyntax("\(raw: typeName)(\(raw: param.secondName ?? param.firstName))")
|
|
return try base.buildFunctionCall(args, variant)
|
|
}
|
|
}
|
|
|
|
protocol PointerBoundsThunkBuilder: BoundsCheckedThunkBuilder {
|
|
var name: TokenSyntax { get }
|
|
var nullable: Bool { get }
|
|
var signature: FunctionSignatureSyntax { get }
|
|
var nonescaping: Bool { get }
|
|
}
|
|
|
|
struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
|
|
public let base: BoundsCheckedThunkBuilder
|
|
public let index: Int
|
|
public let countExpr: ExprSyntax
|
|
public let name: TokenSyntax
|
|
public let nullable: Bool
|
|
public let signature: FunctionSignatureSyntax
|
|
public let nonescaping: Bool
|
|
public let isSizedBy: Bool
|
|
|
|
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
|
-> FunctionSignatureSyntax
|
|
{
|
|
var types = argTypes
|
|
let param = getParam(signature, index)
|
|
types[index] = try transformType(param.type, variant, isSizedBy)
|
|
if variant.skipTrivialCount {
|
|
if let countVar = countExpr.as(DeclReferenceExprSyntax.self) {
|
|
let i = try getParameterIndexForDeclRef(signature.parameterClause.parameters, countVar)
|
|
types[i] = nil as TypeSyntax?
|
|
}
|
|
}
|
|
return try base.buildFunctionSignature(types, variant)
|
|
}
|
|
|
|
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item] {
|
|
var res = try base.buildBoundsChecks(variant)
|
|
let countName: TokenSyntax = "_\(raw: name)Count"
|
|
let count: VariableDeclSyntax = try VariableDeclSyntax(
|
|
"let \(countName): some BinaryInteger = \(countExpr)")
|
|
res.append(CodeBlockItemSyntax.Item(count))
|
|
|
|
let countCheck = ExprSyntax(
|
|
"""
|
|
if \(getCount(variant)) < \(countName) || \(countName) < 0 {
|
|
fatalError("bounds check failure when calling unsafe function")
|
|
}
|
|
""")
|
|
res.append(CodeBlockItemSyntax.Item(countCheck))
|
|
return res
|
|
}
|
|
|
|
func unwrapIfNullable(_ expr: ExprSyntax) -> ExprSyntax {
|
|
if nullable {
|
|
return ExprSyntax(ForceUnwrapExprSyntax(expression: expr))
|
|
}
|
|
return expr
|
|
}
|
|
|
|
func unwrapIfNonnullable(_ expr: ExprSyntax) -> ExprSyntax {
|
|
if !nullable {
|
|
return ExprSyntax(ForceUnwrapExprSyntax(expression: expr))
|
|
}
|
|
return expr
|
|
}
|
|
|
|
func castIntToTargetType(expr: ExprSyntax, type: TypeSyntax) -> ExprSyntax {
|
|
if type.canRepresentBasicType(type: Int.self) {
|
|
return expr
|
|
}
|
|
return ExprSyntax("\(type)(exactly: \(expr))!")
|
|
}
|
|
|
|
func buildUnwrapCall(_ argOverrides: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax {
|
|
let unwrappedName = TokenSyntax("_\(name)Ptr")
|
|
var args = argOverrides
|
|
let argExpr = ExprSyntax("\(unwrappedName).baseAddress")
|
|
assert(args[index] == nil)
|
|
args[index] = try castPointerToOpaquePointer(unwrapIfNonnullable(argExpr))
|
|
let call = try base.buildFunctionCall(args, variant)
|
|
let ptrRef = unwrapIfNullable(ExprSyntax(DeclReferenceExprSyntax(baseName: name)))
|
|
|
|
let funcName = isSizedBy ? "withUnsafeBytes" : "withUnsafeBufferPointer"
|
|
let unwrappedCall = ExprSyntax(
|
|
"""
|
|
\(ptrRef).\(raw: funcName) { \(unwrappedName) in
|
|
return \(call)
|
|
}
|
|
""")
|
|
return unwrappedCall
|
|
}
|
|
|
|
func getCount(_ variant: Variant) -> ExprSyntax {
|
|
let countName = isSizedBy && variant.generateSpan ? "byteCount" : "count"
|
|
if nullable {
|
|
return ExprSyntax("\(name)?.\(raw: countName) ?? 0")
|
|
}
|
|
return ExprSyntax("\(name).\(raw: countName)")
|
|
}
|
|
|
|
func peelOptionalType(_ type: TypeSyntax) -> TypeSyntax {
|
|
if let optType = type.as(OptionalTypeSyntax.self) {
|
|
return optType.wrappedType
|
|
}
|
|
if let impOptType = type.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
|
|
return impOptType.wrappedType
|
|
}
|
|
return type
|
|
}
|
|
|
|
func castPointerToOpaquePointer(_ baseAddress: ExprSyntax) throws -> ExprSyntax {
|
|
let i = try getParameterIndexForParamName(signature.parameterClause.parameters, name)
|
|
let type = peelOptionalType(getParam(signature, i).type)
|
|
if type.canRepresentBasicType(type: OpaquePointer.self) {
|
|
return ExprSyntax("OpaquePointer(\(baseAddress))")
|
|
}
|
|
return baseAddress
|
|
}
|
|
|
|
func getPointerArg() throws -> ExprSyntax {
|
|
if nullable {
|
|
return ExprSyntax("\(name)?.baseAddress")
|
|
}
|
|
return ExprSyntax("\(name).baseAddress!")
|
|
}
|
|
|
|
func buildFunctionCall(_ argOverrides: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax
|
|
{
|
|
var args = argOverrides
|
|
if variant.skipTrivialCount {
|
|
assert(
|
|
countExpr.is(DeclReferenceExprSyntax.self) || countExpr.is(IntegerLiteralExprSyntax.self))
|
|
if let countVar = countExpr.as(DeclReferenceExprSyntax.self) {
|
|
let i = try getParameterIndexForDeclRef(signature.parameterClause.parameters, countVar)
|
|
assert(args[i] == nil)
|
|
args[i] = castIntToTargetType(expr: getCount(variant), type: getParam(signature, i).type)
|
|
}
|
|
}
|
|
assert(args[index] == nil)
|
|
if variant.generateSpan {
|
|
assert(nonescaping)
|
|
let unwrappedCall = try buildUnwrapCall(args, variant)
|
|
if nullable {
|
|
var nullArgs = args
|
|
nullArgs[index] = ExprSyntax(NilLiteralExprSyntax(nilKeyword: .keyword(.nil)))
|
|
return ExprSyntax(
|
|
"""
|
|
if \(name) == nil {
|
|
\(try base.buildFunctionCall(nullArgs, variant))
|
|
} else {
|
|
\(unwrappedCall)
|
|
}
|
|
""")
|
|
}
|
|
return unwrappedCall
|
|
}
|
|
|
|
args[index] = try castPointerToOpaquePointer(getPointerArg())
|
|
return try base.buildFunctionCall(args, variant)
|
|
}
|
|
}
|
|
|
|
struct EndedByPointerThunkBuilder: PointerBoundsThunkBuilder {
|
|
public let base: BoundsCheckedThunkBuilder
|
|
public let startIndex: Int
|
|
public let endIndex: Int
|
|
public let name: TokenSyntax
|
|
public let nullable: Bool
|
|
public let signature: FunctionSignatureSyntax
|
|
public let nonescaping: Bool
|
|
|
|
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
|
-> FunctionSignatureSyntax
|
|
{
|
|
throw RuntimeError("endedBy support not yet implemented")
|
|
}
|
|
|
|
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item] {
|
|
throw RuntimeError("endedBy support not yet implemented")
|
|
}
|
|
|
|
func buildFunctionCall(_ argOverrides: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax
|
|
{
|
|
throw RuntimeError("endedBy support not yet implemented")
|
|
}
|
|
}
|
|
|
|
func getArgumentByName(_ argumentList: LabeledExprListSyntax, _ name: String) throws -> ExprSyntax {
|
|
guard
|
|
let arg = argumentList.first(where: {
|
|
return $0.label?.text == name
|
|
})
|
|
else {
|
|
throw DiagnosticError(
|
|
"no argument with name '\(name)' in '\(argumentList)'", node: argumentList)
|
|
}
|
|
return arg.expression
|
|
}
|
|
|
|
func getOptionalArgumentByName(_ argumentList: LabeledExprListSyntax, _ name: String) -> ExprSyntax?
|
|
{
|
|
return argumentList.first(where: {
|
|
$0.label?.text == name
|
|
})?.expression
|
|
}
|
|
|
|
func getParameterIndexForParamName(
|
|
_ parameterList: FunctionParameterListSyntax, _ tok: TokenSyntax
|
|
) throws -> Int {
|
|
let name = tok.text
|
|
guard
|
|
let index = parameterList.enumerated().first(where: {
|
|
(_: Int, param: FunctionParameterSyntax) in
|
|
let paramenterName = param.secondName ?? param.firstName
|
|
return paramenterName.trimmed.text == name
|
|
})?.offset
|
|
else {
|
|
throw DiagnosticError("no parameter with name '\(name)' in '\(parameterList)'", node: tok)
|
|
}
|
|
return index
|
|
}
|
|
|
|
func getParameterIndexForDeclRef(
|
|
_ parameterList: FunctionParameterListSyntax, _ ref: DeclReferenceExprSyntax
|
|
) throws -> Int {
|
|
return try getParameterIndexForParamName((parameterList), ref.baseName)
|
|
}
|
|
|
|
/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
|
|
/// Depends on bounds, escapability and lifetime information for each pointer.
|
|
/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape,
|
|
/// for automatic application by ClangImporter when the C declaration is annotated
|
|
/// appropriately. Moreover, it can wrap C++ APIs using unsafe C++ types like
|
|
/// std::span with APIs that use their safer Swift equivalents.
|
|
public struct SwiftifyImportMacro: PeerMacro {
|
|
static func parseEnumName(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> String {
|
|
guard let calledExpr = enumConstructorExpr.calledExpression.as(MemberAccessExprSyntax.self)
|
|
else {
|
|
throw DiagnosticError(
|
|
"expected _SwiftifyInfo enum literal as argument, got '\(enumConstructorExpr)'",
|
|
node: enumConstructorExpr)
|
|
}
|
|
return calledExpr.declName.baseName.text
|
|
}
|
|
|
|
static func getIntLiteralValue(_ expr: ExprSyntax) throws -> Int {
|
|
guard let intLiteral = expr.as(IntegerLiteralExprSyntax.self) else {
|
|
throw DiagnosticError("expected integer literal, got '\(expr)'", node: expr)
|
|
}
|
|
guard let res = intLiteral.representedLiteralValue else {
|
|
throw DiagnosticError("expected integer literal, got '\(expr)'", node: expr)
|
|
}
|
|
return res
|
|
}
|
|
|
|
static func getBoolLiteralValue(_ expr: ExprSyntax) throws -> Bool {
|
|
guard let boolLiteral = expr.as(BooleanLiteralExprSyntax.self) else {
|
|
throw DiagnosticError("expected boolean literal, got '\(expr)'", node: expr)
|
|
}
|
|
switch boolLiteral.literal.tokenKind {
|
|
case .keyword(.true):
|
|
return true
|
|
case .keyword(.false):
|
|
return false
|
|
default:
|
|
throw DiagnosticError("expected bool literal, got '\(expr)'", node: expr)
|
|
}
|
|
}
|
|
|
|
static func parseCountedByEnum(
|
|
_ enumConstructorExpr: FunctionCallExprSyntax, _ signature: FunctionSignatureSyntax
|
|
) throws -> ParamInfo {
|
|
let argumentList = enumConstructorExpr.arguments
|
|
let pointerParamIndexArg = try getArgumentByName(argumentList, "pointer")
|
|
let pointerParamIndex: Int = try getIntLiteralValue(pointerParamIndexArg)
|
|
let countExprArg = try getArgumentByName(argumentList, "count")
|
|
guard let countExprStringLit = countExprArg.as(StringLiteralExprSyntax.self) else {
|
|
throw DiagnosticError(
|
|
"expected string literal for 'count' parameter, got \(countExprArg)", node: countExprArg)
|
|
}
|
|
let unwrappedCountExpr = ExprSyntax(stringLiteral: countExprStringLit.representedLiteralValue!)
|
|
if let countVar = unwrappedCountExpr.as(DeclReferenceExprSyntax.self) {
|
|
// Perform this lookup here so we can override the position to point to the string literal
|
|
// instead of line 1, column 1
|
|
do {
|
|
_ = try getParameterIndexForDeclRef(signature.parameterClause.parameters, countVar)
|
|
} catch let error as DiagnosticError {
|
|
throw DiagnosticError(error.description, node: countExprStringLit, notes: error.notes)
|
|
}
|
|
}
|
|
return CountedBy(
|
|
pointerIndex: pointerParamIndex, count: unwrappedCountExpr, sizedBy: false,
|
|
nonescaping: false, original: ExprSyntax(enumConstructorExpr))
|
|
}
|
|
|
|
static func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo {
|
|
let argumentList = enumConstructorExpr.arguments
|
|
let pointerParamIndexArg = try getArgumentByName(argumentList, "pointer")
|
|
let pointerParamIndex: Int = try getIntLiteralValue(pointerParamIndexArg)
|
|
let sizeExprArg = try getArgumentByName(argumentList, "size")
|
|
guard let sizeExprStringLit = sizeExprArg.as(StringLiteralExprSyntax.self) else {
|
|
throw DiagnosticError(
|
|
"expected string literal for 'size' parameter, got \(sizeExprArg)", node: sizeExprArg)
|
|
}
|
|
let unwrappedCountExpr = ExprSyntax(stringLiteral: sizeExprStringLit.representedLiteralValue!)
|
|
return CountedBy(
|
|
pointerIndex: pointerParamIndex, count: unwrappedCountExpr, sizedBy: true, nonescaping: false,
|
|
original: ExprSyntax(enumConstructorExpr))
|
|
}
|
|
|
|
static func parseEndedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo {
|
|
let argumentList = enumConstructorExpr.arguments
|
|
let startParamIndexArg = try getArgumentByName(argumentList, "start")
|
|
let startParamIndex: Int = try getIntLiteralValue(startParamIndexArg)
|
|
let endParamIndexArg = try getArgumentByName(argumentList, "end")
|
|
let endParamIndex: Int = try getIntLiteralValue(endParamIndexArg)
|
|
let nonescapingExprArg = getOptionalArgumentByName(argumentList, "nonescaping")
|
|
let nonescaping = try nonescapingExprArg != nil && getBoolLiteralValue(nonescapingExprArg!)
|
|
return EndedBy(
|
|
pointerIndex: startParamIndex, endIndex: endParamIndex, nonescaping: nonescaping,
|
|
original: ExprSyntax(enumConstructorExpr))
|
|
}
|
|
|
|
static func parseNonEscaping(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> Int {
|
|
let argumentList = enumConstructorExpr.arguments
|
|
let pointerParamIndexArg = try getArgumentByName(argumentList, "pointer")
|
|
let pointerParamIndex: Int = try getIntLiteralValue(pointerParamIndexArg)
|
|
return pointerParamIndex
|
|
}
|
|
|
|
static func parseTypeMappingParam(_ paramAST: LabeledExprSyntax?) throws -> [String: String]? {
|
|
guard let unwrappedParamAST = paramAST else {
|
|
return nil
|
|
}
|
|
let paramExpr = unwrappedParamAST.expression
|
|
guard let dictExpr = paramExpr.as(DictionaryExprSyntax.self) else {
|
|
return nil
|
|
}
|
|
var dict : [String: String] = [:]
|
|
switch dictExpr.content {
|
|
case .colon(_):
|
|
return dict
|
|
case .elements(let types):
|
|
for element in types {
|
|
guard let key = element.key.as(StringLiteralExprSyntax.self) else {
|
|
throw DiagnosticError("expected a string literal, got '\(element.key)'", node: element.key)
|
|
}
|
|
guard let value = element.value.as(StringLiteralExprSyntax.self) else {
|
|
throw DiagnosticError("expected a string literal, got '\(element.value)'", node: element.value)
|
|
}
|
|
dict[key.representedLiteralValue!] = value.representedLiteralValue!
|
|
}
|
|
default:
|
|
throw DiagnosticError("unknown dictionary literal", node: dictExpr)
|
|
}
|
|
return dict
|
|
}
|
|
|
|
static func parseCxxSpanParams(
|
|
_ signature: FunctionSignatureSyntax,
|
|
_ typeMappings: [String: String]?
|
|
) throws -> [ParamInfo] {
|
|
guard let typeMappings else {
|
|
return []
|
|
}
|
|
var result : [ParamInfo] = []
|
|
for (idx, param) in signature.parameterClause.parameters.enumerated() {
|
|
let typeName = try getTypeName(param.type).text;
|
|
if let desugaredType = typeMappings[typeName] {
|
|
if desugaredType.starts(with: "span") {
|
|
result.append(CxxSpan(pointerIndex: idx + 1, nonescaping: false,
|
|
original: param, typeMappings: typeMappings))
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
static func parseMacroParam(
|
|
_ paramAST: LabeledExprSyntax, _ signature: FunctionSignatureSyntax,
|
|
nonescapingPointers: inout Set<Int>
|
|
) throws -> ParamInfo? {
|
|
let paramExpr = paramAST.expression
|
|
guard let enumConstructorExpr = paramExpr.as(FunctionCallExprSyntax.self) else {
|
|
throw DiagnosticError(
|
|
"expected _SwiftifyInfo enum literal as argument, got '\(paramExpr)'", node: paramExpr)
|
|
}
|
|
let enumName = try parseEnumName(enumConstructorExpr)
|
|
switch enumName {
|
|
case "countedBy": return try parseCountedByEnum(enumConstructorExpr, signature)
|
|
case "sizedBy": return try parseSizedByEnum(enumConstructorExpr)
|
|
case "endedBy": return try parseEndedByEnum(enumConstructorExpr)
|
|
case "nonescaping":
|
|
let index = try parseNonEscaping(enumConstructorExpr)
|
|
nonescapingPointers.insert(index)
|
|
return nil
|
|
default:
|
|
throw DiagnosticError(
|
|
"expected 'countedBy', 'sizedBy', 'endedBy' or 'nonescaping', got '\(enumName)'",
|
|
node: enumConstructorExpr)
|
|
}
|
|
}
|
|
|
|
static func hasSafeVariants(_ parsedArgs: [ParamInfo]) -> Bool {
|
|
return parsedArgs.contains { $0.nonescaping }
|
|
}
|
|
|
|
static func hasTrivialCountVariants(_ parsedArgs: [ParamInfo]) -> Bool {
|
|
let countExprs = parsedArgs.compactMap {
|
|
switch $0 {
|
|
case let c as CountedBy: return c.count
|
|
default: return nil
|
|
}
|
|
}
|
|
let trivialCounts = countExprs.filter {
|
|
$0.is(DeclReferenceExprSyntax.self) || $0.is(IntegerLiteralExprSyntax.self)
|
|
}
|
|
// don't generate trivial count variants if there are any non-trivial counts
|
|
if trivialCounts.count < countExprs.count {
|
|
return false
|
|
}
|
|
let countVars = trivialCounts.filter { $0.is(DeclReferenceExprSyntax.self) }
|
|
let distinctCountVars = Set(
|
|
countVars.map {
|
|
return $0.as(DeclReferenceExprSyntax.self)!.baseName.text
|
|
})
|
|
// don't generate trivial count variants if two count expressions refer to the same parameter
|
|
return countVars.count == distinctCountVars.count
|
|
}
|
|
|
|
static func checkArgs(_ args: [ParamInfo], _ funcDecl: FunctionDeclSyntax) throws {
|
|
var argByIndex: [Int: ParamInfo] = [:]
|
|
let paramCount = funcDecl.signature.parameterClause.parameters.count
|
|
try args.forEach { pointerArg in
|
|
let i = pointerArg.pointerIndex
|
|
if i < 1 || i > paramCount {
|
|
let noteMessage =
|
|
paramCount > 0
|
|
? "function \(funcDecl.name) has parameter indices 1..\(paramCount)"
|
|
: "function \(funcDecl.name) has no parameters"
|
|
throw DiagnosticError(
|
|
"pointer index out of bounds", node: pointerArg.original,
|
|
notes: [
|
|
Note(node: Syntax(funcDecl.name), message: MacroExpansionNoteMessage(noteMessage))
|
|
])
|
|
}
|
|
if argByIndex[i] != nil {
|
|
throw DiagnosticError(
|
|
"multiple _SwiftifyInfos referring to parameter with index "
|
|
+ "\(i): \(pointerArg) and \(argByIndex[i]!)", node: pointerArg.original)
|
|
}
|
|
argByIndex[i] = pointerArg
|
|
}
|
|
}
|
|
|
|
static func setNonescapingPointers(_ args: inout [ParamInfo], _ nonescapingPointers: Set<Int>) {
|
|
for i in 0...args.count - 1 where nonescapingPointers.contains(args[i].pointerIndex) {
|
|
args[i].nonescaping = true
|
|
}
|
|
}
|
|
|
|
public static func expansion(
|
|
of node: AttributeSyntax,
|
|
providingPeersOf declaration: some DeclSyntaxProtocol,
|
|
in context: some MacroExpansionContext
|
|
) throws -> [DeclSyntax] {
|
|
do {
|
|
guard let funcDecl = declaration.as(FunctionDeclSyntax.self) else {
|
|
throw DiagnosticError("@_SwiftifyImport only works on functions", node: declaration)
|
|
}
|
|
|
|
let argumentList = node.arguments!.as(LabeledExprListSyntax.self)!
|
|
var arguments = Array<LabeledExprSyntax>(argumentList)
|
|
let typeMappings = try parseTypeMappingParam(arguments.last)
|
|
if typeMappings != nil {
|
|
arguments = arguments.dropLast()
|
|
}
|
|
var nonescapingPointers = Set<Int>()
|
|
var parsedArgs = try arguments.compactMap {
|
|
try parseMacroParam($0, funcDecl.signature, nonescapingPointers: &nonescapingPointers)
|
|
}
|
|
parsedArgs.append(contentsOf: try parseCxxSpanParams(funcDecl.signature, typeMappings))
|
|
setNonescapingPointers(&parsedArgs, nonescapingPointers)
|
|
parsedArgs = parsedArgs.filter {
|
|
!($0 is CxxSpan) || ($0 as! CxxSpan).nonescaping
|
|
}
|
|
try checkArgs(parsedArgs, funcDecl)
|
|
let baseBuilder = FunctionCallBuilder(funcDecl)
|
|
|
|
let variant = Variant(
|
|
generateSpan: hasSafeVariants(parsedArgs),
|
|
skipTrivialCount: hasTrivialCountVariants(parsedArgs))
|
|
|
|
let builder: BoundsCheckedThunkBuilder = parsedArgs.reduce(
|
|
baseBuilder,
|
|
{ (prev, parsedArg) in
|
|
parsedArg.getBoundsCheckedThunkBuilder(prev, funcDecl, variant)
|
|
})
|
|
let newSignature = try builder.buildFunctionSignature([:], variant)
|
|
let checks =
|
|
variant.skipTrivialCount
|
|
? [] as [CodeBlockItemSyntax]
|
|
: try builder.buildBoundsChecks(variant).map { e in
|
|
CodeBlockItemSyntax(leadingTrivia: "\n", item: e)
|
|
}
|
|
let call = CodeBlockItemSyntax(
|
|
item: CodeBlockItemSyntax.Item(
|
|
ReturnStmtSyntax(
|
|
returnKeyword: .keyword(.return, trailingTrivia: " "),
|
|
expression: try builder.buildFunctionCall([:], variant))))
|
|
let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call]))
|
|
let newFunc =
|
|
funcDecl
|
|
.with(\.signature, newSignature)
|
|
.with(\.body, body)
|
|
.with(
|
|
\.attributes,
|
|
funcDecl.attributes.filter { e in
|
|
switch e {
|
|
case .attribute(let attr):
|
|
// don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient
|
|
let name = attr.attributeName.as(IdentifierTypeSyntax.self)?.name.text
|
|
return name == nil || (name != "_SwiftifyImport" && name != "_alwaysEmitIntoClient")
|
|
default: return true
|
|
}
|
|
} + [
|
|
.attribute(
|
|
AttributeSyntax(
|
|
atSign: .atSignToken(),
|
|
attributeName: IdentifierTypeSyntax(name: "_alwaysEmitIntoClient")))
|
|
])
|
|
return [DeclSyntax(newFunc)]
|
|
} catch let error as DiagnosticError {
|
|
context.diagnose(
|
|
Diagnostic(
|
|
node: error.node, message: MacroExpansionErrorMessage(error.description),
|
|
notes: error.notes))
|
|
return []
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: syntax utils
|
|
extension TypeSyntaxProtocol {
|
|
public var isSwiftCoreModule: Bool {
|
|
guard let identifierType = self.as(IdentifierTypeSyntax.self) else {
|
|
return false
|
|
}
|
|
return identifierType.name.text == "Swift"
|
|
}
|
|
|
|
/// Check if this syntax could resolve to the type passed. Only supports types where the canonical type
|
|
/// can be named using only IdentifierTypeSyntax and MemberTypeSyntax. A non-exhaustive list of unsupported
|
|
/// types includes:
|
|
/// * array types
|
|
/// * function types
|
|
/// * optional types
|
|
/// * tuple types (including Void!)
|
|
/// The type syntax is allowed to use any level of qualified name for the type, e.g. Swift.Int.self
|
|
/// will match against both "Swift.Int" and "Int".
|
|
///
|
|
/// - Parameter type: Type to check against. NB: if passing a type alias, the canonical type will be used.
|
|
/// - Returns: true if `self` spells out some suffix of the fully qualified name of `type`, otherwise false
|
|
public func canRepresentBasicType(type: Any.Type) -> Bool {
|
|
let qualifiedTypeName = String(reflecting: type)
|
|
var typeNames = qualifiedTypeName.split(separator: ".")
|
|
var currType: TypeSyntaxProtocol = self
|
|
|
|
while !typeNames.isEmpty {
|
|
let typeName = typeNames.popLast()!
|
|
if let identifierType = currType.as(IdentifierTypeSyntax.self) {
|
|
// It doesn't matter whether this is the final element of typeNames, because we don't know
|
|
// surrounding context - the Foo.Bar.Baz type can be referred to as `Baz` inside Foo.Bar
|
|
return identifierType.name.text == typeName
|
|
} else if let memberType = currType.as(MemberTypeSyntax.self) {
|
|
if memberType.name.text != typeName {
|
|
return false
|
|
}
|
|
currType = memberType.baseType
|
|
} else {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
}
|