mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
The SIL type lowering logic for AutoDiff gets the substituted generic signature mixed up with the invocation generic signature, so it tries to ask questions about DependentMemberTypes in a signature with no requirements. This triggers assertions when the requirement machine is enabled. Disable the requirement machine until this is fixed.
65 lines
2.2 KiB
Swift
65 lines
2.2 KiB
Swift
// RUN: %empty-directory(%t)
|
|
// RUN: %gyb %s -o %t/differential_operators.swift
|
|
// RUN: %target-build-swift %t/differential_operators.swift -o %t/differential_operators -Xfrontend -requirement-machine=off
|
|
// RUN: %target-codesign %t/differential_operators
|
|
// RUN: %target-run %t/differential_operators
|
|
// REQUIRES: executable_test
|
|
|
|
import _Differentiation
|
|
|
|
import StdlibUnittest
|
|
|
|
var DifferentialOperatorTestSuite = TestSuite("DifferentialOperator")
|
|
|
|
% for arity in range(1, 3 + 1):
|
|
|
|
% params = ', '.join(['_ x%d: Float' % i for i in range(arity)])
|
|
% pb_return_type = '(' + ', '.join(['Float' for _ in range(arity)]) + ')'
|
|
func exampleDiffFunc_${arity}(${params}) -> Float {
|
|
fatalError()
|
|
}
|
|
@derivative(of: exampleDiffFunc_${arity})
|
|
func exampleVJP_${arity}(${params}) -> (value: Float, pullback: (Float) -> ${pb_return_type}) {
|
|
(
|
|
${' + '.join(['x%d * x%d' % (i, i) for i in range(arity)])},
|
|
{ (${', '.join(['2 * x%d * $0' % i for i in range(arity)])}) }
|
|
)
|
|
}
|
|
|
|
% argValues = [i * 10 for i in range(1, arity + 1)]
|
|
% args = ', '.join([str(v) for v in argValues])
|
|
% expectedValue = sum([v * v for v in argValues])
|
|
% expectedGradientValues = [2 * v for v in argValues]
|
|
% expectedGradients = '(' + ', '.join([str(g) for g in expectedGradientValues]) + ')'
|
|
|
|
DifferentialOperatorTestSuite.test("valueWithPullback_${arity}") {
|
|
let (value, pb) = valueWithPullback(at: ${args}, of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedValue}, value)
|
|
expectEqual(${expectedGradients}, pb(1))
|
|
}
|
|
|
|
DifferentialOperatorTestSuite.test("pullback_${arity}") {
|
|
let pb = pullback(at: ${args}, of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedGradients}, pb(1))
|
|
}
|
|
|
|
DifferentialOperatorTestSuite.test("gradient_${arity}") {
|
|
let grad = gradient(at: ${args}, of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedGradients}, grad)
|
|
}
|
|
|
|
DifferentialOperatorTestSuite.test("valueWithGradient_${arity}") {
|
|
let (value, grad) = valueWithGradient(at: ${args}, of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedValue}, value)
|
|
expectEqual(${expectedGradients}, grad)
|
|
}
|
|
|
|
DifferentialOperatorTestSuite.test("gradient_curried_${arity}") {
|
|
let gradF = gradient(of: exampleDiffFunc_${arity})
|
|
expectEqual(${expectedGradients}, gradF(${args}))
|
|
}
|
|
|
|
% end
|
|
|
|
runAllTests()
|