mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[Swiftify] Add return pointer support (#78571)
* Import __counted_by for function return values Instead of simply passing a parameter index to _SwiftifyInfo, the _SwiftifyExpr enum is introduced. It currently has two cases: - .param(index: Int), corresponding to the previous parameter index - .return, corresponding to the function's return value. ClangImporter is also updated to pass this new information along to _SwiftifyImport, allowing overloads with buffer pointer return types to be generated. The swiftified return values currently return Span when the return value is marked as nonescaping, despite this not being sound. This is a bug that will be fixed in the next commit, as the issue is greater than just for return values. * Fix Span variant selection There was an assumption that all converted pointers were either converted to Span-family pointers, or UnsafeBufferPointer-family pointers. This was not consistently handled, resulting in violating the `assert(nonescaping)` assert when the two were mixed. This patch removes the Variant struct, and instead each swiftified pointer separately tracks whether it should map to Span or UnsafeBufferPointer. This also fixes return pointers being incorrectly mapped to Span when marked as nonescaping.
This commit is contained in:
@@ -4,20 +4,56 @@ import SwiftSyntax
|
||||
import SwiftSyntaxBuilder
|
||||
import SwiftSyntaxMacros
|
||||
|
||||
// avoids depending on SwiftifyImport.swift
|
||||
// all instances are reparsed and reinstantiated by the macro anyways,
|
||||
// so linking is irrelevant
|
||||
enum SwiftifyExpr {
|
||||
case param(_ index: Int)
|
||||
case `return`
|
||||
}
|
||||
|
||||
extension SwiftifyExpr: CustomStringConvertible {
|
||||
var description: String {
|
||||
switch self {
|
||||
case .param(let index): return ".param(\(index))"
|
||||
case .return: return ".return"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protocol ParamInfo: CustomStringConvertible {
|
||||
var description: String { get }
|
||||
var original: SyntaxProtocol { get }
|
||||
var pointerIndex: Int { get }
|
||||
var pointerIndex: SwiftifyExpr { get }
|
||||
var nonescaping: Bool { get set }
|
||||
|
||||
func getBoundsCheckedThunkBuilder(
|
||||
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax,
|
||||
_ variant: Variant
|
||||
_ skipTrivialCount: Bool
|
||||
) -> BoundsCheckedThunkBuilder
|
||||
}
|
||||
|
||||
func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> TokenSyntax? {
|
||||
switch expr {
|
||||
case .param(let i):
|
||||
let funcParam = getParam(funcDecl, i - 1)
|
||||
return funcParam.secondName ?? funcParam.firstName
|
||||
default: return nil
|
||||
}
|
||||
}
|
||||
|
||||
func getSwiftifyExprType(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> TypeSyntax {
|
||||
switch expr {
|
||||
case .param(let i):
|
||||
let funcParam = getParam(funcDecl, i - 1)
|
||||
return funcParam.type
|
||||
case .return:
|
||||
return funcDecl.signature.returnClause!.type
|
||||
}
|
||||
}
|
||||
|
||||
struct CxxSpan: ParamInfo {
|
||||
var pointerIndex: Int
|
||||
var pointerIndex: SwiftifyExpr
|
||||
var nonescaping: Bool
|
||||
var original: SyntaxProtocol
|
||||
var typeMappings: [String: String]
|
||||
@@ -28,15 +64,22 @@ struct CxxSpan: ParamInfo {
|
||||
|
||||
func getBoundsCheckedThunkBuilder(
|
||||
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax,
|
||||
_ variant: Variant
|
||||
_ skipTrivialCount: Bool
|
||||
) -> BoundsCheckedThunkBuilder {
|
||||
CxxSpanThunkBuilder(base: base, index: pointerIndex - 1, signature: funcDecl.signature,
|
||||
typeMappings: typeMappings, node: original)
|
||||
switch pointerIndex {
|
||||
case .param(let i):
|
||||
return CxxSpanThunkBuilder(base: base, index: i - 1, signature: funcDecl.signature,
|
||||
typeMappings: typeMappings, node: original, nonescaping: nonescaping)
|
||||
case .return:
|
||||
// TODO: actually implement std::span in return position
|
||||
return CxxSpanThunkBuilder(base: base, index: -1, signature: funcDecl.signature,
|
||||
typeMappings: typeMappings, node: original, nonescaping: nonescaping)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct CountedBy: ParamInfo {
|
||||
var pointerIndex: Int
|
||||
var pointerIndex: SwiftifyExpr
|
||||
var count: ExprSyntax
|
||||
var sizedBy: Bool
|
||||
var nonescaping: Bool
|
||||
@@ -51,39 +94,20 @@ struct CountedBy: ParamInfo {
|
||||
|
||||
func getBoundsCheckedThunkBuilder(
|
||||
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax,
|
||||
_ variant: Variant
|
||||
_ skipTrivialCount: Bool
|
||||
) -> BoundsCheckedThunkBuilder {
|
||||
let funcParam = getParam(funcDecl, pointerIndex - 1)
|
||||
let paramName = funcParam.secondName ?? funcParam.firstName
|
||||
let isNullable = funcParam.type.is(OptionalTypeSyntax.self)
|
||||
return CountedOrSizedPointerThunkBuilder(
|
||||
base: base, index: pointerIndex - 1, countExpr: count,
|
||||
name: paramName, nullable: isNullable, signature: funcDecl.signature,
|
||||
nonescaping: nonescaping, isSizedBy: sizedBy)
|
||||
}
|
||||
}
|
||||
|
||||
struct EndedBy: ParamInfo {
|
||||
var pointerIndex: Int
|
||||
var endIndex: Int
|
||||
var nonescaping: Bool
|
||||
var original: SyntaxProtocol
|
||||
|
||||
var description: String {
|
||||
return ".endedBy(start: \(pointerIndex), end: \(endIndex), nonescaping: \(nonescaping))"
|
||||
}
|
||||
|
||||
func getBoundsCheckedThunkBuilder(
|
||||
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax,
|
||||
_ variant: Variant
|
||||
) -> BoundsCheckedThunkBuilder {
|
||||
let funcParam = getParam(funcDecl, pointerIndex - 1)
|
||||
let paramName = funcParam.secondName ?? funcParam.firstName
|
||||
let isNullable = funcParam.type.is(OptionalTypeSyntax.self)
|
||||
return EndedByPointerThunkBuilder(
|
||||
base: base, startIndex: pointerIndex - 1, endIndex: endIndex - 1,
|
||||
name: paramName, nullable: isNullable, signature: funcDecl.signature, nonescaping: nonescaping
|
||||
)
|
||||
switch pointerIndex {
|
||||
case .param(let i):
|
||||
return CountedOrSizedPointerThunkBuilder(
|
||||
base: base, index: i-1, countExpr: count,
|
||||
signature: funcDecl.signature,
|
||||
nonescaping: nonescaping, isSizedBy: sizedBy, skipTrivialCount: skipTrivialCount)
|
||||
case .return:
|
||||
return CountedOrSizedReturnPointerThunkBuilder(
|
||||
base: base, countExpr: count,
|
||||
signature: funcDecl.signature,
|
||||
nonescaping: nonescaping, isSizedBy: sizedBy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,13 +217,13 @@ func getSafePointerName(mut: Mutability, generateSpan: Bool, isRaw: Bool) -> Tok
|
||||
}
|
||||
}
|
||||
|
||||
func transformType(_ prev: TypeSyntax, _ variant: Variant, _ isSizedBy: Bool) throws -> TypeSyntax {
|
||||
func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool) throws -> TypeSyntax {
|
||||
if let optType = prev.as(OptionalTypeSyntax.self) {
|
||||
return TypeSyntax(
|
||||
optType.with(\.wrappedType, try transformType(optType.wrappedType, variant, isSizedBy)))
|
||||
optType.with(\.wrappedType, try transformType(optType.wrappedType, generateSpan, isSizedBy)))
|
||||
}
|
||||
if let impOptType = prev.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
|
||||
return try transformType(impOptType.wrappedType, variant, isSizedBy)
|
||||
return try transformType(impOptType.wrappedType, generateSpan, isSizedBy)
|
||||
}
|
||||
let name = try getTypeName(prev)
|
||||
let text = name.text
|
||||
@@ -216,22 +240,17 @@ func transformType(_ prev: TypeSyntax, _ variant: Variant, _ isSizedBy: Bool) th
|
||||
"expected Unsafe[Mutable][Raw]Pointer type for type \(prev)"
|
||||
+ " - first type token is '\(text)'", node: name)
|
||||
}
|
||||
let token = getSafePointerName(mut: kind, generateSpan: variant.generateSpan, isRaw: isSizedBy)
|
||||
let token = getSafePointerName(mut: kind, generateSpan: generateSpan, isRaw: isSizedBy)
|
||||
if isSizedBy {
|
||||
return TypeSyntax(IdentifierTypeSyntax(name: token))
|
||||
}
|
||||
return replaceTypeName(prev, token)
|
||||
}
|
||||
|
||||
struct Variant {
|
||||
public let generateSpan: Bool
|
||||
public let skipTrivialCount: Bool
|
||||
}
|
||||
|
||||
protocol BoundsCheckedThunkBuilder {
|
||||
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax
|
||||
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item]
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
||||
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax
|
||||
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item]
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
|
||||
-> FunctionSignatureSyntax
|
||||
}
|
||||
|
||||
@@ -254,13 +273,12 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
|
||||
base = function
|
||||
}
|
||||
|
||||
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item] {
|
||||
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
|
||||
return []
|
||||
}
|
||||
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
||||
-> FunctionSignatureSyntax
|
||||
{
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
|
||||
-> FunctionSignatureSyntax {
|
||||
var newParams = base.signature.parameterClause.parameters.enumerated().filter {
|
||||
let type = argTypes[$0.offset]
|
||||
// filter out deleted parameters, i.e. ones where argTypes[i] _contains_ nil
|
||||
@@ -271,10 +289,14 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
|
||||
let last = newParams.popLast()!
|
||||
newParams.append(last.with(\.trailingComma, nil))
|
||||
|
||||
return base.signature.with(\.parameterClause.parameters, FunctionParameterListSyntax(newParams))
|
||||
var sig = base.signature.with(\.parameterClause.parameters, FunctionParameterListSyntax(newParams))
|
||||
if returnType != nil {
|
||||
sig = sig.with(\.returnClause!.type, returnType!)
|
||||
}
|
||||
return sig
|
||||
}
|
||||
|
||||
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax], _: Variant) throws -> ExprSyntax {
|
||||
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
|
||||
let functionRef = DeclReferenceExprSyntax(baseName: base.name)
|
||||
let args: [ExprSyntax] = base.signature.parameterClause.parameters.enumerated()
|
||||
.map { (i: Int, param: FunctionParameterSyntax) in
|
||||
@@ -305,22 +327,23 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
struct CxxSpanThunkBuilder: BoundsCheckedThunkBuilder {
|
||||
struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
|
||||
public let base: BoundsCheckedThunkBuilder
|
||||
public let index: Int
|
||||
public let signature: FunctionSignatureSyntax
|
||||
public let typeMappings: [String: String]
|
||||
public let typeMappings: [String: String]
|
||||
public let node: SyntaxProtocol
|
||||
public let nonescaping: Bool
|
||||
let isSizedBy: Bool = false
|
||||
|
||||
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item] {
|
||||
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
|
||||
return []
|
||||
}
|
||||
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
|
||||
-> FunctionSignatureSyntax {
|
||||
var types = argTypes
|
||||
let param = getParam(signature, index)
|
||||
let typeName = try getTypeName(param.type).text;
|
||||
let typeName = try getTypeName(oldType).text
|
||||
guard let desugaredType = typeMappings[typeName] else {
|
||||
throw DiagnosticError(
|
||||
"unable to desugar type with name '\(typeName)'", node: node)
|
||||
@@ -330,53 +353,112 @@ struct CxxSpanThunkBuilder: BoundsCheckedThunkBuilder {
|
||||
let genericArg = TypeSyntax(parsedDesugaredType.as(IdentifierTypeSyntax.self)!
|
||||
.genericArgumentClause!.arguments.first!.argument)!
|
||||
types[index] = TypeSyntax("Span<\(raw: try getTypeName(genericArg).text)>")
|
||||
return try base.buildFunctionSignature(types, variant)
|
||||
return try base.buildFunctionSignature(types, returnType)
|
||||
}
|
||||
|
||||
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax {
|
||||
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
|
||||
var args = pointerArgs
|
||||
let param = getParam(signature, index)
|
||||
let typeName = try getTypeName(param.type).text;
|
||||
let typeName = try getTypeName(oldType).text
|
||||
assert(args[index] == nil)
|
||||
args[index] = ExprSyntax("\(raw: typeName)(\(raw: param.secondName ?? param.firstName))")
|
||||
return try base.buildFunctionCall(args, variant)
|
||||
args[index] = ExprSyntax("\(raw: typeName)(\(raw: name))")
|
||||
return try base.buildFunctionCall(args)
|
||||
}
|
||||
}
|
||||
|
||||
protocol PointerBoundsThunkBuilder: BoundsCheckedThunkBuilder {
|
||||
var name: TokenSyntax { get }
|
||||
var oldType: TypeSyntax { get }
|
||||
var newType: TypeSyntax { get throws }
|
||||
var nullable: Bool { get }
|
||||
var signature: FunctionSignatureSyntax { get }
|
||||
var nonescaping: Bool { get }
|
||||
var isSizedBy: Bool { get }
|
||||
var generateSpan: Bool { get }
|
||||
}
|
||||
|
||||
struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
|
||||
extension PointerBoundsThunkBuilder {
|
||||
var nullable: Bool { return oldType.is(OptionalTypeSyntax.self) }
|
||||
|
||||
var newType: TypeSyntax { get throws {
|
||||
return try transformType(oldType, generateSpan, isSizedBy) }
|
||||
}
|
||||
}
|
||||
|
||||
protocol ParamPointerBoundsThunkBuilder: PointerBoundsThunkBuilder {
|
||||
var index: Int { get }
|
||||
}
|
||||
|
||||
extension ParamPointerBoundsThunkBuilder {
|
||||
var generateSpan: Bool { nonescaping }
|
||||
|
||||
var param: FunctionParameterSyntax {
|
||||
return getParam(signature, index)
|
||||
}
|
||||
|
||||
var oldType: TypeSyntax {
|
||||
return param.type
|
||||
}
|
||||
|
||||
var name: TokenSyntax {
|
||||
return param.secondName ?? param.firstName
|
||||
}
|
||||
}
|
||||
|
||||
struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
|
||||
public let base: BoundsCheckedThunkBuilder
|
||||
public let index: Int
|
||||
public let countExpr: ExprSyntax
|
||||
public let name: TokenSyntax
|
||||
public let nullable: Bool
|
||||
public let signature: FunctionSignatureSyntax
|
||||
public let nonescaping: Bool
|
||||
public let isSizedBy: Bool
|
||||
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
||||
-> FunctionSignatureSyntax
|
||||
{
|
||||
var generateSpan: Bool = false // needs more lifetime information
|
||||
|
||||
var oldType: TypeSyntax {
|
||||
return signature.returnClause!.type
|
||||
}
|
||||
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
|
||||
-> FunctionSignatureSyntax {
|
||||
assert(returnType == nil)
|
||||
return try base.buildFunctionSignature(argTypes, newType)
|
||||
}
|
||||
|
||||
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
|
||||
return []
|
||||
}
|
||||
|
||||
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
|
||||
let call = try base.buildFunctionCall(pointerArgs)
|
||||
return
|
||||
"""
|
||||
\(raw: try newType)(start: \(call), count: Int(\(countExpr)))
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
struct CountedOrSizedPointerThunkBuilder: ParamPointerBoundsThunkBuilder {
|
||||
public let base: BoundsCheckedThunkBuilder
|
||||
public let index: Int
|
||||
public let countExpr: ExprSyntax
|
||||
public let signature: FunctionSignatureSyntax
|
||||
public let nonescaping: Bool
|
||||
public let isSizedBy: Bool
|
||||
public let skipTrivialCount: Bool
|
||||
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
|
||||
-> FunctionSignatureSyntax {
|
||||
var types = argTypes
|
||||
let param = getParam(signature, index)
|
||||
types[index] = try transformType(param.type, variant, isSizedBy)
|
||||
if variant.skipTrivialCount {
|
||||
types[index] = try newType
|
||||
if skipTrivialCount {
|
||||
if let countVar = countExpr.as(DeclReferenceExprSyntax.self) {
|
||||
let i = try getParameterIndexForDeclRef(signature.parameterClause.parameters, countVar)
|
||||
types[i] = nil as TypeSyntax?
|
||||
}
|
||||
}
|
||||
return try base.buildFunctionSignature(types, variant)
|
||||
return try base.buildFunctionSignature(types, returnType)
|
||||
}
|
||||
|
||||
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item] {
|
||||
var res = try base.buildBoundsChecks(variant)
|
||||
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
|
||||
var res = try base.buildBoundsChecks()
|
||||
let countName: TokenSyntax = "_\(raw: name)Count"
|
||||
let count: VariableDeclSyntax = try VariableDeclSyntax(
|
||||
"let \(countName): some BinaryInteger = \(countExpr)")
|
||||
@@ -384,7 +466,7 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
|
||||
|
||||
let countCheck = ExprSyntax(
|
||||
"""
|
||||
if \(getCount(variant)) < \(countName) || \(countName) < 0 {
|
||||
if \(getCount()) < \(countName) || \(countName) < 0 {
|
||||
fatalError("bounds check failure when calling unsafe function")
|
||||
}
|
||||
""")
|
||||
@@ -413,13 +495,13 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
|
||||
return ExprSyntax("\(type)(exactly: \(expr))!")
|
||||
}
|
||||
|
||||
func buildUnwrapCall(_ argOverrides: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax {
|
||||
func buildUnwrapCall(_ argOverrides: [Int: ExprSyntax]) throws -> ExprSyntax {
|
||||
let unwrappedName = TokenSyntax("_\(name)Ptr")
|
||||
var args = argOverrides
|
||||
let argExpr = ExprSyntax("\(unwrappedName).baseAddress")
|
||||
assert(args[index] == nil)
|
||||
args[index] = try castPointerToOpaquePointer(unwrapIfNonnullable(argExpr))
|
||||
let call = try base.buildFunctionCall(args, variant)
|
||||
let call = try base.buildFunctionCall(args)
|
||||
let ptrRef = unwrapIfNullable(ExprSyntax(DeclReferenceExprSyntax(baseName: name)))
|
||||
|
||||
let funcName = isSizedBy ? "withUnsafeBytes" : "withUnsafeBufferPointer"
|
||||
@@ -432,8 +514,8 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
|
||||
return unwrappedCall
|
||||
}
|
||||
|
||||
func getCount(_ variant: Variant) -> ExprSyntax {
|
||||
let countName = isSizedBy && variant.generateSpan ? "byteCount" : "count"
|
||||
func getCount() -> ExprSyntax {
|
||||
let countName = isSizedBy && generateSpan ? "byteCount" : "count"
|
||||
if nullable {
|
||||
return ExprSyntax("\(name)?.\(raw: countName) ?? 0")
|
||||
}
|
||||
@@ -466,29 +548,28 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
|
||||
return ExprSyntax("\(name).baseAddress!")
|
||||
}
|
||||
|
||||
func buildFunctionCall(_ argOverrides: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax
|
||||
{
|
||||
func buildFunctionCall(_ argOverrides: [Int: ExprSyntax]) throws -> ExprSyntax {
|
||||
var args = argOverrides
|
||||
if variant.skipTrivialCount {
|
||||
if skipTrivialCount {
|
||||
assert(
|
||||
countExpr.is(DeclReferenceExprSyntax.self) || countExpr.is(IntegerLiteralExprSyntax.self))
|
||||
if let countVar = countExpr.as(DeclReferenceExprSyntax.self) {
|
||||
let i = try getParameterIndexForDeclRef(signature.parameterClause.parameters, countVar)
|
||||
assert(args[i] == nil)
|
||||
args[i] = castIntToTargetType(expr: getCount(variant), type: getParam(signature, i).type)
|
||||
args[i] = castIntToTargetType(expr: getCount(), type: getParam(signature, i).type)
|
||||
}
|
||||
}
|
||||
assert(args[index] == nil)
|
||||
if variant.generateSpan {
|
||||
if generateSpan {
|
||||
assert(nonescaping)
|
||||
let unwrappedCall = try buildUnwrapCall(args, variant)
|
||||
let unwrappedCall = try buildUnwrapCall(args)
|
||||
if nullable {
|
||||
var nullArgs = args
|
||||
nullArgs[index] = ExprSyntax(NilLiteralExprSyntax(nilKeyword: .keyword(.nil)))
|
||||
return ExprSyntax(
|
||||
"""
|
||||
if \(name) == nil {
|
||||
\(try base.buildFunctionCall(nullArgs, variant))
|
||||
\(try base.buildFunctionCall(nullArgs))
|
||||
} else {
|
||||
\(unwrappedCall)
|
||||
}
|
||||
@@ -498,32 +579,7 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
|
||||
}
|
||||
|
||||
args[index] = try castPointerToOpaquePointer(getPointerArg())
|
||||
return try base.buildFunctionCall(args, variant)
|
||||
}
|
||||
}
|
||||
|
||||
struct EndedByPointerThunkBuilder: PointerBoundsThunkBuilder {
|
||||
public let base: BoundsCheckedThunkBuilder
|
||||
public let startIndex: Int
|
||||
public let endIndex: Int
|
||||
public let name: TokenSyntax
|
||||
public let nullable: Bool
|
||||
public let signature: FunctionSignatureSyntax
|
||||
public let nonescaping: Bool
|
||||
|
||||
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ variant: Variant) throws
|
||||
-> FunctionSignatureSyntax
|
||||
{
|
||||
throw RuntimeError("endedBy support not yet implemented")
|
||||
}
|
||||
|
||||
func buildBoundsChecks(_ variant: Variant) throws -> [CodeBlockItemSyntax.Item] {
|
||||
throw RuntimeError("endedBy support not yet implemented")
|
||||
}
|
||||
|
||||
func buildFunctionCall(_ argOverrides: [Int: ExprSyntax], _ variant: Variant) throws -> ExprSyntax
|
||||
{
|
||||
throw RuntimeError("endedBy support not yet implemented")
|
||||
return try base.buildFunctionCall(args)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -575,14 +631,28 @@ func getParameterIndexForDeclRef(
|
||||
/// appropriately. Moreover, it can wrap C++ APIs using unsafe C++ types like
|
||||
/// std::span with APIs that use their safer Swift equivalents.
|
||||
public struct SwiftifyImportMacro: PeerMacro {
|
||||
static func parseEnumName(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> String {
|
||||
guard let calledExpr = enumConstructorExpr.calledExpression.as(MemberAccessExprSyntax.self)
|
||||
static func parseEnumName(_ expr: ExprSyntax) throws -> String {
|
||||
var exprLocal = expr
|
||||
if let callExpr = expr.as(FunctionCallExprSyntax.self) {
|
||||
exprLocal = callExpr.calledExpression
|
||||
}
|
||||
guard let dotExpr = exprLocal.as(MemberAccessExprSyntax.self)
|
||||
else {
|
||||
throw DiagnosticError(
|
||||
"expected _SwiftifyInfo enum literal as argument, got '\(enumConstructorExpr)'",
|
||||
node: enumConstructorExpr)
|
||||
"expected enum literal as argument, got '\(expr)'",
|
||||
node: expr)
|
||||
}
|
||||
return calledExpr.declName.baseName.text
|
||||
return dotExpr.declName.baseName.text
|
||||
}
|
||||
|
||||
static func parseEnumArgs(_ expr: ExprSyntax) throws -> LabeledExprListSyntax {
|
||||
guard let callExpr = expr.as(FunctionCallExprSyntax.self)
|
||||
else {
|
||||
throw DiagnosticError(
|
||||
"expected call to enum constructor, got '\(expr)'",
|
||||
node: expr)
|
||||
}
|
||||
return callExpr.arguments
|
||||
}
|
||||
|
||||
static func getIntLiteralValue(_ expr: ExprSyntax) throws -> Int {
|
||||
@@ -609,12 +679,33 @@ public struct SwiftifyImportMacro: PeerMacro {
|
||||
}
|
||||
}
|
||||
|
||||
static func parseSwiftifyExpr(_ expr: ExprSyntax) throws -> SwiftifyExpr {
|
||||
let enumName = try parseEnumName(expr)
|
||||
switch enumName {
|
||||
case "param":
|
||||
let argumentList = try parseEnumArgs(expr)
|
||||
if argumentList.count != 1 {
|
||||
throw DiagnosticError(
|
||||
"expected single argument to _SwiftifyExpr.param, got \(argumentList.count) arguments",
|
||||
node: expr)
|
||||
}
|
||||
let pointerParamIndexArg = argumentList[argumentList.startIndex]
|
||||
let pointerParamIndex: Int = try getIntLiteralValue(pointerParamIndexArg.expression)
|
||||
return .param(pointerParamIndex)
|
||||
case "return": return .return
|
||||
default:
|
||||
throw DiagnosticError(
|
||||
"expected 'param' or 'return', got '\(enumName)'",
|
||||
node: expr)
|
||||
}
|
||||
}
|
||||
|
||||
static func parseCountedByEnum(
|
||||
_ enumConstructorExpr: FunctionCallExprSyntax, _ signature: FunctionSignatureSyntax
|
||||
) throws -> ParamInfo {
|
||||
let argumentList = enumConstructorExpr.arguments
|
||||
let pointerParamIndexArg = try getArgumentByName(argumentList, "pointer")
|
||||
let pointerParamIndex: Int = try getIntLiteralValue(pointerParamIndexArg)
|
||||
let pointerExprArg = try getArgumentByName(argumentList, "pointer")
|
||||
let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg)
|
||||
let countExprArg = try getArgumentByName(argumentList, "count")
|
||||
guard let countExprStringLit = countExprArg.as(StringLiteralExprSyntax.self) else {
|
||||
throw DiagnosticError(
|
||||
@@ -631,14 +722,14 @@ public struct SwiftifyImportMacro: PeerMacro {
|
||||
}
|
||||
}
|
||||
return CountedBy(
|
||||
pointerIndex: pointerParamIndex, count: unwrappedCountExpr, sizedBy: false,
|
||||
pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: false,
|
||||
nonescaping: false, original: ExprSyntax(enumConstructorExpr))
|
||||
}
|
||||
|
||||
static func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo {
|
||||
let argumentList = enumConstructorExpr.arguments
|
||||
let pointerParamIndexArg = try getArgumentByName(argumentList, "pointer")
|
||||
let pointerParamIndex: Int = try getIntLiteralValue(pointerParamIndexArg)
|
||||
let pointerExprArg = try getArgumentByName(argumentList, "pointer")
|
||||
let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg)
|
||||
let sizeExprArg = try getArgumentByName(argumentList, "size")
|
||||
guard let sizeExprStringLit = sizeExprArg.as(StringLiteralExprSyntax.self) else {
|
||||
throw DiagnosticError(
|
||||
@@ -646,27 +737,24 @@ public struct SwiftifyImportMacro: PeerMacro {
|
||||
}
|
||||
let unwrappedCountExpr = ExprSyntax(stringLiteral: sizeExprStringLit.representedLiteralValue!)
|
||||
return CountedBy(
|
||||
pointerIndex: pointerParamIndex, count: unwrappedCountExpr, sizedBy: true, nonescaping: false,
|
||||
pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: true, nonescaping: false,
|
||||
original: ExprSyntax(enumConstructorExpr))
|
||||
}
|
||||
|
||||
static func parseEndedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo {
|
||||
let argumentList = enumConstructorExpr.arguments
|
||||
let startParamIndexArg = try getArgumentByName(argumentList, "start")
|
||||
let startParamIndex: Int = try getIntLiteralValue(startParamIndexArg)
|
||||
let endParamIndexArg = try getArgumentByName(argumentList, "end")
|
||||
let endParamIndex: Int = try getIntLiteralValue(endParamIndexArg)
|
||||
let nonescapingExprArg = getOptionalArgumentByName(argumentList, "nonescaping")
|
||||
let nonescaping = try nonescapingExprArg != nil && getBoolLiteralValue(nonescapingExprArg!)
|
||||
return EndedBy(
|
||||
pointerIndex: startParamIndex, endIndex: endParamIndex, nonescaping: nonescaping,
|
||||
original: ExprSyntax(enumConstructorExpr))
|
||||
let startPointerExprArg = try getArgumentByName(argumentList, "start")
|
||||
let _: SwiftifyExpr = try parseSwiftifyExpr(startPointerExprArg)
|
||||
let endPointerExprArg = try getArgumentByName(argumentList, "end")
|
||||
let _: SwiftifyExpr = try parseSwiftifyExpr(endPointerExprArg)
|
||||
throw RuntimeError("endedBy support not yet implemented")
|
||||
}
|
||||
|
||||
static func parseNonEscaping(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> Int {
|
||||
let argumentList = enumConstructorExpr.arguments
|
||||
let pointerParamIndexArg = try getArgumentByName(argumentList, "pointer")
|
||||
let pointerParamIndex: Int = try getIntLiteralValue(pointerParamIndexArg)
|
||||
let pointerExprArg = try getArgumentByName(argumentList, "pointer")
|
||||
let pointerExpr: SwiftifyExpr = try parseSwiftifyExpr(pointerExprArg)
|
||||
let pointerParamIndex: Int = paramOrReturnIndex(pointerExpr)
|
||||
return pointerParamIndex
|
||||
}
|
||||
|
||||
@@ -711,7 +799,7 @@ public struct SwiftifyImportMacro: PeerMacro {
|
||||
if let desugaredType = typeMappings[typeName] {
|
||||
if let unqualifiedDesugaredType = getUnqualifiedStdName(desugaredType) {
|
||||
if unqualifiedDesugaredType.starts(with: "span<") {
|
||||
result.append(CxxSpan(pointerIndex: idx + 1, nonescaping: false,
|
||||
result.append(CxxSpan(pointerIndex: .param(idx + 1), nonescaping: false,
|
||||
original: param, typeMappings: typeMappings))
|
||||
}
|
||||
}
|
||||
@@ -729,7 +817,7 @@ public struct SwiftifyImportMacro: PeerMacro {
|
||||
throw DiagnosticError(
|
||||
"expected _SwiftifyInfo enum literal as argument, got '\(paramExpr)'", node: paramExpr)
|
||||
}
|
||||
let enumName = try parseEnumName(enumConstructorExpr)
|
||||
let enumName = try parseEnumName(paramExpr)
|
||||
switch enumName {
|
||||
case "countedBy": return try parseCountedByEnum(enumConstructorExpr, signature)
|
||||
case "sizedBy": return try parseSizedByEnum(enumConstructorExpr)
|
||||
@@ -745,10 +833,6 @@ public struct SwiftifyImportMacro: PeerMacro {
|
||||
}
|
||||
}
|
||||
|
||||
static func hasSafeVariants(_ parsedArgs: [ParamInfo]) -> Bool {
|
||||
return parsedArgs.contains { $0.nonescaping }
|
||||
}
|
||||
|
||||
static func hasTrivialCountVariants(_ parsedArgs: [ParamInfo]) -> Bool {
|
||||
let countExprs = parsedArgs.compactMap {
|
||||
switch $0 {
|
||||
@@ -774,16 +858,18 @@ public struct SwiftifyImportMacro: PeerMacro {
|
||||
|
||||
static func checkArgs(_ args: [ParamInfo], _ funcDecl: FunctionDeclSyntax) throws {
|
||||
var argByIndex: [Int: ParamInfo] = [:]
|
||||
var ret: ParamInfo? = nil
|
||||
let paramCount = funcDecl.signature.parameterClause.parameters.count
|
||||
try args.forEach { pointerArg in
|
||||
let i = pointerArg.pointerIndex
|
||||
try args.forEach { pointerInfo in
|
||||
switch pointerInfo.pointerIndex {
|
||||
case .param(let i):
|
||||
if i < 1 || i > paramCount {
|
||||
let noteMessage =
|
||||
paramCount > 0
|
||||
? "function \(funcDecl.name) has parameter indices 1..\(paramCount)"
|
||||
: "function \(funcDecl.name) has no parameters"
|
||||
throw DiagnosticError(
|
||||
"pointer index out of bounds", node: pointerArg.original,
|
||||
"pointer index out of bounds", node: pointerInfo.original,
|
||||
notes: [
|
||||
Note(node: Syntax(funcDecl.name), message: MacroExpansionNoteMessage(noteMessage))
|
||||
])
|
||||
@@ -791,14 +877,28 @@ public struct SwiftifyImportMacro: PeerMacro {
|
||||
if argByIndex[i] != nil {
|
||||
throw DiagnosticError(
|
||||
"multiple _SwiftifyInfos referring to parameter with index "
|
||||
+ "\(i): \(pointerArg) and \(argByIndex[i]!)", node: pointerArg.original)
|
||||
+ "\(i): \(pointerInfo) and \(argByIndex[i]!)", node: pointerInfo.original)
|
||||
}
|
||||
argByIndex[i] = pointerArg
|
||||
argByIndex[i] = pointerInfo
|
||||
case .return:
|
||||
if ret != nil {
|
||||
throw DiagnosticError(
|
||||
"multiple _SwiftifyInfos referring to return value: \(pointerInfo) and \(ret!)", node: pointerInfo.original)
|
||||
}
|
||||
ret = pointerInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static func paramOrReturnIndex(_ expr: SwiftifyExpr) -> Int {
|
||||
switch expr {
|
||||
case .param(let i): return i
|
||||
case .return: return -1
|
||||
}
|
||||
}
|
||||
|
||||
static func setNonescapingPointers(_ args: inout [ParamInfo], _ nonescapingPointers: Set<Int>) {
|
||||
for i in 0...args.count - 1 where nonescapingPointers.contains(args[i].pointerIndex) {
|
||||
for i in 0...args.count - 1 where nonescapingPointers.contains(paramOrReturnIndex(args[i].pointerIndex)) {
|
||||
args[i].nonescaping = true
|
||||
}
|
||||
}
|
||||
@@ -831,27 +931,25 @@ public struct SwiftifyImportMacro: PeerMacro {
|
||||
try checkArgs(parsedArgs, funcDecl)
|
||||
let baseBuilder = FunctionCallBuilder(funcDecl)
|
||||
|
||||
let variant = Variant(
|
||||
generateSpan: hasSafeVariants(parsedArgs),
|
||||
skipTrivialCount: hasTrivialCountVariants(parsedArgs))
|
||||
let skipTrivialCount = hasTrivialCountVariants(parsedArgs)
|
||||
|
||||
let builder: BoundsCheckedThunkBuilder = parsedArgs.reduce(
|
||||
baseBuilder,
|
||||
{ (prev, parsedArg) in
|
||||
parsedArg.getBoundsCheckedThunkBuilder(prev, funcDecl, variant)
|
||||
parsedArg.getBoundsCheckedThunkBuilder(prev, funcDecl, skipTrivialCount)
|
||||
})
|
||||
let newSignature = try builder.buildFunctionSignature([:], variant)
|
||||
let newSignature = try builder.buildFunctionSignature([:], nil)
|
||||
let checks =
|
||||
variant.skipTrivialCount
|
||||
skipTrivialCount
|
||||
? [] as [CodeBlockItemSyntax]
|
||||
: try builder.buildBoundsChecks(variant).map { e in
|
||||
: try builder.buildBoundsChecks().map { e in
|
||||
CodeBlockItemSyntax(leadingTrivia: "\n", item: e)
|
||||
}
|
||||
let call = CodeBlockItemSyntax(
|
||||
item: CodeBlockItemSyntax.Item(
|
||||
ReturnStmtSyntax(
|
||||
returnKeyword: .keyword(.return, trailingTrivia: " "),
|
||||
expression: try builder.buildFunctionCall([:], variant))))
|
||||
expression: try builder.buildFunctionCall([:]))))
|
||||
let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call]))
|
||||
let newFunc =
|
||||
funcDecl
|
||||
|
||||
Reference in New Issue
Block a user