Files
swift-mirror/test/AutoDiff/validation-test/zero_tangent_vector_initializer.swift
Dan Zheng f9c5d7ae6c [AutoDiff] Derive Differentiable.zeroTangentVectorInitializer. (#31823)
`Differentiable` conformance derivation now supports
`Differentiable.zeroTangentVectorInitializer`.

There are two potential cases:
1. Memberwise derivation: done when `TangentVector` can be initialized memberwise.
2. `{ TangentVector.zero }` derivation: done as a fallback.

`zeroTangentVectorInitializer` is a closure that produces a zero tangent vector,
capturing minimal necessary information from `self`.

It is an instance property, unlike the static property `AdditiveArithmetic.zero`,
and should be used by the differentiation transform for correctness.

Remove `Differentiable.zeroTangentVectorInitializer` dummy default implementation.

Update stdlib `Differentiable` conformances and tests.
Clean up DerivedConformanceDifferentiable.cpp cruft.

Resolves TF-1007.
Progress towards TF-1008: differentiation correctness for projection operations.
2020-05-29 01:59:52 -07:00

60 lines
1.9 KiB
Swift

// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import _Differentiation
import StdlibUnittest
var ZeroTangentVectorTests = TestSuite("zeroTangentVectorInitializer")
struct Generic<T: Differentiable, U: Differentiable>: Differentiable {
var x: T
var y: U
}
struct Nested<T: Differentiable, U: Differentiable>: Differentiable {
var generic: Generic<T, U>
}
ZeroTangentVectorTests.test("Derivation") {
typealias G = Generic<[Float], [[Float]]>
let generic = G(x: [1, 2, 3], y: [[4, 5, 6], [], [2]])
let genericZero = G.TangentVector(x: [0, 0, 0], y: [[0, 0, 0], [], [0]])
expectEqual(generic.zeroTangentVector, genericZero)
let nested = Nested(generic: generic)
let nestedZero = Nested.TangentVector(generic: genericZero)
expectEqual(nested.zeroTangentVector, nestedZero)
}
// Test differentiation correctness involving projection operations and
// per-instance zeros.
ZeroTangentVectorTests.test("DifferentiationCorrectness") {
struct Struct: Differentiable {
var x, y: [Float]
}
func concatenated(_ lhs: Struct, _ rhs: Struct) -> Struct {
return Struct(x: lhs.x + rhs.x, y: lhs.y + rhs.y)
}
func test(_ s: Struct) -> [Float] {
let result = concatenated(s, s).withDerivative { dresult in
// FIXME(TF-1008): Fix incorrect derivative values for
// "projection operation" operands when differentiation transform uses
// `Differentiable.zeroTangentVectorInitializer`.
// Actual: TangentVector(x: [1.0, 1.0, 1.0], y: [])
// Expected: TangentVector(x: [1.0, 1.0, 1.0], y: [1.0, 1.0, 1.0])
expectEqual(dresult, Struct.TangentVector(x: [1, 1, 1], y: [1, 1, 1]))
}
return result.x
}
let s = Struct(x: [1, 2, 3], y: [1, 2, 3])
let pb = pullback(at: s, in: test)
// FIXME(TF-1008): Remove `expectCrash` when differentiation transform uses
// `Differentiable.zeroTangentVectorInitializer`.
expectCrash {
_ = pb([1, 1, 1])
}
}
runAllTests()