mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Update LoadableByAddress to handle AutoDiff-related instructions: - `differentiable_function` - `differentiable_function_extract` - `linear_function` - `linear_function_extract` - `differentiability_witness_function`
80 lines
4.0 KiB
Swift
80 lines
4.0 KiB
Swift
// RUN: %target-swift-frontend -c -enable-large-loadable-types -Xllvm -sil-verify-after-pass=loadable-address %s
|
|
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s -check-prefix=CHECK-SIL
|
|
// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %s 2>&1 | %FileCheck %s -check-prefix=CHECK-LBA-SIL
|
|
// RUN: %target-run-simple-swift
|
|
// REQUIRES: executable_test
|
|
|
|
// `isLargeLoadableType` depends on the ABI and differs between architectures.
|
|
// REQUIRES: CPU=x86_64
|
|
|
|
// TF-11: Verify that LoadableByAddress works with differentiation-related instructions:
|
|
// - `differentiable_function`
|
|
// - `differentiable_function_extract`
|
|
|
|
// TODO: Add tests for `@differentiable(linear)` functions.
|
|
|
|
import _Differentiation
|
|
import StdlibUnittest
|
|
|
|
var LBATests = TestSuite("LoadableByAddress")
|
|
|
|
// `Large` is a large loadable type.
|
|
// `Large.TangentVector` is not a large loadable type.
|
|
struct Large : Differentiable {
|
|
var a: Float
|
|
var b: Float
|
|
var c: Float
|
|
var d: Float
|
|
@noDerivative let e: Float
|
|
}
|
|
|
|
@_silgen_name("large2large")
|
|
@differentiable
|
|
func large2large(_ foo: Large) -> Large {
|
|
foo
|
|
}
|
|
|
|
// `large2large` old verification error:
|
|
// SIL verification failed: JVP type does not match expected JVP type
|
|
// $@callee_guaranteed (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
|
|
// $@callee_guaranteed (@in_constant Large) -> (@out Large, @owned @callee_guaranteed (@in_constant Large.TangentVector) -> @out Large.TangentVector)
|
|
|
|
@_silgen_name("large2small")
|
|
@differentiable
|
|
func large2small(_ foo: Large) -> Float {
|
|
foo.a
|
|
}
|
|
|
|
// `large2small` old verification error:
|
|
// SIL verification failed: JVP type does not match expected JVP type
|
|
// $@callee_guaranteed (@in_constant Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float)
|
|
// $@callee_guaranteed (@in_constant Large) -> (Float, @owned @callee_guaranteed (@in_constant Large.TangentVector) -> Float)
|
|
|
|
// CHECK-SIL: sil hidden @large2large : $@convention(thin) (Large) -> Large {
|
|
// CHECK-LBA-SIL: sil hidden @large2large : $@convention(thin) (@in_constant Large) -> @out Large {
|
|
|
|
// CHECK-SIL-LABEL: sil hidden @large2small : $@convention(thin) (Large) -> Float {
|
|
// CHECK-LBA-SIL: sil hidden @large2small : $@convention(thin) (@in_constant Large) -> Float {
|
|
|
|
// CHECK-SIL: sil hidden @AD__large2large__jvp_src_0_wrt_0 : $@convention(thin) (Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
|
|
// CHECK-LBA-SIL: sil hidden @AD__large2large__jvp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
|
|
|
|
// CHECK-SIL: sil hidden @AD__large2large__vjp_src_0_wrt_0 : $@convention(thin) (Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
|
|
// CHECK-LBA-SIL: sil hidden @AD__large2large__vjp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
|
|
|
|
// CHECK-SIL: sil hidden @AD__large2small__jvp_src_0_wrt_0 : $@convention(thin) (Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float) {
|
|
// CHECK-LBA-SIL: sil hidden @AD__large2small__jvp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float) {
|
|
|
|
// CHECK-SIL: sil hidden @AD__large2small__vjp_src_0_wrt_0 : $@convention(thin) (Large) -> (Float, @owned @callee_guaranteed (Float) -> Large.TangentVector) {
|
|
// CHECK-LBA-SIL: sil hidden @AD__large2small__vjp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Float, @owned @callee_guaranteed (Float) -> Large.TangentVector) {
|
|
|
|
LBATests.test("Correctness") {
|
|
let one = Large.TangentVector(a: 1, b: 1, c: 1, d: 1)
|
|
expectEqual(one,
|
|
pullback(at: Large(a: 0, b: 0, c: 0, d: 0, e: 0), in: large2large)(one))
|
|
expectEqual(Large.TangentVector(a: 1, b: 0, c: 0, d: 0),
|
|
gradient(at: Large(a: 0, b: 0, c: 0, d: 0, e: 0), in: large2small))
|
|
}
|
|
|
|
runAllTests()
|