mirror of
https://github.com/apple/swift.git
synced 2026-03-08 18:24:30 +01:00
`Differentiable` conformance derivation now supports
`Differentiable.zeroTangentVectorInitializer`.
There are two potential cases:
1. Memberwise derivation: done when `TangentVector` can be initialized memberwise.
2. `{ TangentVector.zero }` derivation: done as a fallback.
`zeroTangentVectorInitializer` is a closure that produces a zero tangent vector,
capturing minimal necessary information from `self`.
It is an instance property, unlike the static property `AdditiveArithmetic.zero`,
and should be used by the differentiation transform for correctness.
Remove `Differentiable.zeroTangentVectorInitializer` dummy default implementation.
Update stdlib `Differentiable` conformances and tests.
Clean up DerivedConformanceDifferentiable.cpp cruft.
Resolves TF-1007.
Progress towards TF-1008: differentiation correctness for projection operations.
60 lines
1.9 KiB
Swift
60 lines
1.9 KiB
Swift
// RUN: %target-run-simple-swift
|
|
// REQUIRES: executable_test
|
|
|
|
import _Differentiation
|
|
import StdlibUnittest
|
|
|
|
var ZeroTangentVectorTests = TestSuite("zeroTangentVectorInitializer")
|
|
|
|
struct Generic<T: Differentiable, U: Differentiable>: Differentiable {
|
|
var x: T
|
|
var y: U
|
|
}
|
|
|
|
struct Nested<T: Differentiable, U: Differentiable>: Differentiable {
|
|
var generic: Generic<T, U>
|
|
}
|
|
|
|
ZeroTangentVectorTests.test("Derivation") {
|
|
typealias G = Generic<[Float], [[Float]]>
|
|
|
|
let generic = G(x: [1, 2, 3], y: [[4, 5, 6], [], [2]])
|
|
let genericZero = G.TangentVector(x: [0, 0, 0], y: [[0, 0, 0], [], [0]])
|
|
expectEqual(generic.zeroTangentVector, genericZero)
|
|
|
|
let nested = Nested(generic: generic)
|
|
let nestedZero = Nested.TangentVector(generic: genericZero)
|
|
expectEqual(nested.zeroTangentVector, nestedZero)
|
|
}
|
|
|
|
// Test differentiation correctness involving projection operations and
|
|
// per-instance zeros.
|
|
ZeroTangentVectorTests.test("DifferentiationCorrectness") {
|
|
struct Struct: Differentiable {
|
|
var x, y: [Float]
|
|
}
|
|
func concatenated(_ lhs: Struct, _ rhs: Struct) -> Struct {
|
|
return Struct(x: lhs.x + rhs.x, y: lhs.y + rhs.y)
|
|
}
|
|
func test(_ s: Struct) -> [Float] {
|
|
let result = concatenated(s, s).withDerivative { dresult in
|
|
// FIXME(TF-1008): Fix incorrect derivative values for
|
|
// "projection operation" operands when differentiation transform uses
|
|
// `Differentiable.zeroTangentVectorInitializer`.
|
|
// Actual: TangentVector(x: [1.0, 1.0, 1.0], y: [])
|
|
// Expected: TangentVector(x: [1.0, 1.0, 1.0], y: [1.0, 1.0, 1.0])
|
|
expectEqual(dresult, Struct.TangentVector(x: [1, 1, 1], y: [1, 1, 1]))
|
|
}
|
|
return result.x
|
|
}
|
|
let s = Struct(x: [1, 2, 3], y: [1, 2, 3])
|
|
let pb = pullback(at: s, in: test)
|
|
// FIXME(TF-1008): Remove `expectCrash` when differentiation transform uses
|
|
// `Differentiable.zeroTangentVectorInitializer`.
|
|
expectCrash {
|
|
_ = pb([1, 1, 1])
|
|
}
|
|
}
|
|
|
|
runAllTests()
|