mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Consider: 1. File struct.swift defining `struct Struct` with `static func max` member 2. File derivatives.swift defining `extension Struct` with custom derivative of the `max` function 3. File error.swift defining a differentiable function which uses `Struct.max`. Previously, when passing error.swift as primary file and derivatives.swift as a secondary file to swift-frontend (and forgetting to pass struct.swift as a secondary file as well), an assertion failure was triggered. This patch fixes the issue by adding a check against `ErrorType` in `findAutoDiffOriginalFunctionDecl` before calling `lookupMember`. Co-authored-by: Anton Korobeynikov <anton@korobeynikov.info>
38 lines
993 B
Swift
38 lines
993 B
Swift
import _Differentiation
|
|
|
|
@inlinable
|
|
@derivative(of: min)
|
|
func minVJP<T: Comparable & Differentiable>(
|
|
_ x: T,
|
|
_ y: T
|
|
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
|
|
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
|
|
if x <= y {
|
|
return (v, .zero)
|
|
}
|
|
else {
|
|
return (.zero, v)
|
|
}
|
|
}
|
|
return (value: min(x, y), pullback: pullback)
|
|
}
|
|
|
|
extension Struct {
|
|
@inlinable
|
|
@derivative(of: max)
|
|
static func maxVJP<T: Comparable & Differentiable>(
|
|
_ x: T,
|
|
_ y: T
|
|
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
|
|
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
|
|
if x < y {
|
|
return (.zero, v)
|
|
}
|
|
else {
|
|
return (v, .zero)
|
|
}
|
|
}
|
|
return (value: max(x, y), pullback: pullback)
|
|
}
|
|
}
|