mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Remove unused APIs `differentiableFunction(from:)` and `linearFunction(from:)`. They were never official APIs, are not included in the [initial proposal](https://github.com/rxwei/swift-evolution/blob/autodiff/proposals/0000-differentiable-programming.md#make-a-function-differentiable-using-derivative), and are unused by existing supported client libraries (SwiftFusion and S4TF). Most importantly, they block crucial optimizations on linear map closures (#34935) and would need nontrivial work in SILGen to support.
59 lines
1.4 KiB
Swift
59 lines
1.4 KiB
Swift
// RUN: %target-run-simple-swift
|
|
// REQUIRES: executable_test
|
|
|
|
import StdlibUnittest
|
|
#if canImport(Darwin)
|
|
import Darwin.C
|
|
#elseif canImport(Glibc)
|
|
import Glibc
|
|
#elseif os(Windows)
|
|
import CRT
|
|
#else
|
|
#error("Unsupported platform")
|
|
#endif
|
|
import DifferentiationUnittest
|
|
|
|
var CustomDerivativesTests = TestSuite("CustomDerivatives")
|
|
|
|
// Specify non-differentiable functions.
|
|
// These will be wrapped in `differentiableFunction` and tested.
|
|
|
|
func unary(_ x: Tracked<Float>) -> Tracked<Float> {
|
|
var x = x
|
|
x *= 2
|
|
return x
|
|
}
|
|
|
|
func binary(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
|
|
var x = x
|
|
x *= y
|
|
return x
|
|
}
|
|
|
|
CustomDerivativesTests.testWithLeakChecking("SumOfGradPieces") {
|
|
var grad: Tracked<Float> = 0
|
|
func addToGrad(_ x: inout Tracked<Float>) { grad += x }
|
|
_ = gradient(at: 4) { (x: Tracked<Float>) in
|
|
x.withDerivative(addToGrad)
|
|
* x.withDerivative(addToGrad)
|
|
* x.withDerivative(addToGrad)
|
|
}
|
|
expectEqual(48, grad)
|
|
}
|
|
|
|
CustomDerivativesTests.testWithLeakChecking("ModifyGradientOfSum") {
|
|
expectEqual(30, gradient(at: 4) { (x: Tracked<Float>) in
|
|
x.withDerivative { $0 *= 10 } + x.withDerivative { $0 *= 20 }
|
|
})
|
|
}
|
|
|
|
CustomDerivativesTests.testWithLeakChecking("WithoutDerivative") {
|
|
expectEqual(0, gradient(at: Tracked<Float>(4)) { x in
|
|
withoutDerivative(at: x) { x in
|
|
Tracked<Float>(sinf(x.value) + cosf(x.value))
|
|
}
|
|
})
|
|
}
|
|
|
|
runAllTests()
|