// RUN: %target-run-simple-swift // REQUIRES: executable_test import StdlibUnittest import DifferentiationUnittest var SupersetVJPTests = TestSuite("SupersetVJP") @differentiable(reverse, wrt: (x, y)) func mulxy(_ x: Tracked, _ y: Tracked) -> Tracked { // use control flow to prevent AD; NB fix when control flow is supported if x > 1000 { return y } return x * y } @derivative(of: mulxy) func dmulxy( _ x: Tracked, _ y: Tracked ) -> (value: Tracked, pullback: (Tracked) -> (Tracked, Tracked)) { return (mulxy(x, y), { v in (y * v, x * v) }) } func calls_mulxy(_ x: Tracked, _ y: Tracked) -> Tracked { return mulxy(x, y) } SupersetVJPTests.testWithLeakChecking("Superset") { expectEqual(3, gradient(at: 2) { x in mulxy(x, 3) }) } SupersetVJPTests.testWithLeakChecking("SupersetNested") { expectEqual(2, gradient(at: 3) { y in calls_mulxy(2, y) }) } SupersetVJPTests.testWithLeakChecking("CrossModuleClosure") { expectEqual(1, gradient(at: Tracked(1)) { x in x + 2 }) } SupersetVJPTests.testWithLeakChecking("SubsetOfSubset") { @differentiable(reverse, wrt: (x, z)) func foo(_ x: Tracked, _ y: Tracked, _ z: Tracked) -> Tracked { withoutDerivative(at: 0) } expectEqual(0, gradient(at: 0, of: { x in foo(x, 0, 0) })) } SupersetVJPTests.test("ApplySubset") { // TF-914 @differentiable(reverse, wrt: x) func foo(_ x: T, _ y: T, apply: @differentiable(reverse) (T, T) -> T) -> T { return apply(x, y) } expectEqual(1, gradient(at: Tracked(0)) { x in foo(x, 0) { $0 + $1 } }) } SupersetVJPTests.test("CrossModule") { let grad = gradient(at: Float(1)) { $0 + 2 } expectEqual(Float(1), grad) } @differentiable(reverse, wrt: (x, y)) func x_T(_ x: Tracked, _ y: T) -> Tracked { if x > 1000 { return x } return x } @derivative(of: x_T) func dx_T( _ x: Tracked, _ y: T ) -> (value: Tracked, pullback: (Tracked) -> (Tracked, T.TangentVector)) { return (x_T(x, y), { v in (x * v, .zero) }) } SupersetVJPTests.testWithLeakChecking("IndirectResults") { expectEqual(2, gradient(at: 2) { x in x_T(x, Tracked(3)) }) } runAllTests()