// RUN: %empty-directory(%t) // RUN: %gyb %s -o %t/differential_operators.swift // RUN: %target-build-swift %t/differential_operators.swift -o %t/differential_operators // RUN: %target-codesign %t/differential_operators // RUN: %target-run %t/differential_operators // REQUIRES: executable_test import _Differentiation import StdlibUnittest var DifferentialOperatorTestSuite = TestSuite("DifferentialOperator") % for arity in range(1, 3 + 1): % params = ', '.join(['_ x%d: Float' % i for i in range(arity)]) % pb_return_type = '(' + ', '.join(['Float' for _ in range(arity)]) + ')' func exampleDiffFunc_${arity}(${params}) -> Float { fatalError() } @derivative(of: exampleDiffFunc_${arity}) func exampleVJP_${arity}(${params}) -> (value: Float, pullback: (Float) -> ${pb_return_type}) { ( ${' + '.join(['x%d * x%d' % (i, i) for i in range(arity)])}, { (${', '.join(['2 * x%d * $0' % i for i in range(arity)])}) } ) } % argValues = [i * 10 for i in range(1, arity + 1)] % args = ', '.join([str(v) for v in argValues]) % expectedValue = sum([v * v for v in argValues]) % expectedGradientValues = [2 * v for v in argValues] % expectedGradients = '(' + ', '.join([str(g) for g in expectedGradientValues]) + ')' DifferentialOperatorTestSuite.test("valueWithPullback_${arity}") { let (value, pb) = valueWithPullback(at: ${args}, of: exampleDiffFunc_${arity}) expectEqual(${expectedValue}, value) expectEqual(${expectedGradients}, pb(1)) } DifferentialOperatorTestSuite.test("pullback_${arity}") { let pb = pullback(at: ${args}, of: exampleDiffFunc_${arity}) expectEqual(${expectedGradients}, pb(1)) } DifferentialOperatorTestSuite.test("gradient_${arity}") { let grad = gradient(at: ${args}, of: exampleDiffFunc_${arity}) expectEqual(${expectedGradients}, grad) } DifferentialOperatorTestSuite.test("valueWithGradient_${arity}") { let (value, grad) = valueWithGradient(at: ${args}, of: exampleDiffFunc_${arity}) expectEqual(${expectedValue}, value) expectEqual(${expectedGradients}, grad) } DifferentialOperatorTestSuite.test("gradient_curried_${arity}") { let gradF = gradient(of: exampleDiffFunc_${arity}) expectEqual(${expectedGradients}, gradF(${args})) } % end runAllTests()