// RUN: %target-run-simple-swift // REQUIRES: executable_test import StdlibUnittest import DifferentiationUnittest var MethodTests = TestSuite("Method") // ==== Tests with generated adjoint ==== struct Parameter : Equatable { private let storedX: Float @differentiable(reverse, wrt: (self)) var x: Float { return storedX } init(x: Float) { storedX = x } @derivative(of: x) func vjpX() -> (value: Float, pullback: (Float) -> Parameter) { return (x, { dx in Parameter(x: dx) } ) } @derivative(of: x) func jvpX() -> (value: Float, differential: (Parameter) -> Float) { return (x, { $0.x }) } } extension Parameter { func squared() -> Float { return x * x } static func squared(p: Parameter) -> Float { return p.x * p.x } func multiplied(with other: Float) -> Float { return x * other } static func * (_ a: Parameter, _ b: Parameter) -> Float { return a.x * b.x } } extension Parameter : Differentiable, AdditiveArithmetic { typealias TangentVector = Parameter typealias Scalar = Float typealias Shape = () init(repeating repeatedValue: Float, shape: ()) { self.init(x: repeatedValue) } static func + (lhs: Parameter, rhs: Parameter) -> Parameter { return Parameter(x: lhs.x + rhs.x) } static func - (lhs: Parameter, rhs: Parameter) -> Parameter { return Parameter(x: lhs.x - rhs.x) } static func * (lhs: Scalar, rhs: Parameter) -> Parameter { return Parameter(x: lhs * rhs.x) } static var zero: Parameter { return Parameter(x: 0) } } MethodTests.test( "instance method with generated adjoint, called from differentiated func" ) { func f(_ p: Parameter) -> Float { return 100 * p.squared() } expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f)) expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f)) } MethodTests.test( "instance method with generated adjoint, differentiated directly" ) { // This is our current syntax for taking gradients of instance methods // directly. If/when we develop nicer syntax for this, change this test. func g(p: Parameter) -> Float { p.squared() } expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: g)) expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: g)) } MethodTests.test("instance method with generated adjoint, wrt only self") { func f(_ p: Parameter) -> Float { return 100 * p.multiplied(with: 200) } expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f)) expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f)) } MethodTests.test("instance method with generated adjoint, wrt only non-self") { func f(_ other: Float) -> Float { return 100 * Parameter(x: 200).multiplied(with: other) } expectEqual(100 * 200, gradient(at: 1, of: f)) expectEqual(100 * 200, gradient(at: 2, of: f)) } MethodTests.test( "instance method with generated adjoint, wrt self and non-self" ) { expectEqual( (Parameter(x: 100), 200), gradient(at: Parameter(x: 200), 100) { $0.multiplied(with: $1) }) expectEqual( (Parameter(x: 200), 100), gradient(at: Parameter(x: 100), 200) { $0.multiplied(with: $1) }) } MethodTests.test( "static method with generated adjoint, called from differentiated func" ) { func f(_ p: Parameter) -> Float { return 100 * Parameter.squared(p: p) } expectEqual(Parameter(x: 4 * 100), gradient(at: Parameter(x: 2), of: f)) expectEqual(Parameter(x: 40 * 100), gradient(at: Parameter(x: 20), of: f)) } MethodTests.test( "static method with generated adjoint, differentiated directly" ) { expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), of: Parameter.squared)) expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), of: Parameter.squared)) } MethodTests.test("static method with generated adjoint, wrt only first param") { func f(_ p: Parameter) -> Float { return 100 * (p * Parameter(x: 200)) } expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f)) expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f)) } MethodTests.test("static method with generated adjoint, wrt only second param") { func f(_ p: Parameter) -> Float { return 100 * (Parameter(x: 200) * p) } expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 1), of: f)) expectEqual(Parameter(x: 100 * 200), gradient(at: Parameter(x: 2), of: f)) } MethodTests.test("static method with generated adjoint, wrt all params") { func g(a: Parameter, b: Parameter) -> Float { a * b } expectEqual((Parameter(x: 100), Parameter(x: 200)), gradient(at: Parameter(x: 200), Parameter(x: 100), of: g)) expectEqual((Parameter(x: 200), Parameter(x: 100)), gradient(at: Parameter(x: 100), Parameter(x: 200), of: g)) } /* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed We cannot use `Tracked` :( struct ParameterTracked : Equatable { private let storedX: Tracked @differentiable(reverse, wrt: (self)) var x: Tracked { return storedX } init(x: Tracked) { storedX = x } @derivative(of: x) func vjpX() -> (value: Tracked, pullback: (Tracked) -> ParameterTracked) { return (x, { dx in ParameterTracked(x: dx) } ) } @derivative(of: x) func jvpX() -> (value: Tracked, differential: (ParameterTracked) -> Tracked) { return (x, { $0.x }) } } extension ParameterTracked { func squared() -> Tracked { return x * x } static func squared(p: ParameterTracked) -> Tracked { return p.x * p.x } func multiplied(with other: Tracked) -> Tracked { return x * other } static func * (_ a: ParameterTracked, _ b: ParameterTracked) -> Tracked { return a.x * b.x } } extension ParameterTracked : Differentiable, AdditiveArithmetic { typealias TangentVector = ParameterTracked typealias Scalar = Tracked typealias Shape = () init(repeating repeatedValue: Tracked, shape: ()) { self.init(x: repeatedValue) } static func + (lhs: ParameterTracked, rhs: ParameterTracked) -> ParameterTracked { return ParameterTracked(x: lhs.x + rhs.x) } static func - (lhs: ParameterTracked, rhs: ParameterTracked) -> ParameterTracked { return ParameterTracked(x: lhs.x - rhs.x) } static func * (lhs: Scalar, rhs: ParameterTracked) -> ParameterTracked { return ParameterTracked(x: lhs * rhs.x) } static var zero: ParameterTracked { return ParameterTracked(x: 0) } } MethodTests.testWithLeakChecking( "instance method with generated adjoint, called from differentiated func" ) { func f(_ p: ParameterTracked) -> Tracked { return 100 * p.squared() } expectEqual(ParameterTracked(x: 4 * 100), gradient(at: ParameterTracked(x: 2), of: f)) expectEqual(ParameterTracked(x: 40 * 100), gradient(at: ParameterTracked(x: 20), of: f)) } MethodTests.testWithLeakChecking( "instance method with generated adjoint, differentiated directly" ) { // This is our current syntax for taking gradients of instance methods // directly. If/when we develop nicer syntax for this, change this test. func g(p: ParameterTracked) -> Tracked { p.squared() } expectEqual(ParameterTracked(x: 4), gradient(at: ParameterTracked(x: 2), of: g)) expectEqual(ParameterTracked(x: 40), gradient(at: ParameterTracked(x: 20), of: g)) } MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only self") { func f(_ p: ParameterTracked) -> Tracked { return 100 * p.multiplied(with: 200) } expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f)) expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f)) } MethodTests.testWithLeakChecking("instance method with generated adjoint, wrt only non-self") { func f(_ other: Tracked) -> Tracked { return 100 * ParameterTracked(x: 200).multiplied(with: other) } expectEqual(100 * 200, gradient(at: 1, of: f)) expectEqual(100 * 200, gradient(at: 2, of: f)) } MethodTests.testWithLeakChecking( "instance method with generated adjoint, wrt self and non-self" ) { expectEqual( (ParameterTracked(x: 100), 200), gradient(at: ParameterTracked(x: 200), 100) { $0.multiplied(with: $1) }) expectEqual( (ParameterTracked(x: 200), 100), gradient(at: ParameterTracked(x: 100), 200) { $0.multiplied(with: $1) }) } MethodTests.testWithLeakChecking( "static method with generated adjoint, called from differentiated func" ) { func f(_ p: ParameterTracked) -> Tracked { return 100 * ParameterTracked.squared(p: p) } expectEqual(ParameterTracked(x: 4 * 100), gradient(at: ParameterTracked(x: 2), of: f)) expectEqual(ParameterTracked(x: 40 * 100), gradient(at: ParameterTracked(x: 20), of: f)) } MethodTests.testWithLeakChecking( "static method with generated adjoint, differentiated directly" ) { expectEqual(ParameterTracked(x: 4), gradient(at: ParameterTracked(x: 2), of: ParameterTracked.squared)) expectEqual(ParameterTracked(x: 40), gradient(at: ParameterTracked(x: 20), of: ParameterTracked.squared)) } MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only first param") { func f(_ p: ParameterTracked) -> Tracked { return 100 * (p * ParameterTracked(x: 200)) } expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f)) expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f)) } MethodTests.testWithLeakChecking("static method with generated adjoint, wrt only second param") { func f(_ p: ParameterTracked) -> Tracked { return 100 * (ParameterTracked(x: 200) * p) } expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 1), of: f)) expectEqual(ParameterTracked(x: 100 * 200), gradient(at: ParameterTracked(x: 2), of: f)) } MethodTests.testWithLeakChecking("static method with generated adjoint, wrt all params") { func g(a: ParameterTracked, b: ParameterTracked) -> Tracked { a * b } expectEqual((ParameterTracked(x: 100), ParameterTracked(x: 200)), gradient(at: ParameterTracked(x: 200), ParameterTracked(x: 100), of: g)) expectEqual((ParameterTracked(x: 200), ParameterTracked(x: 100)), gradient(at: ParameterTracked(x: 100), ParameterTracked(x: 200), of: g)) } */ // ==== Tests with custom adjoint ==== // Test self-reordering thunk for jvp/vjp methods. struct DiffWrtSelf : Differentiable { @differentiable(reverse, wrt: (self, x, y)) func call(_ x: T, _ y: U) -> T { return x } @derivative(of: call) func _jvpCall(_ x: T, _ y: U) -> (value: T, differential: (DiffWrtSelf.TangentVector, T.TangentVector, U.TangentVector) -> T.TangentVector) { return (x, { (dself, dx, dy) in dx }) } @derivative(of: call) func _vjpCall(_ x: T, _ y: U) -> (value: T, pullback: (T.TangentVector) -> (DiffWrtSelf.TangentVector, T.TangentVector, U.TangentVector)) { return (x, { (.zero, $0, .zero) }) } } struct CustomParameter : Equatable { let storedX: Float @differentiable(reverse, wrt: (self)) var x: Float { return storedX } init(x: Float) { storedX = x } @derivative(of: x) func vjpX() -> (value: Float, pullback: (Float) -> CustomParameter) { return (x, { dx in CustomParameter(x: dx) }) } } extension CustomParameter : Differentiable, AdditiveArithmetic { typealias TangentVector = CustomParameter typealias Scalar = Float typealias Shape = () init(repeating repeatedValue: Float, shape: ()) { self.init(x: repeatedValue) } static func + (lhs: CustomParameter, rhs: CustomParameter) -> CustomParameter { return CustomParameter(x: lhs.x + rhs.x) } static func - (lhs: CustomParameter, rhs: CustomParameter) -> CustomParameter { return CustomParameter(x: lhs.x - rhs.x) } static func * (lhs: Scalar, rhs: CustomParameter) -> CustomParameter { return CustomParameter(x: lhs * rhs.x) } static var zero: CustomParameter { return CustomParameter(x: 0) } } extension Float { func clamped(to limits: ClosedRange) -> Float { return min(max(self, limits.lowerBound), limits.upperBound) } } extension CustomParameter { @differentiable(reverse, wrt: (self)) func squared() -> Float { return x * x } @derivative(of: squared) func dSquared() -> (value: Float, pullback: (Float) -> CustomParameter) { return (squared(), { [x] v in CustomParameter(x: (2 * x).clamped(to: -10.0...10.0) * v) }) } @differentiable(reverse) static func squared(p: CustomParameter) -> Float { return p.x * p.x } @derivative(of: squared) static func dSquared( _ p: CustomParameter ) -> (value: Float, pullback: (Float) -> CustomParameter) { return (p.x * p.x, { v in CustomParameter(x: (2 * p.x).clamped(to: -10.0...10.0) * v) }) } // There is currently no way to define multiple custom VJPs wrt different // parameters on the same func, so we define a copy of this func per adjoint. @differentiable(reverse, wrt: (self, other)) func multiplied(with other: Float) -> Float { return x * other } @differentiable(reverse, wrt: (other)) func multiplied_constSelf(with other: Float) -> Float { return x * other } @differentiable(reverse, wrt: (self)) func multiplied_constOther(with other: Float) -> Float { return x * other } @derivative(of: multiplied) func dMultiplied_wrtAll( with other: Float ) -> (value: Float, pullback: (Float) -> (CustomParameter, Float)) { return (multiplied(with: other), { [x] v in (CustomParameter(x: other.clamped(to: -10.0...10.0) * v), x.clamped(to: -10.0...10.0) * v) }) } @derivative(of: multiplied_constSelf, wrt: other) func dMultiplied_wrtOther( with other: Float ) -> (value: Float, pullback: (Float) -> Float) { let (r, pb) = dMultiplied_wrtAll(with: other) return (r, { v in pb(v).1 }) } @derivative(of: multiplied_constOther, wrt: self) func dMultiplied_wrtSelf( with other: Float ) -> (value: Float, pullback: (Float) -> CustomParameter) { let (r, pb) = dMultiplied_wrtAll(with: other) return (r, { v in pb(v).0 }) } @differentiable(reverse) static func multiply(_ lhs: CustomParameter, _ rhs: CustomParameter) -> Float { return lhs.x * rhs.x } @differentiable(reverse, wrt: (rhs)) static func multiply_constLhs(_ lhs: CustomParameter, _ rhs: CustomParameter) -> Float { return lhs.x * rhs.x } @derivative(of: multiply) static func dMultiply_wrtAll(_ lhs: CustomParameter,_ rhs: CustomParameter) -> (value: Float, pullback: (Float) -> (CustomParameter, CustomParameter)) { let result = multiply(lhs, rhs) return (result, { v in (CustomParameter(x: rhs.x.clamped(to: -10.0...10.0) * v), CustomParameter(x: lhs.x.clamped(to: -10.0...10.0) * v)) }) } @derivative(of: multiply_constLhs, wrt: rhs) static func dMultiply_wrtRhs(_ lhs: CustomParameter, _ rhs: CustomParameter) -> (value: Float, pullback: (Float) -> CustomParameter) { let (r, pb) = dMultiply_wrtAll(lhs, rhs) return (r, { v in pb(v).1 }) } } MethodTests.test( "instance method with custom adjoint, called from differentiated func" ) { func f(_ p: CustomParameter) -> Float { return 100 * p.squared() } expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f)) expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f)) } MethodTests.test("instance method with generated adjoint, differentiated directly") { // This is our current syntax for taking gradients of instance methods // directly. If/when we develop nicer syntax for this, change this test. func g(p: CustomParameter) -> Float { p.squared() } expectEqual(CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: g)) expectEqual(CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: g)) } MethodTests.test("static method with custom adjoint, called from differentiated func") { func f(_ p: CustomParameter) -> Float { return 100 * CustomParameter.squared(p: p) } expectEqual(CustomParameter(x: 4 * 100), gradient(at: CustomParameter(x: 2), of: f)) expectEqual(CustomParameter(x: 10 * 100), gradient(at: CustomParameter(x: 20), of: f)) } MethodTests.test("static method with custom adjoint, differentiated directly") { expectEqual( CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), of: CustomParameter.squared)) expectEqual( CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), of: CustomParameter.squared)) } MethodTests.test("instance method with custom adjoint, wrt only self") { func f(_ p: CustomParameter) -> Float { return 100 * p.multiplied_constOther(with: 200) } expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f)) expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f)) } MethodTests.test("instance method with custom adjoint, wrt only non-self") { func f(_ other: Float) -> Float { return 100 * CustomParameter(x: 200).multiplied_constSelf(with: other) } expectEqual(100 * 10, gradient(at: 1, of: f)) expectEqual(100 * 10, gradient(at: 2, of: f)) } MethodTests.test("instance method with custom adjoint, wrt self and non-self") { func g(p: CustomParameter, o: Float) -> Float { p.multiplied(with: o) } expectEqual((CustomParameter(x: 5), 10), gradient(at: CustomParameter(x: 100), 5, of: g)) expectEqual((CustomParameter(x: 10), 5), gradient(at: CustomParameter(x: 5), 100, of: g)) } MethodTests.test("static method with custom adjoint, wrt only lhs") { func f(_ p: CustomParameter) -> Float { return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p) } expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f)) expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f)) } MethodTests.test("static method with custom adjoint, wrt only rhs") { func f(_ p: CustomParameter) -> Float { return 100 * CustomParameter.multiply_constLhs(CustomParameter(x: 200), p) } expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 1), of: f)) expectEqual(CustomParameter(x: 100 * 10), gradient(at: CustomParameter(x: 2), of: f)) } MethodTests.test("static method with custom adjoint, wrt all") { func f(_ a: CustomParameter, _ b: CustomParameter) -> Float { return CustomParameter.multiply(a, b) } expectEqual((CustomParameter(x: 5), CustomParameter(x: 10)), gradient(at: CustomParameter(x: 100), CustomParameter(x: 5), of: f)) expectEqual((CustomParameter(x: 10), CustomParameter(x: 5)), gradient(at: CustomParameter(x: 5), CustomParameter(x: 100), of: f)) } /* Temporary disabled until https://github.com/swiftlang/swift/issues/84840 is fixed We cannot use `Tracked` :( struct CustomParameterTracked : Equatable { let storedX: Tracked @differentiable(reverse, wrt: (self)) var x: Tracked { return storedX } init(x: Tracked) { storedX = x } @derivative(of: x) func vjpX() -> (value: Tracked, pullback: (Tracked) -> CustomParameterTracked) { return (x, { dx in CustomParameterTracked(x: dx) }) } } extension CustomParameterTracked : Differentiable, AdditiveArithmetic { typealias TangentVector = CustomParameterTracked typealias Scalar = Tracked typealias Shape = () init(repeating repeatedValue: Tracked, shape: ()) { self.init(x: repeatedValue) } static func + (lhs: CustomParameterTracked, rhs: CustomParameterTracked) -> CustomParameterTracked { return CustomParameterTracked(x: lhs.x + rhs.x) } static func - (lhs: CustomParameterTracked, rhs: CustomParameterTracked) -> CustomParameterTracked { return CustomParameterTracked(x: lhs.x - rhs.x) } static func * (lhs: Scalar, rhs: CustomParameterTracked) -> CustomParameterTracked { return CustomParameterTracked(x: lhs * rhs.x) } static var zero: CustomParameterTracked { return CustomParameterTracked(x: 0) } } extension Tracked where T : FloatingPoint { func clamped(to limits: ClosedRange>) -> Tracked { return min(max(self, limits.lowerBound), limits.upperBound) } } extension CustomParameterTracked { @differentiable(reverse, wrt: (self)) func squared() -> Tracked { return x * x } @derivative(of: squared) func dSquared() -> (value: Tracked, pullback: (Tracked) -> CustomParameterTracked) { return (squared(), { [x] v in CustomParameterTracked(x: (2 * x).clamped(to: -10.0...10.0) * v) }) } @differentiable(reverse) static func squared(p: CustomParameterTracked) -> Tracked { return p.x * p.x } @derivative(of: squared) static func dSquared( _ p: CustomParameterTracked ) -> (value: Tracked, pullback: (Tracked) -> CustomParameterTracked) { return (p.x * p.x, { v in CustomParameterTracked(x: (2 * p.x).clamped(to: -10.0...10.0) * v) }) } // There is currently no way to define multiple custom VJPs wrt different // parameters on the same func, so we define a copy of this func per adjoint. @differentiable(reverse, wrt: (self, other)) func multiplied(with other: Tracked) -> Tracked { return x * other } @differentiable(reverse, wrt: (other)) func multiplied_constSelf(with other: Tracked) -> Tracked { return x * other } @differentiable(reverse, wrt: (self)) func multiplied_constOther(with other: Tracked) -> Tracked { return x * other } @derivative(of: multiplied) func dMultiplied_wrtAll( with other: Tracked ) -> (value: Tracked, pullback: (Tracked) -> (CustomParameterTracked, Tracked)) { return (multiplied(with: other), { [x] v in (CustomParameterTracked(x: other.clamped(to: -10.0...10.0) * v), x.clamped(to: -10.0...10.0) * v) }) } @derivative(of: multiplied_constSelf, wrt: other) func dMultiplied_wrtOther( with other: Tracked ) -> (value: Tracked, pullback: (Tracked) -> Tracked) { let (r, pb) = dMultiplied_wrtAll(with: other) return (r, { v in pb(v).1 }) } @derivative(of: multiplied_constOther, wrt: self) func dMultiplied_wrtSelf( with other: Tracked ) -> (value: Tracked, pullback: (Tracked) -> CustomParameterTracked) { let (r, pb) = dMultiplied_wrtAll(with: other) return (r, { v in pb(v).0 }) } @differentiable(reverse) static func multiply(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked) -> Tracked { return lhs.x * rhs.x } @differentiable(reverse, wrt: (rhs)) static func multiply_constLhs(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked) -> Tracked { return lhs.x * rhs.x } @derivative(of: multiply) static func dMultiply_wrtAll(_ lhs: CustomParameterTracked,_ rhs: CustomParameterTracked) -> (value: Tracked, pullback: (Tracked) -> (CustomParameterTracked, CustomParameterTracked)) { let result = multiply(lhs, rhs) return (result, { v in (CustomParameterTracked(x: rhs.x.clamped(to: -10.0...10.0) * v), CustomParameterTracked(x: lhs.x.clamped(to: -10.0...10.0) * v)) }) } @derivative(of: multiply_constLhs, wrt: rhs) static func dMultiply_wrtRhs(_ lhs: CustomParameterTracked, _ rhs: CustomParameterTracked) -> (value: Tracked, pullback: (Tracked) -> CustomParameterTracked) { let (r, pb) = dMultiply_wrtAll(lhs, rhs) return (r, { v in pb(v).1 }) } } MethodTests.testWithLeakChecking( "instance method with custom adjoint, called from differentiated func" ) { func f(_ p: CustomParameterTracked) -> Tracked { return 100 * p.squared() } expectEqual(CustomParameterTracked(x: 4 * 100), gradient(at: CustomParameterTracked(x: 2), of: f)) expectEqual(CustomParameterTracked(x: 10 * 100), gradient(at: CustomParameterTracked(x: 20), of: f)) } MethodTests.testWithLeakChecking("instance method with generated adjoint, differentiated directly") { // This is our current syntax for taking gradients of instance methods // directly. If/when we develop nicer syntax for this, change this test. func g(p: CustomParameterTracked) -> Tracked { p.squared() } expectEqual(CustomParameterTracked(x: 4), gradient(at: CustomParameterTracked(x: 2), of: g)) expectEqual(CustomParameterTracked(x: 10), gradient(at: CustomParameterTracked(x: 20), of: g)) } MethodTests.testWithLeakChecking("static method with custom adjoint, called from differentiated func") { func f(_ p: CustomParameterTracked) -> Tracked { return 100 * CustomParameterTracked.squared(p: p) } expectEqual(CustomParameterTracked(x: 4 * 100), gradient(at: CustomParameterTracked(x: 2), of: f)) expectEqual(CustomParameterTracked(x: 10 * 100), gradient(at: CustomParameterTracked(x: 20), of: f)) } MethodTests.testWithLeakChecking("static method with custom adjoint, differentiated directly") { expectEqual( CustomParameterTracked(x: 4), gradient(at: CustomParameterTracked(x: 2), of: CustomParameterTracked.squared)) expectEqual( CustomParameterTracked(x: 10), gradient(at: CustomParameterTracked(x: 20), of: CustomParameterTracked.squared)) } MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only self") { func f(_ p: CustomParameterTracked) -> Tracked { return 100 * p.multiplied_constOther(with: 200) } expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f)) expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f)) } MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt only non-self") { func f(_ other: Tracked) -> Tracked { return 100 * CustomParameterTracked(x: 200).multiplied_constSelf(with: other) } expectEqual(100 * 10, gradient(at: 1, of: f)) expectEqual(100 * 10, gradient(at: 2, of: f)) } MethodTests.testWithLeakChecking("instance method with custom adjoint, wrt self and non-self") { func g(p: CustomParameterTracked, o: Tracked) -> Tracked { p.multiplied(with: o) } expectEqual((CustomParameterTracked(x: 5), 10), gradient(at: CustomParameterTracked(x: 100), 5, of: g)) expectEqual((CustomParameterTracked(x: 10), 5), gradient(at: CustomParameterTracked(x: 5), 100, of: g)) } MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only lhs") { func f(_ p: CustomParameterTracked) -> Tracked { return 100 * CustomParameterTracked.multiply_constLhs(CustomParameterTracked(x: 200), p) } expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f)) expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f)) } MethodTests.testWithLeakChecking("static method with custom adjoint, wrt only rhs") { func f(_ p: CustomParameterTracked) -> Tracked { return 100 * CustomParameterTracked.multiply_constLhs(CustomParameterTracked(x: 200), p) } expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 1), of: f)) expectEqual(CustomParameterTracked(x: 100 * 10), gradient(at: CustomParameterTracked(x: 2), of: f)) } MethodTests.testWithLeakChecking("static method with custom adjoint, wrt all") { func f(_ a: CustomParameterTracked, _ b: CustomParameterTracked) -> Tracked { return CustomParameterTracked.multiply(a, b) } expectEqual((CustomParameterTracked(x: 5), CustomParameterTracked(x: 10)), gradient(at: CustomParameterTracked(x: 100), CustomParameterTracked(x: 5), of: f)) expectEqual((CustomParameterTracked(x: 10), CustomParameterTracked(x: 5)), gradient(at: CustomParameterTracked(x: 5), CustomParameterTracked(x: 100), of: f)) } */ runAllTests()