mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
198 lines
14 KiB
Plaintext
198 lines
14 KiB
Plaintext
// RUN: %target-sil-opt -sil-print-types -enable-sil-verify-all %s -sil-combine | %FileCheck %s
|
|
|
|
// SILCombine tests for differentiation-related instructions.
|
|
|
|
sil_stage canonical
|
|
|
|
import Swift
|
|
import _Differentiation
|
|
|
|
// MARK: `differentiable_function_extract` folding
|
|
|
|
sil @differentiable_function_extract_orig : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> (@callee_guaranteed (Float) -> Float) {
|
|
bb0(%orig : $@callee_guaranteed (Float) -> Float):
|
|
%diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float
|
|
%extracted_orig = differentiable_function_extract [original] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float
|
|
return %extracted_orig : $@callee_guaranteed (Float) -> Float
|
|
}
|
|
|
|
// CHECK-LABEL: sil @differentiable_function_extract_orig : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> @callee_guaranteed (Float) -> Float {
|
|
// CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float):
|
|
// CHECK: return [[ORIG_FN]] : $@callee_guaranteed (Float) -> Float
|
|
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_orig'
|
|
|
|
sil @differentiable_function_extract_jvp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
|
|
bb0(%orig : $@callee_guaranteed (Float) -> Float, %jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)):
|
|
%diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float with_derivative {%jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), undef : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
|
|
%extracted_jvp = differentiable_function_extract [jvp] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float
|
|
return %extracted_jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
}
|
|
|
|
// CHECK-LABEL: sil @differentiable_function_extract_jvp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
// CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float, [[JVP_FN:%.*]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)):
|
|
// CHECK: return [[JVP_FN]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_jvp'
|
|
|
|
sil @differentiable_function_extract_vjp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
|
|
bb0(%orig : $@callee_guaranteed (Float) -> Float, %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)):
|
|
%diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float with_derivative {undef : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
|
|
%extracted_vjp = differentiable_function_extract [vjp] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float
|
|
return %extracted_vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
}
|
|
|
|
// CHECK-LABEL: sil @differentiable_function_extract_vjp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
// CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float, [[VJP_FN:%.*]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)):
|
|
// CHECK: return [[VJP_FN]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp'
|
|
|
|
// Test `differentiatiable_function_extract [vjp] %diff_fn` where `%diff_fn` does not have a VJP function operand.
|
|
sil @differentiable_function_extract_vjp_undefined : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> (@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
|
|
bb0(%orig : $@callee_guaranteed (Float) -> Float):
|
|
%diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float
|
|
%extracted_vjp = differentiable_function_extract [vjp] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float
|
|
return %extracted_vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
}
|
|
|
|
// CHECK-LABEL: sil @differentiable_function_extract_vjp_undefined : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
// CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float):
|
|
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]] : $@callee_guaranteed (Float) -> Float
|
|
// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
|
|
// CHECK: return [[EXTRACTED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
|
|
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp_undefined'
|
|
|
|
// MARK: `convert_function` hoisting
|
|
|
|
// This should optimize down single partial_apply that escapes
|
|
sil @differential_function_convert_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
|
|
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
|
|
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
|
|
|
|
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
|
|
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
|
|
|
|
%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
|
|
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
|
|
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
|
|
}
|
|
|
|
debug_value %diff_fn, let, name "f", argno 1
|
|
|
|
%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
|
|
%conv_orig = differentiable_function_extract [original] %conv_diff
|
|
return %conv_orig
|
|
}
|
|
|
|
// CHECK-LABEL: sil @differential_function_convert_single_use
|
|
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
|
|
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
|
|
// CHECK: %[[PA:.*]] = partial_apply [callee_guaranteed] %[[THUNK]](%[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
|
|
// CHECKL return %[[PA]] : $@callee_guaranteed (@in_guaranteed Float) -> Float
|
|
|
|
sil @blackhole : $(@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
|
|
|
|
// differentiable_function has multiple uses, so we cannot commute it with convert_function, check that all instructions are there
|
|
|
|
sil @differential_function_convert_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
|
|
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
|
|
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
|
|
|
|
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
|
|
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
|
|
|
|
%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
|
|
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
|
|
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
|
|
}
|
|
|
|
debug_value %diff_fn, let, name "f", argno 1
|
|
|
|
%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
|
|
%conv_orig = differentiable_function_extract [original] %conv_diff
|
|
|
|
%blackhole = function_ref @blackhole : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
|
|
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
|
|
|
|
return %conv_orig : $@callee_guaranteed (@in_guaranteed Float) -> Float
|
|
}
|
|
|
|
// CHECK-LABEL: sil @differential_function_convert_multiple_use
|
|
// CHECK: convert_function
|
|
// CHECK: differentiable_function
|
|
// CHECK: convert_function
|
|
// CHECK: differentiable_function_extract
|
|
|
|
// MARK: `convert_escape_to_noescape` hoisting
|
|
|
|
sil @blackhole2 : $(@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
|
|
|
|
// Here we should be able to unfold partial_apply down to direct function call
|
|
|
|
sil @differential_function_noescape_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
|
|
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
|
|
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
|
|
|
|
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
|
|
|
|
%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
|
|
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
|
|
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
|
|
}
|
|
|
|
debug_value %diff_fn, let, name "f", argno 1
|
|
|
|
%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
|
|
%conv_orig = differentiable_function_extract [original] %conv_diff
|
|
|
|
%arg = alloc_stack $Float
|
|
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
|
|
|
|
dealloc_stack %arg : $*Float
|
|
strong_release %pa
|
|
|
|
%res = tuple ()
|
|
return %res : $()
|
|
}
|
|
|
|
// CHECK-LABEL: sil @differential_function_noescape_single_use
|
|
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
|
|
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
|
|
// CHECK: %[[ARG:.*]] = alloc_stack $Float
|
|
// CHECK: apply %[[THUNK]](%[[ARG]], %[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
|
|
|
|
|
|
// differentiable_function has multiple uses, so we cannot commute it with convert_escape_to_noescape, check that all instructions are there
|
|
|
|
sil @differential_function_noescape_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
|
|
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
|
|
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
|
|
|
|
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
|
|
|
|
%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
|
|
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
|
|
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
|
|
}
|
|
|
|
debug_value %diff_fn, let, name "f", argno 1
|
|
|
|
%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
|
|
%conv_orig = differentiable_function_extract [original] %conv_diff
|
|
|
|
%arg = alloc_stack $Float
|
|
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
|
|
|
|
%blackhole = function_ref @blackhole2 : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
|
|
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
|
|
|
|
dealloc_stack %arg : $*Float
|
|
strong_release %pa
|
|
|
|
%res = tuple ()
|
|
return %res : $()
|
|
}
|
|
|
|
// CHECK-LABEL: sil @differential_function_noescape_multiple_use
|
|
// CHECK: differentiable_function
|
|
// CHECK: convert_escape_to_noescape
|
|
// CHECK: differentiable_function_extract
|