Files
swift-mirror/test/AutoDiff/stdlib/differentiable_protocol.swift
Dan Zheng 8aac6f9a1a [AutoDiff upstream] Conform floating-point types to Differentiable. (#28718)
Add `Differentiable` conformances for floating-point types to the
`_Differentiation` module. The `TangentVector` associated type for
floating-point types is `Self`.

This design adheres to the differentiable programming manifesto:
docs/DifferentiableProgramming.md.

Partially resolves TF-1052.
2019-12-11 18:43:09 -08:00

47 lines
1.2 KiB
Swift

// RUN: %target-typecheck-verify-swift
// REQUIRES: differentiable_programming
import _Differentiation
// Test `Differentiable` protocol conformances.
struct FloatWrapper {
var value: Float
}
extension FloatWrapper: AdditiveArithmetic {
static var zero: Self {
FloatWrapper(value: Float.zero)
}
static func + (lhs: Self, rhs: Self) -> Self {
return FloatWrapper(value: lhs.value + rhs.value)
}
static func - (lhs: Self, rhs: Self) -> Self {
return FloatWrapper(value: lhs.value + rhs.value)
}
}
extension FloatWrapper: Differentiable {
public typealias TangentVector = Self
}
struct Wrapper<T> {
var value: T
}
extension Wrapper: Equatable where T: Equatable {}
extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic {
static var zero: Self {
Wrapper(value: T.zero)
}
static func + (lhs: Self, rhs: Self) -> Self {
return Wrapper(value: lhs.value + rhs.value)
}
static func - (lhs: Self, rhs: Self) -> Self {
return Wrapper(value: lhs.value + rhs.value)
}
}
extension Wrapper: Differentiable where T: Differentiable {
typealias TangentVector = Wrapper<T.TangentVector>
mutating func move(along direction: TangentVector) {
value.move(along: direction.value)
}
}