Files
swift-mirror/test/AutoDiff/SILOptimizer/sil_combine_differentiable_function_extract.sil
T
Daniil Kovalev aa8f1d7efd [AutoDiff] Fix crash due to use after consume in differentiable_function (#88918)
The patch implements proper sil-combiner handling of
`differentiable_function` for cases when extractee has non-trivial
ownership. In such casese, it is consumed by the differentiable_function
instruction. We must copy the extractee before the consumption point so
the copy remains live afterward.

Fixes #88816
2026-05-13 21:33:13 +00:00

235 lines
18 KiB
Plaintext

// RUN: %target-sil-opt -enable-sil-verify-all %s -sil-combine | %FileCheck %s
// CHECK-NOT: = differentiable_function_extract
sil_stage canonical
import Builtin
import Swift
import SwiftShims
import _Differentiation
/// SIL below corresponds to the following Swift:
///
/// @differentiable(reverse)
/// func foo<T: Differentiable & AdditiveArithmetic>(_ x: T, _ y: T) -> T {
/// if x == T.zero {
/// return y
/// }
/// return x
/// }
///
/// @differentiable(reverse)
/// func bar(_ x: Float, _ y: Float) -> Float {
/// return foo(y, x)
/// }
// foo<A>(_:_:)
sil @$s3src3fooyxx_xts18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF : $@convention(thin) <T where T : AdditiveArithmetic, T : Differentiable> (@in_guaranteed T, @in_guaranteed T) -> @out T
// forward-mode derivative of foo<A>(_:_:)
sil @$s3src3fooyxx_xts18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlFsACRzAdERzlTJfSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1) -> @out τ_0_2 for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>)
// reverse-mode derivative of foo<A>(_:_:)
sil @$s3src3fooyxx_xts18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlFsACRzAdERzlTJrSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>)
// CHECK-LABEL: {{^}}// reverse-mode derivative of bar(_:_:){{$}}
// reverse-mode derivative of bar(_:_:)
sil [ossa] @$s3src3baryS2f_SftFTJrSSpSr : $@convention(thin) (Float, Float) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) {
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0 // users: %8, %2
// %1 // users: %6, %3
bb0(%0 : $Float, %1 : $Float):
%4 = alloc_stack $Float // users: %28, %27, %21
%5 = alloc_stack $Float // users: %26, %21, %6
store %1 to [trivial] %5 // id: %6
%7 = alloc_stack $Float // users: %25, %21, %8
store %0 to [trivial] %7 // id: %8
// function_ref foo<A>(_:_:)
%9 = function_ref @$s3src3fooyxx_xts18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF : $@convention(thin) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 // user: %10
%10 = partial_apply [callee_guaranteed] %9<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0 // user: %15
// function_ref forward-mode derivative of foo<A>(_:_:)
%11 = function_ref @$s3src3fooyxx_xts18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlFsACRzAdERzlTJfSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1) -> @out τ_0_2 for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) // user: %12
%12 = partial_apply [callee_guaranteed] %11<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_1) -> @out τ_0_2 for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) // user: %15
// function_ref reverse-mode derivative of foo<A>(_:_:)
%13 = function_ref @$s3src3fooyxx_xts18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlFsACRzAdERzlTJrSSpSr : $@convention(thin) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) // user: %14
// CHECK: %[[#T0:]] = function_ref @$s3src3fooyxx_xts18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlFsACRzAdERzlTJrSSpSr
%14 = partial_apply [callee_guaranteed] %13<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <τ_0_0.TangentVector, τ_0_0.TangentVector, τ_0_0.TangentVector>) // user: %15
// CHECK: %[[#T1:]] = partial_apply [callee_guaranteed] %[[#T0]]<Float>()
// CHECK: %[[#T2:]] = copy_value %[[#T1]]
%15 = differentiable_function [parameters 0 1] [results 0] %10 with_derivative {%12, %14} // users: %16, %20
// CHECK: %[[#T4:]] = differentiable_function [parameters 0 1] [results 0] %[[#]] with_derivative {%[[#]], %[[#T2]]} // user: %[[#T5:]]
%16 = begin_borrow %15 // users: %19, %17
%17 = differentiable_function_extract [vjp] %16 // user: %18
%18 = copy_value %17 // users: %22, %21
end_borrow %16 // id: %19
destroy_value %15 // id: %20
// CHECK: destroy_value %[[#T4]] // id: %[[#T5]]
%21 = apply %18(%4, %5, %7) : $@callee_guaranteed (@in_guaranteed Float, @in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float)) // user: %24
// CHECK: %[[#]] = apply %[[#T0]]<Float>(%[[#]], %[[#]], %[[#]])
destroy_value %18 // id: %22
// CHECK: destroy_value %[[#T1]]
// function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float)
%23 = function_ref @$sS3fIegnrr_S3fIegydd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float)) -> (Float, Float) // user: %24
%24 = partial_apply [callee_guaranteed] %23(%21) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float)) -> (Float, Float) // user: %30
dealloc_stack %7 // id: %25
dealloc_stack %5 // id: %26
%27 = load [trivial] %4 // user: %31
dealloc_stack %4 // id: %28
// function_ref pullback of bar(_:_:)
%29 = function_ref @$s3src3baryS2f_SftFTJpSSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> (Float, Float) // user: %30
%30 = partial_apply [callee_guaranteed] %29(%24) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> (Float, Float) // user: %31
%31 = tuple (%27, %30) // user: %32
return %31 // id: %32
} // end sil function '$s3src3baryS2f_SftFTJrSSpSr'
// pullback of bar(_:_:)
sil @$s3src3baryS2f_SftFTJpSSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> (Float, Float)
// thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float)
sil [transparent] [reabstraction_thunk] @$sS3fIegnrr_S3fIegydd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float)) -> (Float, Float)
/// SIL below corresponds to the following Swift:
/// func compute<T>(_ f: (T) -> T, _ x: T) -> T {
/// return f(x)
/// }
///
/// func checkDerivative<T: Differentiable>(
/// _ f: @differentiable(reverse) (T) -> T,
/// _ x: T) {
/// let _ = compute(f, x)
/// }
///
/// func caller() { checkDerivative({ $0 }, 0.0) }
// closure #1 in caller()
// Isolation: nonisolated
sil [ossa] @$s3src6callerSdyFS2dcfU_ : $@convention(thin) (Double) -> Double
// forward-mode derivative of closure #1 in caller()
// Isolation: nonisolated
sil [signature_optimized_thunk] [heuristic_always_inline] @$s3src6callerSdyFS2dcfU_TJfSpSr : $@convention(thin) (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)
// reverse-mode derivative of closure #1 in caller()
// Isolation: nonisolated
sil @$s3src6callerSdyFS2dcfU_TJrSpSr : $@convention(thin) (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)
// thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double, @owned @escaping @callee_guaranteed (@unowned Double) -> (@unowned Double))
sil [transparent] [serialized] [reabstraction_thunk] [ossa] @$sS4dIegyd_Igydo_S2dxq_Ri_zRi0_zRi__Ri0__r0_lyS2dIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>)
// thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double)
sil [transparent] [serialized] [reabstraction_thunk] [ossa] @$sS2dIgyd_S2dIegnr_TR : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> Double) -> @out Double
sil_scope 1 { parent @$s3src6callerSdyF : $@convention(thin) () -> Double }
sil_scope 2 { parent 1 }
sil_scope 3 { parent @$s3src15checkDerivativeyyxxYjrXE_xt16_Differentiation14DifferentiableRzlFSd_Tg5 : $@convention(thin) (@guaranteed @differentiable(reverse) @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>, Double) -> () inlined_at 2 }
sil_scope 4 { parent 3 inlined_at 2 }
sil_scope 5 { parent 4 inlined_at 2 }
sil_scope 6 { parent @$s3src7computeyxxxXE_xtlFSd_Tg5 : $@convention(thin) (@guaranteed @noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>, Double) -> Double inlined_at 5 }
// CHECK-LABEL: {{^}}// caller(){{$}}
// caller()
// Isolation: unspecified
sil [ossa] @$s3src6callerSdyF : $@convention(thin) () -> Double {
[global: read,write,copy,destroy,allocate,deinit_barrier]
bb0:
%0 = alloc_stack $Double // users: %37, %43, %42
// function_ref closure #1 in caller()
%1 = function_ref @$s3src6callerSdyFS2dcfU_ : $@convention(thin) (Double) -> Double // user: %4
// function_ref forward-mode derivative of closure #1 in caller()
%2 = function_ref @$s3src6callerSdyFS2dcfU_TJfSpSr : $@convention(thin) (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double) // user: %5
// function_ref reverse-mode derivative of closure #1 in caller()
%3 = function_ref @$s3src6callerSdyFS2dcfU_TJrSpSr : $@convention(thin) (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double) // user: %6
%4 = thin_to_thick_function %1 to $@noescape @callee_guaranteed (Double) -> Double // users: %9, %7; ownership: none
%5 = thin_to_thick_function %2 to $@noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double) // users: %15, %7; ownership: none
%6 = thin_to_thick_function %3 to $@noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double) // users: %21, %7; ownership: none
%7 = differentiable_function [parameters 0] [results 0] %4 with_derivative {%5, %6} // user: %40
// function_ref thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double)
%8 = function_ref @$sS2dIgyd_S2dIegnr_TR : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> Double) -> @out Double // user: %9
// CHECK: %[[#A0:]] = function_ref @$sS2dIgyd_S2dIegnr_TR : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> Double) -> @out Double // users: %[[#]], %[[#]]{{$}}
%9 = partial_apply [callee_guaranteed] %8(%4) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> Double) -> @out Double // user: %10
// CHECK: %[[#A1:]] = partial_apply [callee_guaranteed] %[[#A0]](%[[#A2:]]) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> Double) -> @out Double // user: %[[#A3:]]
%10 = convert_function %9 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double> // users: %12, %11
// CHECK: %[[#A3:]] = convert_function %[[#A1]] to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double> // users: %[[#]], %[[#]]{{$}}
%11 = copy_value %10 // users: %44, %13
// CHECK: %[[#A4:]] = copy_value %[[#A3]] // users: %[[#]], %[[#]]{{$}}
destroy_value %10 // id: %12
// CHECK: destroy_value %[[#A3]] // id: %11
%13 = convert_escape_to_noescape %11 to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double> // user: %26
// CHECK: %[[#A5:]] = convert_escape_to_noescape %[[#A4]] to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double> // users: %[[#]], %[[#]], %[[#]]{{$}}
// function_ref thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double, @owned @escaping @callee_guaranteed (@unowned Double) -> (@unowned Double))
%14 = function_ref @$sS4dIegyd_Igydo_S2dxq_Ri_zRi0_zRi__Ri0__r0_lyS2dIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %15
%15 = partial_apply [callee_guaranteed] %14(%5) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %16
%16 = convert_function %15 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // users: %18, %17
%17 = copy_value %16 // users: %45, %19
destroy_value %16 // id: %18
%19 = convert_escape_to_noescape %17 to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // user: %26
// function_ref thunk for @callee_guaranteed (@unowned Double) -> (@unowned Double, @owned @escaping @callee_guaranteed (@unowned Double) -> (@unowned Double))
%20 = function_ref @$sS4dIegyd_Igydo_S2dxq_Ri_zRi0_zRi__Ri0__r0_lyS2dIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %21
%21 = partial_apply [callee_guaranteed] %20(%6) : $@convention(thin) (@in_guaranteed Double, @guaranteed @noescape @callee_guaranteed (Double) -> (Double, @owned @callee_guaranteed (Double) -> Double)) -> (@out Double, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>) // user: %22
%22 = convert_function %21 to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // users: %24, %23
%23 = copy_value %22 // users: %46, %25
destroy_value %22 // id: %24
%25 = convert_escape_to_noescape %23 to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Double, Double, Double, Double> // user: %26
// CHECK: %[[#A6:]] = copy_value %[[#A5]]
%26 = differentiable_function [parameters 0] [results 0] %13 with_derivative {%19, %25} // users: %29, %41
// CHECK: %[[#A7:]] = differentiable_function [parameters 0] [results 0] %[[#A6]] with_derivative {%[[#]], %[[#]]} // users: %[[#]], %[[#]]{{$}}
%27 = float_literal $Builtin.FPIEEE64, 0x0 // user: %28
%28 = struct $Double (%27) // users: %30, %34
%29 = begin_borrow [lexical] %26 // users: %32, %31, %39
debug_value %28, let, name "x", argno 2, scope 3 // id: %30
debug_value %29, let, name "f", argno 1, scope 3 // id: %31
%32 = differentiable_function_extract [original] %29 // users: %37, %35
%33 = alloc_stack $Double // users: %38, %37, %36, %34
store %28 to [trivial] %33 // id: %34
debug_value %32, let, name "f", argno 1, scope 5 // id: %35
// CHECK: debug_value %[[#A5]]
// CHECK: debug_value %[[#A7]]
debug_value %33, let, name "x", argno 2, expr op_deref, scope 5 // id: %36
// CHECK: %[[#A8:]] = alloc_stack $Double
%37 = apply %32(%0, %33) : $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Double, Double>
// CHECK: %[[#]] = apply %[[#A0]](%0, %[[#A8]], %[[#A2]])
dealloc_stack %33 // id: %38
end_borrow %29 // id: %39
destroy_value %7 // id: %40
destroy_value %26 // id: %41
// CHECK: destroy_value %[[#A7]]
%42 = load [trivial] %0 // user: %47
dealloc_stack %0 // id: %43
destroy_value %11 // id: %44
// CHECK: destroy_value %[[#A4]]
destroy_value %17 // id: %45
destroy_value %23 // id: %46
return %42 // id: %47
} // end sil function '$s3src6callerSdyF'