mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Add `Differentiable` conformances for floating-point types to the `_Differentiation` module. The `TangentVector` associated type for floating-point types is `Self`. This design adheres to the differentiable programming manifesto: docs/DifferentiableProgramming.md. Partially resolves TF-1052.
47 lines
1.2 KiB
Swift
47 lines
1.2 KiB
Swift
// RUN: %target-typecheck-verify-swift
|
|
// REQUIRES: differentiable_programming
|
|
|
|
import _Differentiation
|
|
|
|
// Test `Differentiable` protocol conformances.
|
|
|
|
struct FloatWrapper {
|
|
var value: Float
|
|
}
|
|
extension FloatWrapper: AdditiveArithmetic {
|
|
static var zero: Self {
|
|
FloatWrapper(value: Float.zero)
|
|
}
|
|
static func + (lhs: Self, rhs: Self) -> Self {
|
|
return FloatWrapper(value: lhs.value + rhs.value)
|
|
}
|
|
static func - (lhs: Self, rhs: Self) -> Self {
|
|
return FloatWrapper(value: lhs.value + rhs.value)
|
|
}
|
|
}
|
|
extension FloatWrapper: Differentiable {
|
|
public typealias TangentVector = Self
|
|
}
|
|
|
|
struct Wrapper<T> {
|
|
var value: T
|
|
}
|
|
extension Wrapper: Equatable where T: Equatable {}
|
|
extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic {
|
|
static var zero: Self {
|
|
Wrapper(value: T.zero)
|
|
}
|
|
static func + (lhs: Self, rhs: Self) -> Self {
|
|
return Wrapper(value: lhs.value + rhs.value)
|
|
}
|
|
static func - (lhs: Self, rhs: Self) -> Self {
|
|
return Wrapper(value: lhs.value + rhs.value)
|
|
}
|
|
}
|
|
extension Wrapper: Differentiable where T: Differentiable {
|
|
typealias TangentVector = Wrapper<T.TangentVector>
|
|
mutating func move(along direction: TangentVector) {
|
|
value.move(along: direction.value)
|
|
}
|
|
}
|