Files
swift-mirror/test/AutoDiff/validation-test/custom_derivatives.swift
Richard Wei f79391b684 [AutoDiff] Remove 'differentiableFunction(from:)' and 'linearFunction(from:)'.
Remove unused APIs `differentiableFunction(from:)` and `linearFunction(from:)`. They were never official APIs, are not included in the [initial proposal](https://github.com/rxwei/swift-evolution/blob/autodiff/proposals/0000-differentiable-programming.md#make-a-function-differentiable-using-derivative), and are unused by existing supported client libraries (SwiftFusion and S4TF). Most importantly, they block crucial optimizations on linear map closures (#34935) and would need nontrivial work in SILGen to support.
2021-01-12 11:58:48 -08:00

59 lines
1.4 KiB
Swift

// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
#if canImport(Darwin)
import Darwin.C
#elseif canImport(Glibc)
import Glibc
#elseif os(Windows)
import CRT
#else
#error("Unsupported platform")
#endif
import DifferentiationUnittest
var CustomDerivativesTests = TestSuite("CustomDerivatives")
// Specify non-differentiable functions.
// These will be wrapped in `differentiableFunction` and tested.
func unary(_ x: Tracked<Float>) -> Tracked<Float> {
var x = x
x *= 2
return x
}
func binary(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
var x = x
x *= y
return x
}
CustomDerivativesTests.testWithLeakChecking("SumOfGradPieces") {
var grad: Tracked<Float> = 0
func addToGrad(_ x: inout Tracked<Float>) { grad += x }
_ = gradient(at: 4) { (x: Tracked<Float>) in
x.withDerivative(addToGrad)
* x.withDerivative(addToGrad)
* x.withDerivative(addToGrad)
}
expectEqual(48, grad)
}
CustomDerivativesTests.testWithLeakChecking("ModifyGradientOfSum") {
expectEqual(30, gradient(at: 4) { (x: Tracked<Float>) in
x.withDerivative { $0 *= 10 } + x.withDerivative { $0 *= 20 }
})
}
CustomDerivativesTests.testWithLeakChecking("WithoutDerivative") {
expectEqual(0, gradient(at: Tracked<Float>(4)) { x in
withoutDerivative(at: x) { x in
Tracked<Float>(sinf(x.value) + cosf(x.value))
}
})
}
runAllTests()