// RUN: %target-run-simple-swift // REQUIRES: executable_test // Test is unexpectedly passing on no_assert config on Linux // REQUIRES: rdar89860761 // FIXME: Disabled due to test failure with `-O` (https://github.com/apple/swift/issues/55690). // XFAIL: swift_test_mode_optimize // XFAIL: swift_test_mode_optimize_size // XFAIL: swift_test_mode_optimize_unchecked import StdlibUnittest import DifferentiationUnittest // Test end-to-end differentiation of `@differentiable` protocol requirements. var ProtocolRequirementAutodiffTests = TestSuite("ProtocolRequirementDifferentiation") // MARK: - Method requirements. protocol DiffReq: Differentiable { @differentiable(reverse, wrt: (self, x)) func f(_ x: Tracked) -> Tracked } extension DiffReq where TangentVector: AdditiveArithmetic { @inline(never) // Prevent specialization, to test all witness code. func gradF(at x: Tracked) -> (Self.TangentVector, Tracked) { return (valueWithPullback(at: self, x) { s, x in s.f(x) }).1(1) } } struct Quadratic: DiffReq, AdditiveArithmetic { typealias TangentVector = Quadratic @differentiable(reverse) let a: Tracked @differentiable(reverse) let b: Tracked @differentiable(reverse) let c: Tracked init(_ a: Tracked, _ b: Tracked, _ c: Tracked) { self.a = a self.b = b self.c = c } @differentiable(reverse, wrt: (self, x)) func f(_ x: Tracked) -> Tracked { return a * x * x + b * x + c } } ProtocolRequirementAutodiffTests.testWithLeakChecking("func") { expectEqual((Quadratic(0, 0, 1), 12), Quadratic(11, 12, 13).gradF(at: 0)) expectEqual((Quadratic(1, 1, 1), 2 * 11 + 12), Quadratic(11, 12, 13).gradF(at: 1)) expectEqual((Quadratic(4, 2, 1), 2 * 11 * 2 + 12), Quadratic(11, 12, 13).gradF(at: 2)) } // MARK: - Constructor, accessor, and subscript requirements. protocol FunctionsOfX: Differentiable { @differentiable(reverse) init(x: Tracked) @differentiable(reverse) var x: Tracked { get } @differentiable(reverse) var y: Tracked { get } @differentiable(reverse) var z: Tracked { get } @differentiable(reverse) subscript() -> Tracked { get } } struct TestFunctionsOfX: FunctionsOfX { @differentiable(reverse) init(x: Tracked) { self.x = x self.y = x * x } /// x = x var x: Tracked /// y = x * x var y: Tracked /// z = x * x + x var z: Tracked { return y + x } @differentiable(reverse) subscript() -> Tracked { return z } } @inline(never) // Prevent specialization, to test all witness code. func derivatives(at x: Tracked, of: F.Type) -> (Tracked, Tracked, Tracked, Tracked) { let dxdx = gradient(at: x) { x in F(x: x).x } let dydx = gradient(at: x) { x in F(x: x).y } let dzdx = gradient(at: x) { x in F(x: x).z } let dsubscriptdx = gradient(at: x) { x in F(x: x)[] } return (dxdx, dydx, dzdx, dsubscriptdx) } ProtocolRequirementAutodiffTests.testWithLeakChecking("constructor, accessor, subscript") { expectEqual( (1.0, 4.0, 5.0, 5.0), derivatives(at: 2.0, of: TestFunctionsOfX.self)) } // MARK: - Test witness method SIL type computation. protocol P: Differentiable { @differentiable(reverse, wrt: (x, y)) func foo(_ x: Tracked, _ y: Double) -> Tracked } struct S: P { @differentiable(reverse, wrt: (x, y)) func foo(_ x: Tracked, _ y: Double) -> Tracked { return x } } // MARK: - Overriding protocol method adding `@differentiable` attribute. public protocol Distribution { associatedtype Value func logProbability(of value: Value) -> Tracked } public protocol DifferentiableDistribution: Differentiable, Distribution { @differentiable(reverse, wrt: self) func logProbability(of value: Value) -> Tracked } struct Foo: DifferentiableDistribution { @differentiable(reverse, wrt: self) func logProbability(of value: Tracked) -> Tracked { .zero } } @differentiable(reverse) func blah(_ x: T) -> Tracked where T.Value: AdditiveArithmetic { x.logProbability(of: .zero) } // Adding a more general `@differentiable` attribute. public protocol DoubleDifferentiableDistribution: DifferentiableDistribution where Value: Differentiable { @differentiable(reverse, wrt: self) @differentiable(reverse, wrt: (self, value)) func logProbability(of value: Value) -> Tracked } @differentiable(reverse) func blah2(_ x: T, _ value: T.Value) -> Tracked where T.Value: AdditiveArithmetic { x.logProbability(of: value) } // Satisfying the requirement with more wrt parameter indices than are necessary. protocol DifferentiableFoo { associatedtype T: Differentiable @differentiable(reverse, wrt: x) func foo(_ x: T) -> Tracked } protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo { @differentiable(reverse, wrt: (self, x)) func foo(_ x: T) -> Tracked } struct MoreDifferentiableFooStruct: MoreDifferentiableFoo { @differentiable(reverse, wrt: (self, x)) func foo(_ x: Tracked) -> Tracked { x } } // Satisfying the requirement with a less-constrained derivative than is necessary. protocol ExtraDerivativeConstraint {} protocol HasExtraConstrainedDerivative { @differentiable(reverse) func requirement(_ x: T) -> T } struct SatisfiesDerivativeWithLessConstraint: HasExtraConstrainedDerivative { @differentiable(reverse) func requirement(_ x: T) -> T { x } } runAllTests()