mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Emit symbols for `@differentiable` and `@derivative` declaration attributes: - Differentiability witness symbols. - Derivative function (JVP/VJP) symbols. - Linear map (differential/pullback) symbols. Add TBDGen test.
57 lines
1.6 KiB
Swift
57 lines
1.6 KiB
Swift
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s
|
|
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s -O
|
|
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing
|
|
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing -O
|
|
|
|
import _Differentiation
|
|
|
|
@differentiable
|
|
public func topLevelDifferentiable(_ x: Float, _ y: Float) -> Float { x }
|
|
|
|
public func topLevelHasDerivative<T: Differentiable>(_ x: T) -> T {
|
|
x
|
|
}
|
|
|
|
@derivative(of: topLevelHasDerivative)
|
|
public func topLevelDerivative<T: Differentiable>(_ x: T) -> (
|
|
value: T, pullback: (T.TangentVector) -> T.TangentVector
|
|
) {
|
|
fatalError()
|
|
}
|
|
|
|
struct Struct: Differentiable {
|
|
var stored: Float
|
|
|
|
// Test property.
|
|
@differentiable
|
|
public var property: Float {
|
|
stored
|
|
}
|
|
|
|
// Test initializer.
|
|
@differentiable
|
|
public init(_ x: Float) {
|
|
stored = x
|
|
}
|
|
|
|
// Test method.
|
|
public func method(x: Float, y: Float) -> Float { x }
|
|
|
|
@derivative(of: method)
|
|
public func jvpMethod(x: Float, y: Float) -> (
|
|
value: Float, differential: (TangentVector, Float, Float) -> Float
|
|
) {
|
|
fatalError()
|
|
}
|
|
|
|
// Test subscript.
|
|
public subscript(x: Float) -> Float { x }
|
|
|
|
@derivative(of: subscript)
|
|
public func vjpSubscript(x: Float) -> (
|
|
value: Float, pullback: (Float) -> (TangentVector, Float)
|
|
) {
|
|
fatalError()
|
|
}
|
|
}
|