Files
Daniil Kovalev 194199643f [AutoDiff] Fix assert on missing struct decl on cross-file derivative search (#77183)
Consider:

1. File struct.swift defining `struct Struct` with `static func max` member
2. File derivatives.swift defining `extension Struct` with custom derivative of the `max` function
3. File error.swift defining a differentiable function which uses `Struct.max`.

Previously, when passing error.swift as primary file and derivatives.swift as a secondary file to swift-frontend (and forgetting to pass struct.swift as a secondary file as well), an assertion failure was triggered.

This patch fixes the issue by adding a check against `ErrorType` in `findAutoDiffOriginalFunctionDecl` before calling `lookupMember`.

Co-authored-by: Anton Korobeynikov <anton@korobeynikov.info>
2024-10-29 02:20:50 -07:00

38 lines
1.0 KiB
Swift

import _Differentiation
@inlinable
@derivative(of: min)
func minVJP<T: Comparable & Differentiable>(
_ x: T,
_ y: T
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
if x <= y {
return (v, .zero)
}
else {
return (.zero, v)
}
}
return (value: min(x, y), pullback: pullback)
}
extension Struct {
@inlinable
@derivative(of: max) // expected-error {{cannot find 'max' in scope}}
static func maxVJP<T: Comparable & Differentiable>(
_ x: T,
_ y: T
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
if x < y {
return (.zero, v)
}
else {
return (v, .zero)
}
}
return (value: max(x, y), pullback: pullback)
}
}