mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
`DifferentiableFunctionInst` now stores result indices. `SILAutoDiffIndices` now stores result indices instead of a source index. `@differentiable` SIL function types may now have multiple differentiability result indices and `@noDerivative` resutls. `@differentiable` AST function types do not have `@noDerivative` results (yet), so this functionality is not exposed to users. Resolves TF-689 and TF-1256. Infrastructural support for TF-983: supporting differentiation of `apply` instructions with multiple active semantic results.
118 lines
7.7 KiB
Plaintext
118 lines
7.7 KiB
Plaintext
// RUN: %target-swift-frontend -emit-ir %s | %FileCheck %s
|
|
|
|
sil_stage raw
|
|
|
|
import Swift
|
|
import Builtin
|
|
import _Differentiation
|
|
|
|
sil @f : $@convention(thin) (Float) -> Float {
|
|
bb0(%0 : $Float):
|
|
return undef : $Float
|
|
}
|
|
|
|
sil @f_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
bb0(%0 : $Float):
|
|
return undef : $(Float, @callee_guaranteed (Float) -> Float)
|
|
}
|
|
|
|
sil @f_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
bb0(%0 : $Float):
|
|
return undef : $(Float, @callee_guaranteed (Float) -> Float)
|
|
}
|
|
|
|
sil @test_form_diff_func : $@convention(thin) () -> @owned @differentiable @callee_guaranteed (Float) -> Float {
|
|
bb0:
|
|
%orig = function_ref @f : $@convention(thin) (Float) -> Float
|
|
%origThick = thin_to_thick_function %orig : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
|
|
%jvp = function_ref @f_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
%jvpThick = thin_to_thick_function %jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) to $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
%vjp = function_ref @f_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
%vjpThick = thin_to_thick_function %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) to $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
%result = differentiable_function [parameters 0] [results 0] %origThick : $@callee_guaranteed (Float) -> Float with_derivative {%jvpThick : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjpThick : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
|
|
return %result : $@differentiable @callee_guaranteed (Float) -> Float
|
|
}
|
|
|
|
// CHECK-LABEL: define{{.*}}test_form_diff_func(<{ %swift.function, %swift.function, %swift.function }>*
|
|
// CHECK-SAME: [[OUT:%.*]])
|
|
// CHECK: [[OUT_ORIG:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 0
|
|
// CHECK: [[OUT_ORIG_FN:%.*]] = getelementptr{{.*}}[[OUT_ORIG]], i32 0, i32 0
|
|
// CHECK: store{{.*}}@f{{.*}}[[OUT_ORIG_FN]]
|
|
// CHECK: [[OUT_ORIG_DATA:%.*]] = getelementptr{{.*}} [[OUT_ORIG]], i32 0, i32 1
|
|
// CHECK: store{{.*}}null{{.*}} [[OUT_ORIG_DATA]]
|
|
// CHECK: [[OUT_JVP:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 1
|
|
// CHECK: [[OUT_JVP_FN:%.*]] = getelementptr{{.*}}[[OUT_JVP]], i32 0, i32 0
|
|
// CHECK: store{{.*}}@f_jvp{{.*}}[[OUT_JVP_FN]]
|
|
// CHECK: [[OUT_JVP_DATA:%.*]] = getelementptr{{.*}}[[OUT_JVP]], i32 0, i32 1
|
|
// CHECK: store{{.*}} null{{.*}}[[OUT_JVP_DATA]]
|
|
// CHECK: [[OUT_VJP:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 2
|
|
// CHECK: [[OUT_VJP_FN:%.*]] = getelementptr{{.*}}[[OUT_VJP]], i32 0, i32 0
|
|
// CHECK: store{{.*}}@f_vjp{{.*}}[[OUT_VJP_FN]]
|
|
// CHECK: [[OUT_VJP_DATA:%.*]] = getelementptr{{.*}}[[OUT_VJP]], i32 0, i32 1
|
|
// CHECK: store{{.*}}null{{.*}}[[OUT_VJP_DATA]]
|
|
|
|
sil @test_extract_components : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> (@owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), @owned @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
|
|
bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float):
|
|
%orig = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (Float) -> Float
|
|
%jvp = differentiable_function_extract [jvp] %0 : $@differentiable @callee_guaranteed (Float) -> Float
|
|
%vjp = differentiable_function_extract [vjp] %0 : $@differentiable @callee_guaranteed (Float) -> Float
|
|
%result = tuple (%orig : $@callee_guaranteed (Float) -> Float, %jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
|
|
return %result : $(@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
|
|
}
|
|
|
|
// CHECK-LABEL: define{{.*}}@test_extract_components(<{ %swift.function, %swift.function, %swift.function }>*
|
|
// CHECK-SAME: [[OUT:%.*]], <{ %swift.function, %swift.function, %swift.function }>*{{.*}}[[IN:%.*]])
|
|
// CHECK: [[ORIG:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 0
|
|
// CHECK: [[ORIG_FN_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 0
|
|
// CHECK: [[ORIG_FN:%.*]] = load{{.*}}[[ORIG_FN_ADDR]]
|
|
// CHECK: [[ORIG_DATA_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 1
|
|
// CHECK: [[ORIG_DATA:%.*]] = load{{.*}}[[ORIG_DATA_ADDR]]
|
|
// CHECK: [[JVP:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 1
|
|
// CHECK: [[JVP_FN_ADDR:%.*]] = getelementptr{{.*}}[[JVP]], i32 0, i32 0
|
|
// CHECK: [[JVP_FN:%.*]] = load{{.*}}[[JVP_FN_ADDR]]
|
|
// CHECK: [[JVP_DATA_ADDR:%.*]] = getelementptr{{.*}}[[JVP]], i32 0, i32 1
|
|
// CHECK: [[JVP_DATA:%.*]] = load{{.*}}[[JVP_DATA_ADDR]]
|
|
// CHECK: [[VJP:%.*]] = getelementptr{{.*}}[[IN]], i32 0, i32 2
|
|
// CHECK: [[VJP_FN_ADDR:%.*]] = getelementptr{{.*}}[[VJP]], i32 0, i32 0
|
|
// CHECK: [[VJP_FN:%.*]] = load{{.*}}[[VJP_FN_ADDR]]
|
|
// CHECK: [[VJP_DATA_ADDR:%.*]] = getelementptr{{.*}}[[VJP]], i32 0, i32 1
|
|
// CHECK: [[VJP_DATA:%.*]] = load{{.*}}[[VJP_DATA_ADDR]]
|
|
// CHECK: [[OUT_0:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 0
|
|
// CHECK: [[OUT_0_FN:%.*]] = getelementptr{{.*}}[[OUT_0]], i32 0, i32 0
|
|
// CHECK: store{{.*}}[[ORIG_FN]]{{.*}}[[OUT_0_FN]]
|
|
// CHECK: [[OUT_0_DATA:%.*]] = getelementptr{{.*}}[[OUT_0]], i32 0, i32 1
|
|
// CHECK: store{{.*}}[[ORIG_DATA]]{{.*}}[[OUT_0_DATA]]
|
|
// CHECK: [[OUT_1:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 1
|
|
// CHECK: [[OUT_1_FN:%.*]] = getelementptr{{.*}}[[OUT_1]], i32 0, i32 0
|
|
// CHECK: store{{.*}}[[JVP_FN]]{{.*}}[[OUT_1_FN]]
|
|
// CHECK: [[OUT_1_DATA:%.*]] = getelementptr{{.*}}[[OUT_1]], i32 0, i32 1
|
|
// CHECK: store{{.*}}[[JVP_DATA]]{{.*}}[[OUT_1_DATA]]
|
|
// CHECK: [[OUT_2:%.*]] = getelementptr{{.*}}[[OUT]], i32 0, i32 2
|
|
// CHECK: [[OUT_2_FN:%.*]] = getelementptr{{.*}}[[OUT_2]], i32 0, i32 0
|
|
// CHECK: store{{.*}}[[VJP_FN]]{{.*}}[[OUT_2_FN]]
|
|
// CHECK: [[OUT_2_DATA:%.*]] = getelementptr{{.*}}[[OUT_2]], i32 0, i32 1
|
|
// CHECK: store{{.*}}[[VJP_DATA]]{{.*}}[[OUT_2_DATA]]
|
|
|
|
sil @test_call_diff_fn : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float, Float) -> Float {
|
|
bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float, %1 : $Float):
|
|
%2 = apply %0(%1) : $@differentiable @callee_guaranteed (Float) -> Float
|
|
return %2 : $Float
|
|
}
|
|
|
|
// CHECK-LABEL: define{{.*}}@test_call_diff_fn(<{ %swift.function, %swift.function, %swift.function }>*
|
|
// CHECK-SAME: [[DIFF_FN:%.*]], float [[INPUT_FLOAT:%.*]])
|
|
// CHECK: [[ORIG:%.*]] = getelementptr{{.*}}[[DIFF_FN]], i32 0, i32 0
|
|
// CHECK: [[ORIG_FN_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 0
|
|
// CHECK: [[ORIG_FN:%.*]] = load{{.*}}[[ORIG_FN_ADDR]]
|
|
// CHECK: [[ORIG_DATA_ADDR:%.*]] = getelementptr{{.*}}[[ORIG]], i32 0, i32 1
|
|
// CHECK: [[ORIG_DATA:%.*]] = load{{.*}}[[ORIG_DATA_ADDR]]
|
|
// CHECK: [[ORIG_FN_CAST:%.*]] = bitcast{{.*}}[[ORIG_FN]]
|
|
// CHECK: [[RESULT:%.*]] = call swiftcc float [[ORIG_FN_CAST]](float [[INPUT_FLOAT]], %swift.refcounted* swiftself [[ORIG_DATA]])
|
|
// CHECK: ret float [[RESULT]]
|
|
|
|
sil @test_convert_escape_to_noescape : $@convention(thin) (@guaranteed @differentiable @callee_guaranteed (Float) -> Float) -> () {
|
|
bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float):
|
|
%1 = convert_escape_to_noescape %0 : $@differentiable @callee_guaranteed (Float) -> Float to $@noescape @differentiable @callee_guaranteed (Float) -> Float
|
|
return undef : $()
|
|
}
|