// RUN: %target-sil-opt -sil-print-types -enable-sil-verify-all %s -sil-combine | %FileCheck %s // SILCombine tests for differentiation-related instructions. sil_stage canonical import Swift import _Differentiation // MARK: `differentiable_function_extract` folding sil @differentiable_function_extract_orig : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> (@callee_guaranteed (Float) -> Float) { bb0(%orig : $@callee_guaranteed (Float) -> Float): %diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float %extracted_orig = differentiable_function_extract [original] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float return %extracted_orig : $@callee_guaranteed (Float) -> Float } // CHECK-LABEL: sil @differentiable_function_extract_orig : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> @callee_guaranteed (Float) -> Float { // CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float): // CHECK: return [[ORIG_FN]] : $@callee_guaranteed (Float) -> Float // CHECK-LABEL: } // end sil function 'differentiable_function_extract_orig' sil @differentiable_function_extract_jvp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) { bb0(%orig : $@callee_guaranteed (Float) -> Float, %jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)): %diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float with_derivative {%jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), undef : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} %extracted_jvp = differentiable_function_extract [jvp] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float return %extracted_jvp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) } // CHECK-LABEL: sil @differentiable_function_extract_jvp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float, [[JVP_FN:%.*]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)): // CHECK: return [[JVP_FN]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-LABEL: } // end sil function 'differentiable_function_extract_jvp' sil @differentiable_function_extract_vjp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) { bb0(%orig : $@callee_guaranteed (Float) -> Float, %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)): %diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float with_derivative {undef : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} %extracted_vjp = differentiable_function_extract [vjp] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float return %extracted_vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) } // CHECK-LABEL: sil @differentiable_function_extract_vjp : $@convention(thin) (@callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float, [[VJP_FN:%.*]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)): // CHECK: return [[VJP_FN]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp' // Test `differentiatiable_function_extract [vjp] %diff_fn` where `%diff_fn` does not have a VJP function operand. sil @differentiable_function_extract_vjp_undefined : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> (@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) { bb0(%orig : $@callee_guaranteed (Float) -> Float): %diff_fn = differentiable_function [parameters 0] [results 0] %orig : $@callee_guaranteed (Float) -> Float %extracted_vjp = differentiable_function_extract [vjp] %diff_fn : $@callee_guaranteed @differentiable(reverse) (Float) -> Float return %extracted_vjp : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) } // CHECK-LABEL: sil @differentiable_function_extract_vjp_undefined : $@convention(thin) (@callee_guaranteed (Float) -> Float) -> @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { // CHECK: bb0([[ORIG_FN:%.*]] : $@callee_guaranteed (Float) -> Float): // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]] : $@callee_guaranteed (Float) -> Float // CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float // CHECK: return [[EXTRACTED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp_undefined' // MARK: `convert_function` hoisting // This should optimize down single partial_apply that escapes sil @differential_function_convert_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float { bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float): %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float %conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for %diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative { undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for , undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for } debug_value %diff_fn, let, name "f", argno 1 %conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float %conv_orig = differentiable_function_extract [original] %conv_diff return %conv_orig } // CHECK-LABEL: sil @differential_function_convert_single_use // CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) // CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float // CHECK: %[[PA:.*]] = partial_apply [callee_guaranteed] %[[THUNK]](%[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float // CHECKL return %[[PA]] : $@callee_guaranteed (@in_guaranteed Float) -> Float sil @blackhole : $(@differentiable(reverse) @callee_guaranteed @substituted (@in_guaranteed T) -> Float for ) -> () // differentiable_function has multiple uses, so we cannot commute it with convert_function, check that all instructions are there sil @differential_function_convert_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float { bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float): %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float %conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for %diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative { undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for , undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for } debug_value %diff_fn, let, name "f", argno 1 %conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float %conv_orig = differentiable_function_extract [original] %conv_diff %blackhole = function_ref @blackhole : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted (@in_guaranteed T) -> Float for ) -> () apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted (@in_guaranteed T) -> Float for ) -> () return %conv_orig : $@callee_guaranteed (@in_guaranteed Float) -> Float } // CHECK-LABEL: sil @differential_function_convert_multiple_use // CHECK: convert_function // CHECK: differentiable_function // CHECK: convert_function // CHECK: differentiable_function_extract // MARK: `convert_escape_to_noescape` hoisting sil @blackhole2 : $(@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> () // Here we should be able to unfold partial_apply down to direct function call sil @differential_function_noescape_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () { bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float): %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float %diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative { undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float), undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float) } debug_value %diff_fn, let, name "f", argno 1 %conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float %conv_orig = differentiable_function_extract [original] %conv_diff %arg = alloc_stack $Float apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float dealloc_stack %arg : $*Float strong_release %pa %res = tuple () return %res : $() } // CHECK-LABEL: sil @differential_function_noescape_single_use // CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) // CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float // CHECK: %[[ARG:.*]] = alloc_stack $Float // CHECK: apply %[[THUNK]](%[[ARG]], %[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float // differentiable_function has multiple uses, so we cannot commute it with convert_escape_to_noescape, check that all instructions are there sil @differential_function_noescape_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () { bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float): %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float %diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative { undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float), undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float) } debug_value %diff_fn, let, name "f", argno 1 %conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float %conv_orig = differentiable_function_extract [original] %conv_diff %arg = alloc_stack $Float apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float %blackhole = function_ref @blackhole2 : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> () apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> () dealloc_stack %arg : $*Float strong_release %pa %res = tuple () return %res : $() } // CHECK-LABEL: sil @differential_function_noescape_multiple_use // CHECK: differentiable_function // CHECK: convert_escape_to_noescape // CHECK: differentiable_function_extract