Files
swift-mirror/test/AutoDiff/stdlib/differential_operators.swift.gyb
Slava Pestov 4f6ba29715 AutoDiff: Disable requirement machine when building or testing Differentiation library
The SIL type lowering logic for AutoDiff gets the substituted generic signature
mixed up with the invocation generic signature, so it tries to ask questions
about DependentMemberTypes in a signature with no requirements. This triggers
assertions when the requirement machine is enabled.

Disable the requirement machine until this is fixed.
2021-07-30 19:42:31 -04:00

65 lines
2.2 KiB
Swift

// RUN: %empty-directory(%t)
// RUN: %gyb %s -o %t/differential_operators.swift
// RUN: %target-build-swift %t/differential_operators.swift -o %t/differential_operators -Xfrontend -requirement-machine=off
// RUN: %target-codesign %t/differential_operators
// RUN: %target-run %t/differential_operators
// REQUIRES: executable_test
import _Differentiation
import StdlibUnittest
var DifferentialOperatorTestSuite = TestSuite("DifferentialOperator")
% for arity in range(1, 3 + 1):
% params = ', '.join(['_ x%d: Float' % i for i in range(arity)])
% pb_return_type = '(' + ', '.join(['Float' for _ in range(arity)]) + ')'
func exampleDiffFunc_${arity}(${params}) -> Float {
fatalError()
}
@derivative(of: exampleDiffFunc_${arity})
func exampleVJP_${arity}(${params}) -> (value: Float, pullback: (Float) -> ${pb_return_type}) {
(
${' + '.join(['x%d * x%d' % (i, i) for i in range(arity)])},
{ (${', '.join(['2 * x%d * $0' % i for i in range(arity)])}) }
)
}
% argValues = [i * 10 for i in range(1, arity + 1)]
% args = ', '.join([str(v) for v in argValues])
% expectedValue = sum([v * v for v in argValues])
% expectedGradientValues = [2 * v for v in argValues]
% expectedGradients = '(' + ', '.join([str(g) for g in expectedGradientValues]) + ')'
DifferentialOperatorTestSuite.test("valueWithPullback_${arity}") {
let (value, pb) = valueWithPullback(at: ${args}, of: exampleDiffFunc_${arity})
expectEqual(${expectedValue}, value)
expectEqual(${expectedGradients}, pb(1))
}
DifferentialOperatorTestSuite.test("pullback_${arity}") {
let pb = pullback(at: ${args}, of: exampleDiffFunc_${arity})
expectEqual(${expectedGradients}, pb(1))
}
DifferentialOperatorTestSuite.test("gradient_${arity}") {
let grad = gradient(at: ${args}, of: exampleDiffFunc_${arity})
expectEqual(${expectedGradients}, grad)
}
DifferentialOperatorTestSuite.test("valueWithGradient_${arity}") {
let (value, grad) = valueWithGradient(at: ${args}, of: exampleDiffFunc_${arity})
expectEqual(${expectedValue}, value)
expectEqual(${expectedGradients}, grad)
}
DifferentialOperatorTestSuite.test("gradient_curried_${arity}") {
let gradF = gradient(of: exampleDiffFunc_${arity})
expectEqual(${expectedGradients}, gradF(${args}))
}
% end
runAllTests()