[Swiftify] Add MutableSpan support for std::span, and disable it (#80315)

__counted_by already had MutableSpan support, so add it for std::span
for parity. But since MutableSpan hasn't landed in the standard library
yet, disable emitting it to prevent compilation errors in expansions.

rdar://147882736
This commit is contained in:
Henrik G. Olsson
2025-03-27 16:37:31 -07:00
committed by GitHub
parent 98ce87c977
commit d1737b9c20
4 changed files with 149 additions and 63 deletions

View File

@@ -4,6 +4,9 @@ import SwiftSyntax
import SwiftSyntaxBuilder import SwiftSyntaxBuilder
import SwiftSyntaxMacros import SwiftSyntaxMacros
// Disable emitting 'MutableSpan' until it has landed
let enableMutableSpan = false
// avoids depending on SwiftifyImport.swift // avoids depending on SwiftifyImport.swift
// all instances are reparsed and reinstantiated by the macro anyways, // all instances are reparsed and reinstantiated by the macro anyways,
// so linking is irrelevant // so linking is irrelevant
@@ -213,22 +216,26 @@ func replaceBaseType(_ type: TypeSyntax, _ base: TypeSyntax) -> TypeSyntax {
// C++ type qualifiers, `const T` and `volatile T`, are encoded as fake generic // C++ type qualifiers, `const T` and `volatile T`, are encoded as fake generic
// types, `__cxxConst<T>` and `__cxxVolatile<T>` respectively. Remove those. // types, `__cxxConst<T>` and `__cxxVolatile<T>` respectively. Remove those.
func dropQualifierGenerics(_ type: TypeSyntax) -> TypeSyntax { // Second return value is true if __cxxConst was stripped.
guard let identifier = type.as(IdentifierTypeSyntax.self) else { return type } func dropQualifierGenerics(_ type: TypeSyntax) -> (TypeSyntax, Bool) {
guard let generic = identifier.genericArgumentClause else { return type } guard let identifier = type.as(IdentifierTypeSyntax.self) else { return (type, false) }
guard let genericArg = generic.arguments.first else { return type } guard let generic = identifier.genericArgumentClause else { return (type, false) }
guard case .type(let argType) = genericArg.argument else { return type } guard let genericArg = generic.arguments.first else { return (type, false) }
guard case .type(let argType) = genericArg.argument else { return (type, false) }
switch identifier.name.text { switch identifier.name.text {
case "__cxxConst", "__cxxVolatile": case "__cxxConst":
let (retType, _) = dropQualifierGenerics(argType)
return (retType, true)
case "__cxxVolatile":
return dropQualifierGenerics(argType) return dropQualifierGenerics(argType)
default: default:
return type return (type, false)
} }
} }
// The generated type names for template instantiations sometimes contain // The generated type names for template instantiations sometimes contain
// encoded qualifiers for disambiguation purposes. We need to remove those. // encoded qualifiers for disambiguation purposes. We need to remove those.
func dropCxxQualifiers(_ type: TypeSyntax) -> TypeSyntax { func dropCxxQualifiers(_ type: TypeSyntax) -> (TypeSyntax, Bool) {
if let attributed = type.as(AttributedTypeSyntax.self) { if let attributed = type.as(AttributedTypeSyntax.self) {
return dropCxxQualifiers(attributed.baseType) return dropCxxQualifiers(attributed.baseType)
} }
@@ -272,12 +279,20 @@ func getUnqualifiedStdName(_ type: String) -> String? {
func getSafePointerName(mut: Mutability, generateSpan: Bool, isRaw: Bool) -> TokenSyntax { func getSafePointerName(mut: Mutability, generateSpan: Bool, isRaw: Bool) -> TokenSyntax {
switch (mut, generateSpan, isRaw) { switch (mut, generateSpan, isRaw) {
case (.Immutable, true, true): return "RawSpan" case (.Immutable, true, true): return "RawSpan"
case (.Mutable, true, true): return "MutableRawSpan" case (.Mutable, true, true): return if enableMutableSpan {
"MutableRawSpan"
} else {
"RawSpan"
}
case (.Immutable, false, true): return "UnsafeRawBufferPointer" case (.Immutable, false, true): return "UnsafeRawBufferPointer"
case (.Mutable, false, true): return "UnsafeMutableRawBufferPointer" case (.Mutable, false, true): return "UnsafeMutableRawBufferPointer"
case (.Immutable, true, false): return "Span" case (.Immutable, true, false): return "Span"
case (.Mutable, true, false): return "MutableSpan" case (.Mutable, true, false): return if enableMutableSpan {
"MutableSpan"
} else {
"Span"
}
case (.Immutable, false, false): return "UnsafeBufferPointer" case (.Immutable, false, false): return "UnsafeBufferPointer"
case (.Mutable, false, false): return "UnsafeMutableBufferPointer" case (.Mutable, false, false): return "UnsafeMutableBufferPointer"
} }
@@ -317,6 +332,28 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool)
return try replaceTypeName(prev, token) return try replaceTypeName(prev, token)
} }
func isMutablePointerType(_ type: TypeSyntax) -> Bool {
if let optType = type.as(OptionalTypeSyntax.self) {
return isMutablePointerType(optType.wrappedType)
}
if let impOptType = type.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
return isMutablePointerType(impOptType.wrappedType)
}
if let attrType = type.as(AttributedTypeSyntax.self) {
return isMutablePointerType(attrType.baseType)
}
do {
let name = try getTypeName(type)
let text = name.text
guard let kind: Mutability = getPointerMutability(text: text) else {
return false
}
return kind == .Mutable
} catch _ {
return false
}
}
protocol BoundsCheckedThunkBuilder { protocol BoundsCheckedThunkBuilder {
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item]
@@ -401,7 +438,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
} }
} }
struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder { struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
public let base: BoundsCheckedThunkBuilder public let base: BoundsCheckedThunkBuilder
public let index: Int public let index: Int
public let signature: FunctionSignatureSyntax public let signature: FunctionSignatureSyntax
@@ -417,17 +454,7 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
-> (FunctionSignatureSyntax, Bool) { -> (FunctionSignatureSyntax, Bool) {
var types = argTypes var types = argTypes
let typeName = getUnattributedType(oldType).description types[index] = try newType
guard let desugaredType = typeMappings[typeName] else {
throw DiagnosticError(
"unable to desugar type with name '\(typeName)'", node: node)
}
let parsedDesugaredType = TypeSyntax("\(raw: getUnqualifiedStdName(desugaredType)!)")
let genericArg = TypeSyntax(parsedDesugaredType.as(IdentifierTypeSyntax.self)!
.genericArgumentClause!.arguments.first!.argument)!
types[index] = replaceBaseType(param.type,
TypeSyntax("Span<\(raw: dropCxxQualifiers(genericArg))>"))
return try base.buildFunctionSignature(types, returnType) return try base.buildFunctionSignature(types, returnType)
} }
@@ -440,12 +467,16 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
} }
} }
struct CxxSpanReturnThunkBuilder: BoundsCheckedThunkBuilder { struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
public let base: BoundsCheckedThunkBuilder public let base: BoundsCheckedThunkBuilder
public let signature: FunctionSignatureSyntax public let signature: FunctionSignatureSyntax
public let typeMappings: [String: String] public let typeMappings: [String: String]
public let node: SyntaxProtocol public let node: SyntaxProtocol
var oldType: TypeSyntax {
return signature.returnClause!.type
}
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] { func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
return try base.buildBoundsChecks() return try base.buildBoundsChecks()
} }
@@ -453,31 +484,83 @@ struct CxxSpanReturnThunkBuilder: BoundsCheckedThunkBuilder {
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
-> (FunctionSignatureSyntax, Bool) { -> (FunctionSignatureSyntax, Bool) {
assert(returnType == nil) assert(returnType == nil)
let typeName = getUnattributedType(signature.returnClause!.type).description
guard let desugaredType = typeMappings[typeName] else {
throw DiagnosticError(
"unable to desugar type with name '\(typeName)'", node: node)
}
let parsedDesugaredType = TypeSyntax("\(raw: getUnqualifiedStdName(desugaredType)!)")
let genericArg = TypeSyntax(parsedDesugaredType.as(IdentifierTypeSyntax.self)!
.genericArgumentClause!.arguments.first!.argument)!
let newType = replaceBaseType(signature.returnClause!.type,
TypeSyntax("Span<\(raw: dropCxxQualifiers(genericArg))>"))
return try base.buildFunctionSignature(argTypes, newType) return try base.buildFunctionSignature(argTypes, newType)
} }
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax { func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
let call = try base.buildFunctionCall(pointerArgs) let call = try base.buildFunctionCall(pointerArgs)
return "_cxxOverrideLifetime(Span(_unsafeCxxSpan: \(call)), copying: ())" let (_, isConst) = dropCxxQualifiers(try genericArg)
let cast = if isConst || !enableMutableSpan {
"Span"
} else {
"MutableSpan"
}
return "_cxxOverrideLifetime(\(raw: cast)(_unsafeCxxSpan: \(call)), copying: ())"
} }
} }
protocol PointerBoundsThunkBuilder: BoundsCheckedThunkBuilder { protocol BoundsThunkBuilder: BoundsCheckedThunkBuilder {
var oldType: TypeSyntax { get } var oldType: TypeSyntax { get }
var newType: TypeSyntax { get throws } var newType: TypeSyntax { get throws }
var nullable: Bool { get }
var signature: FunctionSignatureSyntax { get } var signature: FunctionSignatureSyntax { get }
var nonescaping: Bool { get } }
protocol SpanBoundsThunkBuilder: BoundsThunkBuilder {
var typeMappings: [String: String] { get }
var node: SyntaxProtocol { get }
}
extension SpanBoundsThunkBuilder {
var desugaredType: TypeSyntax {
get throws {
let typeName = try getUnattributedType(oldType).description
guard let desugaredTypeName = typeMappings[typeName] else {
throw DiagnosticError(
"unable to desugar type with name '\(typeName)'", node: node)
}
return TypeSyntax("\(raw: getUnqualifiedStdName(desugaredTypeName)!)")
}
}
var genericArg: TypeSyntax {
get throws {
guard let idType = try desugaredType.as(IdentifierTypeSyntax.self) else {
throw DiagnosticError(
"unexpected non-identifier type '\(try desugaredType)', expected a std::span type",
node: try desugaredType)
}
guard let genericArgumentClause = idType.genericArgumentClause else {
throw DiagnosticError(
"missing generic type argument clause expected after \(idType)", node: idType)
}
guard let firstArg = genericArgumentClause.arguments.first else {
throw DiagnosticError(
"expected at least 1 generic type argument for std::span type '\(idType)', found '\(genericArgumentClause)'",
node: genericArgumentClause.arguments)
}
guard let arg = TypeSyntax(firstArg.argument) else {
throw DiagnosticError(
"invalid generic type argument '\(firstArg.argument)'",
node: firstArg.argument)
}
return arg
}
}
var newType: TypeSyntax {
get throws {
let (strippedArg, isConst) = dropCxxQualifiers(try genericArg)
let mutablePrefix = if isConst || !enableMutableSpan {
""
} else {
"Mutable"
}
return replaceBaseType(
oldType,
TypeSyntax("\(raw: mutablePrefix)Span<\(raw: strippedArg)>"))
}
}
}
protocol PointerBoundsThunkBuilder: BoundsThunkBuilder {
var nullable: Bool { get }
var isSizedBy: Bool { get } var isSizedBy: Bool { get }
var generateSpan: Bool { get } var generateSpan: Bool { get }
} }
@@ -490,13 +573,12 @@ extension PointerBoundsThunkBuilder {
} }
} }
protocol ParamPointerBoundsThunkBuilder: PointerBoundsThunkBuilder { protocol ParamBoundsThunkBuilder: BoundsThunkBuilder {
var index: Int { get } var index: Int { get }
var nonescaping: Bool { get }
} }
extension ParamPointerBoundsThunkBuilder { extension ParamBoundsThunkBuilder {
var generateSpan: Bool { nonescaping }
var param: FunctionParameterSyntax { var param: FunctionParameterSyntax {
return getParam(signature, index) return getParam(signature, index)
} }
@@ -518,7 +600,7 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
public let isSizedBy: Bool public let isSizedBy: Bool
public let dependencies: [LifetimeDependence] public let dependencies: [LifetimeDependence]
var generateSpan: Bool { !dependencies.isEmpty } var generateSpan: Bool { !dependencies.isEmpty && (!isMutablePointerType(oldType) || enableMutableSpan)}
var oldType: TypeSyntax { var oldType: TypeSyntax {
return signature.returnClause!.type return signature.returnClause!.type
@@ -531,7 +613,7 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
} }
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] { func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
return [] return try base.buildBoundsChecks()
} }
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax { func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
@@ -548,7 +630,8 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
} }
} }
struct CountedOrSizedPointerThunkBuilder: ParamPointerBoundsThunkBuilder {
struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBoundsThunkBuilder {
public let base: BoundsCheckedThunkBuilder public let base: BoundsCheckedThunkBuilder
public let index: Int public let index: Int
public let countExpr: ExprSyntax public let countExpr: ExprSyntax
@@ -557,6 +640,8 @@ struct CountedOrSizedPointerThunkBuilder: ParamPointerBoundsThunkBuilder {
public let isSizedBy: Bool public let isSizedBy: Bool
public let skipTrivialCount: Bool public let skipTrivialCount: Bool
var generateSpan: Bool { nonescaping && (!isMutablePointerType(oldType) || enableMutableSpan) }
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
-> (FunctionSignatureSyntax, Bool) { -> (FunctionSignatureSyntax, Bool) {
var types = argTypes var types = argTypes

View File

@@ -1,6 +1,8 @@
// REQUIRES: swift_feature_SafeInteropWrappers // REQUIRES: swift_feature_SafeInteropWrappers
// REQUIRES: swift_feature_LifetimeDependence // REQUIRES: swift_feature_LifetimeDependence
// This emits UnsafeMutableBufferPointer until MutableSpan has landed
// RUN: %target-swift-ide-test -print-module -module-to-print=CountedByNoEscapeClang -plugin-path %swift-plugin-dir -I %S/Inputs -source-filename=x -enable-experimental-feature SafeInteropWrappers -enable-experimental-feature LifetimeDependence | %FileCheck %s // RUN: %target-swift-ide-test -print-module -module-to-print=CountedByNoEscapeClang -plugin-path %swift-plugin-dir -I %S/Inputs -source-filename=x -enable-experimental-feature SafeInteropWrappers -enable-experimental-feature LifetimeDependence | %FileCheck %s
// swift-ide-test doesn't currently typecheck the macro expansions, so run the compiler as well // swift-ide-test doesn't currently typecheck the macro expansions, so run the compiler as well
@@ -11,15 +13,15 @@
import CountedByNoEscapeClang import CountedByNoEscapeClang
// CHECK: @_alwaysEmitIntoClient public func complexExpr(_ len: Int{{.*}}, _ offset: Int{{.*}}, _ p: MutableSpan<Int{{.*}}>) // CHECK: @_alwaysEmitIntoClient public func complexExpr(_ len: Int{{.*}}, _ offset: Int{{.*}}, _ p: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func nonnull(_ p: MutableSpan<Int{{.*}}>) // CHECK-NEXT: @_alwaysEmitIntoClient public func nonnull(_ p: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func nullUnspecified(_ p: MutableSpan<Int{{.*}}>) // CHECK-NEXT: @_alwaysEmitIntoClient public func nullUnspecified(_ p: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @lifetime(copy p) // CHECK-NEXT: @lifetime(copy p)
// CHECK-NEXT: @_alwaysEmitIntoClient public func returnLifetimeBound(_ len1: Int32, _ p: MutableSpan<Int32>) -> MutableSpan<Int32> // CHECK-NEXT: @_alwaysEmitIntoClient public func returnLifetimeBound(_ len1: Int32, _ p: UnsafeMutableBufferPointer<Int32>) -> UnsafeMutableBufferPointer<Int32>
// CHECK-NEXT: @_alwaysEmitIntoClient @_disfavoredOverload public func returnPointer(_ len: Int{{.*}}) -> UnsafeMutableBufferPointer<Int{{.*}}> // CHECK-NEXT: @_alwaysEmitIntoClient @_disfavoredOverload public func returnPointer(_ len: Int{{.*}}) -> UnsafeMutableBufferPointer<Int{{.*}}>
// CHECK-NEXT: @_alwaysEmitIntoClient public func shared(_ len: Int{{.*}}, _ p1: MutableSpan<Int{{.*}}>, _ p2: MutableSpan<Int{{.*}}>) // CHECK-NEXT: @_alwaysEmitIntoClient public func shared(_ len: Int{{.*}}, _ p1: UnsafeMutableBufferPointer<Int{{.*}}>, _ p2: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func simple(_ p: MutableSpan<Int{{.*}}>) // CHECK-NEXT: @_alwaysEmitIntoClient public func simple(_ p: UnsafeMutableBufferPointer<Int{{.*}}>)
// CHECK-NEXT: @_alwaysEmitIntoClient public func swiftAttr(_ p: MutableSpan<Int{{.*}}>) // CHECK-NEXT: @_alwaysEmitIntoClient public func swiftAttr(_ p: UnsafeMutableBufferPointer<Int{{.*}}>)
@inlinable @inlinable
public func callReturnPointer() { public func callReturnPointer() {

View File

@@ -1,16 +1,15 @@
// REQUIRES: swift_swift_parser // REQUIRES: swift_swift_parser
// RUN: not %target-swift-frontend %s -swift-version 5 -module-name main -disable-availability-checking -typecheck -plugin-path %swift-plugin-dir -strict-memory-safety -warnings-as-errors -dump-macro-expansions > %t.log 2>&1 // RUN: %target-swift-frontend %s -swift-version 5 -module-name main -disable-availability-checking -typecheck -plugin-path %swift-plugin-dir -strict-memory-safety -warnings-as-errors -dump-macro-expansions > %t.log 2>&1
// RUN: %FileCheck --match-full-lines %s < %t.log // RUN: %FileCheck --match-full-lines %s < %t.log
@_SwiftifyImport(.countedBy(pointer: .param(1), count: "len"), .nonescaping(pointer: .param(1))) @_SwiftifyImport(.countedBy(pointer: .param(1), count: "len"), .nonescaping(pointer: .param(1)))
func myFunc(_ ptr: UnsafeMutablePointer<CInt>, _ len: CInt) { func myFunc(_ ptr: UnsafeMutablePointer<CInt>, _ len: CInt) {
} }
// CHECK: @_alwaysEmitIntoClient // Emits UnsafeMutableBufferPointer until MutableSpan has landed
// CHECK-NEXT: func myFunc(_ ptr: MutableSpan<CInt>) {
// CHECK-NEXT: return unsafe ptr.withUnsafeBufferPointer { _ptrPtr in
// CHECK-NEXT: return unsafe myFunc(_ptrPtr.baseAddress!, CInt(exactly: ptr.count)!)
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK: @_alwaysEmitIntoClient
// CHECK-NEXT: func myFunc(_ ptr: UnsafeMutableBufferPointer<CInt>) {
// CHECK-NEXT: return unsafe myFunc(ptr.baseAddress!, CInt(exactly: ptr.count)!)
// CHECK-NEXT: }

View File

@@ -1,15 +1,15 @@
// REQUIRES: swift_swift_parser // REQUIRES: swift_swift_parser
// RUN: not %target-swift-frontend %s -swift-version 5 -module-name main -disable-availability-checking -typecheck -plugin-path %swift-plugin-dir -strict-memory-safety -warnings-as-errors -dump-macro-expansions > %t.log 2>&1 // RUN: %target-swift-frontend %s -swift-version 5 -module-name main -disable-availability-checking -typecheck -plugin-path %swift-plugin-dir -strict-memory-safety -warnings-as-errors -dump-macro-expansions > %t.log 2>&1
// RUN: %FileCheck --match-full-lines %s < %t.log // RUN: %FileCheck --match-full-lines %s < %t.log
@_SwiftifyImport(.sizedBy(pointer: .param(1), size: "size"), .nonescaping(pointer: .param(1))) @_SwiftifyImport(.sizedBy(pointer: .param(1), size: "size"), .nonescaping(pointer: .param(1)))
func myFunc(_ ptr: UnsafeMutableRawPointer, _ size: CInt) { func myFunc(_ ptr: UnsafeMutableRawPointer, _ size: CInt) {
} }
// Emits UnsafeMutableRawBufferPointer until MutableRawSpan has landed
// CHECK: @_alwaysEmitIntoClient // CHECK: @_alwaysEmitIntoClient
// CHECK-NEXT: func myFunc(_ ptr: MutableRawSpan) { // CHECK-NEXT: func myFunc(_ ptr: UnsafeMutableRawBufferPointer) {
// CHECK-NEXT: return unsafe ptr.withUnsafeBytes { _ptrPtr in // CHECK-NEXT: return unsafe myFunc(ptr.baseAddress!, CInt(exactly: ptr.count)!)
// CHECK-NEXT: return unsafe myFunc(_ptrPtr.baseAddress!, CInt(exactly: ptr.byteCount)!)
// CHECK-NEXT: }
// CHECK-NEXT: } // CHECK-NEXT: }