Files
swift-mirror/test/AutoDiff/sil_combine.sil
Anton Korobeynikov 42e530f9ab Add tests
2025-02-24 15:20:46 -08:00

198 lines
14 KiB
Plaintext

// RUN: %target-sil-opt -sil-print-types -enable-sil-verify-all %s -sil-combine | %FileCheck %s
// SILCombine tests for differentiation-related instructions.
sil_stage canonical
import Swift
import _Differentiation
// MARK: `differentiable_function_extract` folding
sil @differentiable_function_extract_orig : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> (@callee_guaranteed (Float) -> Float) {
bb0(%orig : $@callee_guaranteed (Float) -> Float):
%diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float
%extracted_orig = differentiable_function_extract [original] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float
return %extracted_orig : $@callee_guaranteed (Float) -> Float
}
// CHECK-LABEL: sil @differentiable_function_extract_orig : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> @callee_guaranteed (Float) -> Float {
// CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float):
// CHECK: return [[ORIG_FN]] : $@callee_guaranteed (Float) -> Float
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_orig'
sil @differentiable_function_extract_jvp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
bb0(%orig : $@callee_guaranteed (Float) -> Float, %jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)):
%diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float with_derivative {%jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), undef : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
%extracted_jvp = differentiable_function_extract [jvp] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float
return %extracted_jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
}
// CHECK-LABEL: sil @differentiable_function_extract_jvp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float, [[JVP_FN:%.*]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)):
// CHECK: return [[JVP_FN]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_jvp'
sil @differentiable_function_extract_vjp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
bb0(%orig : $@callee_guaranteed (Float) -> Float, %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)):
%diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float with_derivative {undef : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
%extracted_vjp = differentiable_function_extract [vjp] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float
return %extracted_vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
}
// CHECK-LABEL: sil @differentiable_function_extract_vjp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float, [[VJP_FN:%.*]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)):
// CHECK: return [[VJP_FN]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp'
// Test `differentiatiable_function_extract [vjp] %diff_fn` where `%diff_fn` does not have a VJP function operand.
sil @differentiable_function_extract_vjp_undefined : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> (@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
bb0(%orig : $@callee_guaranteed (Float) -> Float):
%diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float
%extracted_vjp = differentiable_function_extract [vjp] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float
return %extracted_vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
}
// CHECK-LABEL: sil @differentiable_function_extract_vjp_undefined : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float):
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]] : $@callee_guaranteed (Float) -> Float
// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK: return [[EXTRACTED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp_undefined'
// MARK: `convert_function` hoisting
// This should optimize down single partial_apply that escapes
sil @differential_function_convert_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
}
debug_value %diff_fn, let, name "f", argno 1
%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
%conv_orig = differentiable_function_extract [original] %conv_diff
return %conv_orig
}
// CHECK-LABEL: sil @differential_function_convert_single_use
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
// CHECK: %[[PA:.*]] = partial_apply [callee_guaranteed] %[[THUNK]](%[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
// CHECKL return %[[PA]] : $@callee_guaranteed (@in_guaranteed Float) -> Float
sil @blackhole : $(@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
// differentiable_function has multiple uses, so we cannot commute it with convert_function, check that all instructions are there
sil @differential_function_convert_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
}
debug_value %diff_fn, let, name "f", argno 1
%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
%conv_orig = differentiable_function_extract [original] %conv_diff
%blackhole = function_ref @blackhole : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
return %conv_orig : $@callee_guaranteed (@in_guaranteed Float) -> Float
}
// CHECK-LABEL: sil @differential_function_convert_multiple_use
// CHECK: convert_function
// CHECK: differentiable_function
// CHECK: convert_function
// CHECK: differentiable_function_extract
// MARK: `convert_escape_to_noescape` hoisting
sil @blackhole2 : $(@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
// Here we should be able to unfold partial_apply down to direct function call
sil @differential_function_noescape_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
}
debug_value %diff_fn, let, name "f", argno 1
%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
%conv_orig = differentiable_function_extract [original] %conv_diff
%arg = alloc_stack $Float
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
dealloc_stack %arg : $*Float
strong_release %pa
%res = tuple ()
return %res : $()
}
// CHECK-LABEL: sil @differential_function_noescape_single_use
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
// CHECK: %[[ARG:.*]] = alloc_stack $Float
// CHECK: apply %[[THUNK]](%[[ARG]], %[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
// differentiable_function has multiple uses, so we cannot commute it with convert_escape_to_noescape, check that all instructions are there
sil @differential_function_noescape_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
}
debug_value %diff_fn, let, name "f", argno 1
%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
%conv_orig = differentiable_function_extract [original] %conv_diff
%arg = alloc_stack $Float
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
%blackhole = function_ref @blackhole2 : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
dealloc_stack %arg : $*Float
strong_release %pa
%res = tuple ()
return %res : $()
}
// CHECK-LABEL: sil @differential_function_noescape_multiple_use
// CHECK: differentiable_function
// CHECK: convert_escape_to_noescape
// CHECK: differentiable_function_extract