mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
127 lines
4.9 KiB
Swift
127 lines
4.9 KiB
Swift
// RUN: %target-run-simple-swiftgyb(-Xfrontend -enable-experimental-forward-mode-differentiation)
|
|
// REQUIRES: executable_test
|
|
|
|
#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS)
|
|
import Darwin.C.tgmath
|
|
#elseif os(Linux) || os(FreeBSD) || os(PS4) || os(Android) || os(Cygwin) || os(Haiku)
|
|
import Glibc
|
|
#elseif os(Windows)
|
|
import MSVCRT
|
|
#else
|
|
#error("Unsupported platform")
|
|
#endif
|
|
|
|
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
|
|
typealias TestLiteralType = Float80
|
|
#else
|
|
typealias TestLiteralType = Double
|
|
#endif
|
|
|
|
import StdlibUnittest
|
|
import _Differentiation
|
|
|
|
let DerivativeTests = TestSuite("TGMath")
|
|
|
|
func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
|
|
ulps allowed: T = 3,
|
|
file: String = #file, line: UInt = #line)
|
|
where T: BinaryFloatingPoint {
|
|
if actual == T(expected) || actual.isNaN && expected.isNaN {
|
|
return
|
|
}
|
|
// Compute error in ulp, compare to tolerance.
|
|
let absoluteError = T(abs(TestLiteralType(actual) - expected))
|
|
let ulpError = absoluteError / T(expected).ulp
|
|
expectTrue(ulpError <= allowed,
|
|
"\(actual) != \(expected) as \(T.self)" +
|
|
"\n \(ulpError)-ulp error exceeds \(allowed)-ulp tolerance.",
|
|
file: file, line: line)
|
|
}
|
|
|
|
func checkGradient<T: BinaryFloatingPoint & Differentiable>(
|
|
_ f: @differentiable (T, T) -> T,
|
|
_ x: T,
|
|
_ y: T)
|
|
where T == T.TangentVector {
|
|
let eps = T(0.01)
|
|
let grad = gradient(at: x, y, in: f)
|
|
let dfdx = (f(x + eps, y) - f(x, y)) / eps
|
|
let dfdy = (f(x, y + eps) - f(x, y)) / eps
|
|
expectEqualWithTolerance(TestLiteralType(dfdx), grad.0, ulps: 192)
|
|
expectEqualWithTolerance(TestLiteralType(dfdy), grad.1, ulps: 192)
|
|
}
|
|
|
|
%for op in ['derivative', 'gradient']:
|
|
%for T in ['Float', 'Float80']:
|
|
|
|
%if T == 'Float80':
|
|
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
|
|
%end
|
|
|
|
DerivativeTests.test("${op}_${T}") {
|
|
expectEqualWithTolerance(7.3890560989306502274, ${op}(at: 2 as ${T}, in: exp))
|
|
expectEqualWithTolerance(2.772588722239781145, ${op}(at: 2 as ${T}, in: exp2))
|
|
expectEqualWithTolerance(7.3890560989306502274, ${op}(at: 2 as ${T}, in: expm1))
|
|
expectEqualWithTolerance(0.5, ${op}(at: 2 as ${T}, in: log))
|
|
expectEqualWithTolerance(0.21714724095162590833, ${op}(at: 2 as ${T}, in: log10))
|
|
expectEqualWithTolerance(0.7213475204444817278, ${op}(at: 2 as ${T}, in: log2))
|
|
expectEqualWithTolerance(0.33333333333333333334, ${op}(at: 2 as ${T}, in: log1p))
|
|
expectEqualWithTolerance(5.774399204041917612, ${op}(at: 2 as ${T}, in: tan))
|
|
expectEqualWithTolerance(-0.9092974268256816954, ${op}(at: 2 as ${T}, in: cos))
|
|
expectEqualWithTolerance(-0.416146836547142387, ${op}(at: 2 as ${T}, in: sin))
|
|
expectEqualWithTolerance(1.154700538379251529, ${op}(at: 0.5 as ${T}, in: asin))
|
|
expectEqualWithTolerance(-1.154700538379251529, ${op}(at: 0.5 as ${T}, in: acos))
|
|
expectEqualWithTolerance(0.8, ${op}(at: 0.5 as ${T}, in: atan))
|
|
expectEqualWithTolerance(3.7621956910836314597, ${op}(at: 2 as ${T}, in: sinh))
|
|
expectEqualWithTolerance(3.6268604078470187677, ${op}(at: 2 as ${T}, in: cosh))
|
|
expectEqualWithTolerance(0.07065082485316446565, ${op}(at: 2 as ${T}, in: tanh))
|
|
expectEqualWithTolerance(0.44721359549995793928, ${op}(at: 2 as ${T}, in: asinh))
|
|
expectEqualWithTolerance(0.5773502691896257645, ${op}(at: 2 as ${T}, in: acosh))
|
|
expectEqualWithTolerance(1.3333333333333333334, ${op}(at: 0.5 as ${T}, in: atanh))
|
|
expectEqualWithTolerance(0.020666985354092053575, ${op}(at: 2 as ${T}, in: erf))
|
|
expectEqualWithTolerance(-0.020666985354092053575, ${op}(at: 2 as ${T}, in: erfc))
|
|
expectEqualWithTolerance(0.35355339059327376222, ${op}(at: 2 as ${T}, in: { sqrt($0) }))
|
|
expectEqualWithTolerance(0, ${op}(at: 2 as ${T}, in: { ceil($0) }))
|
|
expectEqualWithTolerance(0, ${op}(at: 2 as ${T}, in: { floor($0) }))
|
|
expectEqualWithTolerance(0, ${op}(at: 2 as ${T}, in: { round($0) }))
|
|
expectEqualWithTolerance(0, ${op}(at: 2 as ${T}, in: { trunc($0) }))
|
|
|
|
// Differential operator specific tests.
|
|
|
|
// fma
|
|
let dfma = ${op}(at: 4 as ${T}, 5 as ${T}, 6 as ${T}, in: fma)
|
|
%if op == 'gradient':
|
|
expectEqualWithTolerance(5, dfma.0)
|
|
expectEqualWithTolerance(4, dfma.1)
|
|
expectEqualWithTolerance(1, dfma.2)
|
|
%else: # if op == 'derivative'
|
|
expectEqualWithTolerance(10, dfma)
|
|
%end
|
|
|
|
// remainder, fmod
|
|
for a in -10...10 {
|
|
let x = ${T}(a)
|
|
for b in -10...10 {
|
|
let y = ${T}(b)
|
|
guard b != 0 && remainder(x, y).sign == remainder(x + ${T}(0.001), y).sign &&
|
|
remainder(x, y).sign == remainder(x, y + ${T}(0.001)).sign
|
|
else { continue }
|
|
%if op == 'gradient':
|
|
checkGradient({ remainder($0, $1) }, x, y)
|
|
checkGradient({ fmod($0, $1) }, x, y)
|
|
%else: # if op == 'derivative'
|
|
// TODO(TF-1108): Implement JVPs for `remainder` and `fmod`.
|
|
%end
|
|
}
|
|
}
|
|
}
|
|
|
|
%if T == 'Float80':
|
|
#endif
|
|
%end
|
|
|
|
%end # for T in ['Float', 'Float80']:
|
|
%end # for op in ['derivative', 'gradient']:
|
|
|
|
runAllTests()
|