mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
65 lines
2.2 KiB
Swift
65 lines
2.2 KiB
Swift
// RUN: %empty-directory(%t)
|
|
// RUN: %gyb %s -o %t/differential_operators.swift
|
|
// RUN: %target-build-swift %t/differential_operators.swift -o %t/differential_operators
|
|
// RUN: %target-codesign %t/differential_operators
|
|
// RUN: %target-run %t/differential_operators
|
|
// REQUIRES: executable_test
|
|
|
|
import _Differentiation
|
|
|
|
import StdlibUnittest
|
|
|
|
var DifferentialOperatorTestSuite = TestSuite("DifferentialOperator")
|
|
|
|
% for arity in range(1, 3 + 1):
|
|
|
|
% params = ', '.join(['_ x%d: Float' % i for i in range(arity)])
|
|
% pb_return_type = '(' + ', '.join(['Float' for _ in range(arity)]) + ')'
|
|
func exampleDiffFunc_${arity}(${params}) -> Float {
|
|
fatalError()
|
|
}
|
|
@derivative(of: exampleDiffFunc_${arity})
|
|
func exampleVJP_${arity}(${params}) -> (value: Float, pullback: (Float) -> ${pb_return_type}) {
|
|
(
|
|
${' + '.join(['x%d * x%d' % (i, i) for i in range(arity)])},
|
|
{ (${', '.join(['2 * x%d * $0' % i for i in range(arity)])}) }
|
|
)
|
|
}
|
|
|
|
% argValues = [i * 10 for i in range(1, arity + 1)]
|
|
% args = ', '.join([str(v) for v in argValues])
|
|
% expectedValue = sum([v * v for v in argValues])
|
|
% expectedGradientValues = [2 * v for v in argValues]
|
|
% expectedGradients = '(' + ', '.join([str(g) for g in expectedGradientValues]) + ')'
|
|
|
|
DifferentialOperatorTestSuite.test("valueWithPullback_${arity}") {
|
|
let (value, pb) = valueWithPullback(at: ${args}, of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedValue}, value)
|
|
expectEqual(${expectedGradients}, pb(1))
|
|
}
|
|
|
|
DifferentialOperatorTestSuite.test("pullback_${arity}") {
|
|
let pb = pullback(at: ${args}, of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedGradients}, pb(1))
|
|
}
|
|
|
|
DifferentialOperatorTestSuite.test("gradient_${arity}") {
|
|
let grad = gradient(at: ${args}, of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedGradients}, grad)
|
|
}
|
|
|
|
DifferentialOperatorTestSuite.test("valueWithGradient_${arity}") {
|
|
let (value, grad) = valueWithGradient(at: ${args}, of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedValue}, value)
|
|
expectEqual(${expectedGradients}, grad)
|
|
}
|
|
|
|
DifferentialOperatorTestSuite.test("gradient_curried_${arity}") {
|
|
let gradF = gradient(of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedGradients}, gradF(${args}))
|
|
}
|
|
|
|
% end
|
|
|
|
runAllTests()
|