mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
* Lookup for custom derivatives in non-primary source files after typecheck is finished for the primary source. This registers all custom derivatives before autodiff transformations and makes them available to them. Fully resolves #55170
36 lines
901 B
Swift
36 lines
901 B
Swift
import _Differentiation
|
|
|
|
@inlinable
|
|
@derivative(of: min)
|
|
func minVJP<T: Comparable & Differentiable>(
|
|
_ x: T,
|
|
_ y: T
|
|
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
|
|
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
|
|
if x <= y {
|
|
return (v, .zero)
|
|
}
|
|
else {
|
|
return (.zero, v)
|
|
}
|
|
}
|
|
return (value: min(x, y), pullback: pullback)
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: max)
|
|
func maxVJP<T: Comparable & Differentiable>(
|
|
_ x: T,
|
|
_ y: T
|
|
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
|
|
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
|
|
if x < y {
|
|
return (.zero, v)
|
|
}
|
|
else {
|
|
return (v, .zero)
|
|
}
|
|
}
|
|
return (value: max(x, y), pullback: pullback)
|
|
}
|