Files
swift-mirror/test/AutoDiff/stdlib/differential_operators.swift.gyb
Richard Wei f79391b684 [AutoDiff] Remove 'differentiableFunction(from:)' and 'linearFunction(from:)'.
Remove unused APIs `differentiableFunction(from:)` and `linearFunction(from:)`. They were never official APIs, are not included in the [initial proposal](https://github.com/rxwei/swift-evolution/blob/autodiff/proposals/0000-differentiable-programming.md#make-a-function-differentiable-using-derivative), and are unused by existing supported client libraries (SwiftFusion and S4TF). Most importantly, they block crucial optimizations on linear map closures (#34935) and would need nontrivial work in SILGen to support.
2021-01-12 11:58:48 -08:00

72 lines
2.4 KiB
Swift

// RUN: %empty-directory(%t)
// RUN: %gyb %s -o %t/differential_operators.swift
// RUN: %target-build-swift %t/differential_operators.swift -o %t/differential_operators
// RUN: %target-codesign %t/differential_operators
// RUN: %target-run %t/differential_operators
// REQUIRES: executable_test
import _Differentiation
import StdlibUnittest
var DifferentialOperatorTestSuite = TestSuite("DifferentialOperator")
% for arity in range(1, 3 + 1):
% params = ', '.join(['_ x%d: Float' % i for i in range(arity)])
% pb_return_type = '(' + ', '.join(['Float' for _ in range(arity)]) + ')'
func exampleDiffFunc_${arity}(${params}) -> Float {
fatalError()
}
@derivative(of: exampleDiffFunc_${arity})
func exampleVJP_${arity}(${params}) -> (value: Float, pullback: (Float) -> ${pb_return_type}) {
(
${' + '.join(['x%d * x%d' % (i, i) for i in range(arity)])},
{ (${', '.join(['2 * x%d * $0' % i for i in range(arity)])}) }
)
}
% argValues = [i * 10 for i in range(1, arity + 1)]
% args = ', '.join([str(v) for v in argValues])
% expectedValue = sum([v * v for v in argValues])
% expectedGradientValues = [2 * v for v in argValues]
% expectedGradients = '(' + ', '.join([str(g) for g in expectedGradientValues]) + ')'
DifferentialOperatorTestSuite.test("valueWithPullback_${arity}") {
let (value, pb) = valueWithPullback(at: ${args}, in: exampleDiffFunc_${arity})
expectEqual(${expectedValue}, value)
expectEqual(${expectedGradients}, pb(1))
}
DifferentialOperatorTestSuite.test("pullback_${arity}") {
let pb = pullback(at: ${args}, in: exampleDiffFunc_${arity})
expectEqual(${expectedGradients}, pb(1))
}
DifferentialOperatorTestSuite.test("gradient_${arity}") {
let grad = gradient(at: ${args}, in: exampleDiffFunc_${arity})
expectEqual(${expectedGradients}, grad)
}
DifferentialOperatorTestSuite.test("valueWithGradient_${arity}") {
let (value, grad) = valueWithGradient(at: ${args}, in: exampleDiffFunc_${arity})
expectEqual(${expectedValue}, value)
expectEqual(${expectedGradients}, grad)
}
DifferentialOperatorTestSuite.test("gradient_curried_${arity}") {
let gradF = gradient(of: exampleDiffFunc_${arity})
expectEqual(${expectedGradients}, gradF(${args}))
}
DifferentialOperatorTestSuite.test("valueWithGradient_curried_${arity}") {
let valueWithGradF = valueWithGradient(of: exampleDiffFunc_${arity})
let (value, grad) = valueWithGradF(${args})
expectEqual(${expectedValue}, value)
expectEqual(${expectedGradients}, grad)
}
% end
runAllTests()