mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
These recently started failing when optimized. Disable while we investigate and fix. rdar://71642726
62 lines
2.0 KiB
Swift
62 lines
2.0 KiB
Swift
// RUN: %target-run-simple-swift
|
|
// REQUIRES: executable_test
|
|
// rdar://71642726 this test is crashing with optimizations.
|
|
// REQUIRES: swift_test_mode_optimize_none
|
|
|
|
import _Differentiation
|
|
import StdlibUnittest
|
|
|
|
var ZeroTangentVectorTests = TestSuite("zeroTangentVectorInitializer")
|
|
|
|
struct Generic<T: Differentiable, U: Differentiable>: Differentiable {
|
|
var x: T
|
|
var y: U
|
|
}
|
|
|
|
struct Nested<T: Differentiable, U: Differentiable>: Differentiable {
|
|
var generic: Generic<T, U>
|
|
}
|
|
|
|
ZeroTangentVectorTests.test("Derivation") {
|
|
typealias G = Generic<[Float], [[Float]]>
|
|
|
|
let generic = G(x: [1, 2, 3], y: [[4, 5, 6], [], [2]])
|
|
let genericZero = G.TangentVector(x: [0, 0, 0], y: [[0, 0, 0], [], [0]])
|
|
expectEqual(generic.zeroTangentVector, genericZero)
|
|
|
|
let nested = Nested(generic: generic)
|
|
let nestedZero = Nested.TangentVector(generic: genericZero)
|
|
expectEqual(nested.zeroTangentVector, nestedZero)
|
|
}
|
|
|
|
// Test differentiation correctness involving projection operations and
|
|
// per-instance zeros.
|
|
ZeroTangentVectorTests.test("DifferentiationCorrectness") {
|
|
struct Struct: Differentiable {
|
|
var x, y: [Float]
|
|
}
|
|
func concatenated(_ lhs: Struct, _ rhs: Struct) -> Struct {
|
|
return Struct(x: lhs.x + rhs.x, y: lhs.y + rhs.y)
|
|
}
|
|
func test(_ s: Struct) -> [Float] {
|
|
let result = concatenated(s, s).withDerivative { dresult in
|
|
// FIXME(TF-1008): Fix incorrect derivative values for
|
|
// "projection operation" operands when differentiation transform uses
|
|
// `Differentiable.zeroTangentVectorInitializer`.
|
|
// Actual: TangentVector(x: [1.0, 1.0, 1.0], y: [])
|
|
// Expected: TangentVector(x: [1.0, 1.0, 1.0], y: [1.0, 1.0, 1.0])
|
|
expectEqual(dresult, Struct.TangentVector(x: [1, 1, 1], y: [1, 1, 1]))
|
|
}
|
|
return result.x
|
|
}
|
|
let s = Struct(x: [1, 2, 3], y: [1, 2, 3])
|
|
let pb = pullback(at: s, in: test)
|
|
// FIXME(TF-1008): Remove `expectCrash` when differentiation transform uses
|
|
// `Differentiable.zeroTangentVectorInitializer`.
|
|
expectCrash {
|
|
_ = pb([1, 1, 1])
|
|
}
|
|
}
|
|
|
|
runAllTests()
|