mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
* spelling: abcdefghijklmnopqrstuvwxyz Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: clazz Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: collection Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: compressible Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: constituent Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: contiguous Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: convertibility Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: element Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: enforce Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: exhaustive Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: exhausts Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: existential Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: facilitate Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: ignored Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: incorporated Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: intersection Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: laziness Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: misaligned Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: overhaul Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: preamble Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: precondition Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: replacement Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: trailing Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: unambiguous Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: uncompressible Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> * spelling: world Signed-off-by: Josh Soref <jsoref@users.noreply.github.com> Co-authored-by: Josh Soref <jsoref@users.noreply.github.com>
230 lines
7.5 KiB
Swift
230 lines
7.5 KiB
Swift
// RUN: %target-run-simple-swift
|
|
// REQUIRES: executable_test
|
|
|
|
import _Differentiation
|
|
import StdlibUnittest
|
|
|
|
var TypeErasureTests = TestSuite("DifferentiableTypeErasure")
|
|
|
|
struct Vector: Differentiable, Equatable {
|
|
var x, y: Float
|
|
}
|
|
struct Generic<T: Differentiable & Equatable>: Differentiable, Equatable {
|
|
var x: T
|
|
}
|
|
|
|
extension AnyDerivative {
|
|
// This exists only to facilitate testing.
|
|
func moved(along offset: TangentVector) -> Self {
|
|
var result = self
|
|
result.move(by: offset)
|
|
return result
|
|
}
|
|
}
|
|
|
|
TypeErasureTests.test("AnyDifferentiable operations") {
|
|
do {
|
|
var any = AnyDifferentiable(Vector(x: 1, y: 1))
|
|
let tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
|
|
any.move(by: tan)
|
|
expectEqual(Vector(x: 2, y: 2), any.base as? Vector)
|
|
}
|
|
|
|
do {
|
|
var any = AnyDifferentiable(Generic<Float>(x: 1))
|
|
let tan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
|
|
any.move(by: tan)
|
|
expectEqual(Generic<Float>(x: 2), any.base as? Generic<Float>)
|
|
}
|
|
}
|
|
|
|
TypeErasureTests.test("AnyDerivative operations") {
|
|
do {
|
|
var tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
|
|
tan += tan
|
|
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
|
|
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan + tan)
|
|
expectEqual(AnyDerivative(Vector.TangentVector(x: 0, y: 0)), tan - tan)
|
|
expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan.moved(along: tan))
|
|
expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan)
|
|
}
|
|
|
|
do {
|
|
var tan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
|
|
tan += tan
|
|
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 2)), tan)
|
|
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan + tan)
|
|
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 0)), tan - tan)
|
|
expectEqual(AnyDerivative(Generic<Float>.TangentVector(x: 4)), tan.moved(along: tan))
|
|
}
|
|
}
|
|
|
|
TypeErasureTests.test("AnyDerivative.zero") {
|
|
var zero = AnyDerivative.zero
|
|
zero += zero
|
|
zero -= zero
|
|
expectEqual(zero, zero + zero)
|
|
expectEqual(zero, zero - zero)
|
|
expectEqual(zero, zero.moved(along: zero))
|
|
|
|
var tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
|
|
expectEqual(zero, zero)
|
|
expectEqual(AnyDerivative(Vector.TangentVector.zero), tan - tan)
|
|
expectNotEqual(AnyDerivative(Vector.TangentVector.zero), zero)
|
|
expectNotEqual(AnyDerivative.zero, tan - tan)
|
|
tan += zero
|
|
tan -= zero
|
|
expectEqual(tan, tan + zero)
|
|
expectEqual(tan, tan - zero)
|
|
expectEqual(tan, tan.moved(along: zero))
|
|
expectEqual(tan, zero.moved(along: tan))
|
|
expectEqual(zero, zero)
|
|
expectEqual(tan, tan)
|
|
}
|
|
|
|
TypeErasureTests.test("AnyDifferentiable casting") {
|
|
let any = AnyDifferentiable(Vector(x: 1, y: 1))
|
|
expectEqual(Vector(x: 1, y: 1), any.base as? Vector)
|
|
|
|
let genericAny = AnyDifferentiable(Generic<Float>(x: 1))
|
|
expectEqual(Generic<Float>(x: 1),
|
|
genericAny.base as? Generic<Float>)
|
|
expectEqual(nil, genericAny.base as? Generic<Double>)
|
|
}
|
|
|
|
TypeErasureTests.test("AnyDifferentiable reflection") {
|
|
let originalVector = Vector(x: 1, y: 1)
|
|
let vector = AnyDifferentiable(originalVector)
|
|
let mirror = Mirror(reflecting: vector)
|
|
let children = Array(mirror.children)
|
|
expectEqual(2, children.count)
|
|
expectEqual(["x", "y"], children.map(\.label))
|
|
expectEqual([originalVector.x, originalVector.y], children.map { $0.value as! Float })
|
|
}
|
|
|
|
TypeErasureTests.test("AnyDerivative casting") {
|
|
let tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
|
|
expectEqual(Vector.TangentVector(x: 1, y: 1), tan.base as? Vector.TangentVector)
|
|
|
|
let genericTan = AnyDerivative(Generic<Float>.TangentVector(x: 1))
|
|
expectEqual(Generic<Float>.TangentVector(x: 1),
|
|
genericTan.base as? Generic<Float>.TangentVector)
|
|
expectEqual(nil, genericTan.base as? Generic<Double>.TangentVector)
|
|
|
|
let zero = AnyDerivative.zero
|
|
expectEqual(nil, zero.base as? Float)
|
|
expectEqual(nil, zero.base as? Vector.TangentVector)
|
|
expectEqual(nil, zero.base as? Generic<Float>.TangentVector)
|
|
}
|
|
|
|
TypeErasureTests.test("AnyDerivative reflection") {
|
|
let originalTan = Vector.TangentVector(x: 1, y: 1)
|
|
let tan = AnyDerivative(originalTan)
|
|
let mirror = Mirror(reflecting: tan)
|
|
let children = Array(mirror.children)
|
|
expectEqual(2, children.count)
|
|
expectEqual(["x", "y"], children.map(\.label))
|
|
expectEqual([originalTan.x, originalTan.y], children.map { $0.value as! Float })
|
|
}
|
|
|
|
TypeErasureTests.test("AnyDifferentiable differentiation") {
|
|
// Test `AnyDifferentiable` initializer.
|
|
do {
|
|
let x: Float = 3
|
|
let v = AnyDerivative(Float(2))
|
|
let 𝛁x = pullback(at: x, of: { AnyDifferentiable($0) })(v)
|
|
let expectedVJP: Float = 2
|
|
expectEqual(expectedVJP, 𝛁x)
|
|
}
|
|
|
|
do {
|
|
let x = Vector(x: 4, y: 5)
|
|
let v = AnyDerivative(Vector.TangentVector(x: 2, y: 2))
|
|
let 𝛁x = pullback(at: x, of: { AnyDifferentiable($0) })(v)
|
|
let expectedVJP = Vector.TangentVector(x: 2, y: 2)
|
|
expectEqual(expectedVJP, 𝛁x)
|
|
}
|
|
|
|
do {
|
|
let x = Generic<Double>(x: 4)
|
|
let v = AnyDerivative(Generic<Double>.TangentVector(x: 2))
|
|
let 𝛁x = pullback(at: x, of: { AnyDifferentiable($0) })(v)
|
|
let expectedVJP = Generic<Double>.TangentVector(x: 2)
|
|
expectEqual(expectedVJP, 𝛁x)
|
|
}
|
|
}
|
|
|
|
TypeErasureTests.test("AnyDerivative differentiation") {
|
|
// Test `AnyDerivative` operations.
|
|
func tripleSum(_ x: AnyDerivative, _ y: AnyDerivative) -> AnyDerivative {
|
|
let sum = x + y
|
|
return sum + sum + sum
|
|
}
|
|
|
|
do {
|
|
let x = AnyDerivative(Float(4))
|
|
let y = AnyDerivative(Float(-2))
|
|
let v = AnyDerivative(Float(1))
|
|
let expectedVJP: Float = 3
|
|
|
|
let (𝛁x, 𝛁y) = pullback(at: x, y, of: tripleSum)(v)
|
|
expectEqual(expectedVJP, 𝛁x.base as? Float)
|
|
expectEqual(expectedVJP, 𝛁y.base as? Float)
|
|
}
|
|
|
|
do {
|
|
let x = AnyDerivative(Vector.TangentVector(x: 4, y: 5))
|
|
let y = AnyDerivative(Vector.TangentVector(x: -2, y: -1))
|
|
let v = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
|
|
let expectedVJP = Vector.TangentVector(x: 3, y: 3)
|
|
|
|
let (𝛁x, 𝛁y) = pullback(at: x, y, of: tripleSum)(v)
|
|
expectEqual(expectedVJP, 𝛁x.base as? Vector.TangentVector)
|
|
expectEqual(expectedVJP, 𝛁y.base as? Vector.TangentVector)
|
|
}
|
|
|
|
do {
|
|
let x = AnyDerivative(Generic<Double>.TangentVector(x: 4))
|
|
let y = AnyDerivative(Generic<Double>.TangentVector(x: -2))
|
|
let v = AnyDerivative(Generic<Double>.TangentVector(x: 1))
|
|
let expectedVJP = Generic<Double>.TangentVector(x: 3)
|
|
|
|
let (𝛁x, 𝛁y) = pullback(at: x, y, of: tripleSum)(v)
|
|
expectEqual(expectedVJP, 𝛁x.base as? Generic<Double>.TangentVector)
|
|
expectEqual(expectedVJP, 𝛁y.base as? Generic<Double>.TangentVector)
|
|
}
|
|
|
|
// Test `AnyDerivative` initializer.
|
|
func typeErased<T>(_ x: T) -> AnyDerivative
|
|
where T: Differentiable, T.TangentVector == T {
|
|
let any = AnyDerivative(x)
|
|
return any + any
|
|
}
|
|
|
|
do {
|
|
let x: Float = 3
|
|
let v = AnyDerivative(Float(1))
|
|
let 𝛁x = pullback(at: x, of: { x in typeErased(x) })(v)
|
|
let expectedVJP: Float = 2
|
|
expectEqual(expectedVJP, 𝛁x)
|
|
}
|
|
|
|
do {
|
|
let x = Vector.TangentVector(x: 4, y: 5)
|
|
let v = AnyDerivative(Vector.TangentVector(x: 1, y: 1))
|
|
let 𝛁x = pullback(at: x, of: { x in typeErased(x) })(v)
|
|
let expectedVJP = Vector.TangentVector(x: 2, y: 2)
|
|
expectEqual(expectedVJP, 𝛁x)
|
|
}
|
|
|
|
do {
|
|
let x = Generic<Double>.TangentVector(x: 4)
|
|
let v = AnyDerivative(Generic<Double>.TangentVector(x: 1))
|
|
let 𝛁x = pullback(at: x, of: { x in typeErased(x) })(v)
|
|
let expectedVJP = Generic<Double>.TangentVector(x: 2)
|
|
expectEqual(expectedVJP, 𝛁x)
|
|
}
|
|
}
|
|
|
|
runAllTests()
|