Files
swift-mirror/test/AutoDiff/SILOptimizer/fragile_curry_thunk.swift
Daniil Kovalev 1a42a0ce5f [AutoDiff] Support curry thunks differentiation in fragile funcs (#77615)
Inside fragile functions, we expect function derivatives to be public, which could be achieved by either explicitly marking the functions as differentiable or having a public explicit derivative defined for them. This is obviously not
possible for single and double curry thunks which are a special case of `AutoClosureExpr`.

Instead of looking at the thunk itself, we unwrap it and look at the function being wrapped. While the thunk itself and its differentiability witness will not have public visibility, it's not an issue for the case where the function being wrapped (and its witness) have public visibility.

Fixes #54819
Fixes #75776
2025-02-17 14:43:50 -08:00

74 lines
2.1 KiB
Swift

// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s -o /dev/null
import _Differentiation
/// Minimal reproducer for both single and double curry thunk
@inlinable
func caller<Thing: Differentiable & FloatingPoint>(
of f: @differentiable(reverse) (_: Thing) -> Thing
) -> Int where Thing.TangentVector == Thing {
return 42
}
public struct Struct<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
@inlinable
static func foo_single() -> Int {
return caller(of: callee_single) // No error expected
}
@inlinable
@differentiable(reverse)
static func callee_single(input: Thing) -> Thing {
return input
}
@inlinable
func foo_double() -> Int {
return caller(of: callee_double) // No error expected
}
@inlinable
@differentiable(reverse)
func callee_double(input: Thing) -> Thing {
return input
}
}
/// Reproducer from https://github.com/swiftlang/swift/issues/75776
public struct Solution2<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
@inlinable
public static func optimization() -> Thing {
var initial = Thing.zero
let (_, delta) = valueWithGradient(at: initial, of: simulationWithLoss) // No error expected
initial.move(by: delta)
return initial
}
@inlinable
@differentiable(reverse)
static func simulationWithLoss(input: Thing) -> Thing {
return input // implementation
}
}
/// Reproducer from https://github.com/swiftlang/swift/issues/54819
public struct TF_688_Struct<Scalar> {
var x: Scalar
}
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
@differentiable(reverse)
public static func id(x: Self) -> Self {
return x
}
}
@differentiable(reverse, wrt: x)
public func TF_688<Scalar: Differentiable>(
_ x: TF_688_Struct<Scalar>,
reduction: @differentiable(reverse) (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id // No error expected
) -> TF_688_Struct<Scalar> {
reduction(x)
}