Files
Daniil Kovalev 5528cf1cc4 [AutoDiff] Run AutoDiff closure spec pass for all VJPs (#81548)
Previously, AutoDiff closure specialization pass was triggered only on
VJPs containing single basic block. However, the pass logic allows
running on arbitrary VJPs. This PR enables the pass for all VJPs
unconditionally. So, if the pullback corresponding to multiple-BB VJP
accepts some closures directly as arguments, these closures might become
specialized by the pass. Closures passed via payload of branch tracing
enum are not specialized - this is subject for future changes.

The PR contains several commits.
1. The thing named "call site" in the code is partial_apply of pullback
corresponding to the VJP. This might appear only once, so we drop
support for multiple "call sites".
2. Enhance existing SILOptimizer tests for the pass.
3. Add validation-tests for single basic block case.
4. The change itself - delete check against single basic block.
5. Add validation-tests for multiple basic block case.
6. Add SILOptimizer tests for multiple basic block case.
2025-07-07 13:00:14 +00:00

690 lines
46 KiB
Plaintext

/// Multi basic block VJP, pullback accepting branch tracing enum argument.
// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK
// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK
// REQUIRES: swift_in_compiler
sil_stage canonical
import Builtin
import Swift
import SwiftShims
import _Differentiation
///////////////////
/// Test case 1 ///
///////////////////
/// This SIL corresponds to the following Swift:
///
/// @differentiable(reverse)
/// func mul42(_ a: Float?) -> Float {
/// let b = 42 * a!
/// return b
/// }
enum _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0 {
case bb0(())
}
sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
sil [transparent] [thunk] @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// pullback of mul42(_:)
sil private [signature_optimized_thunk] [always_inline] @$s4test5mul42yS2fSgFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional<Float>.TangentVector {
bb0(%0 : $Float, %1 : $_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, %2 : $@callee_guaranteed (Float) -> Float):
%4 = apply %2(%0) : $@callee_guaranteed (Float) -> Float
strong_release %2
%6 = enum $Optional<Float>, #Optional.some!enumelt, %4
%7 = struct $Optional<Float>.TangentVector (%6)
return %7
} // end sil function '$s4test5mul42yS2fSgFTJpSpSr'
// reverse-mode derivative of mul42(_:)
sil hidden @$s4test5mul42yS2fSgFTJrSpSr : $@convention(thin) (Optional<Float>) -> (Float, @owned @callee_guaranteed (Float) -> Optional<Float>.TangentVector) {
bb0(%0 : $Optional<Float>):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s4test5mul42yS2fSgFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional<Float>.TangentVector
// TRUNNER-NEXT: Passed in closures:
// TRUNNER-NEXT: 1. %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER-EMPTY:
//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s4test5mul42yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n
// CHECK: sil private [signature_optimized_thunk] [always_inline] @$s4test5mul42yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional<Float>.TangentVector {
// CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, %2 : $Float, %3 : $Float):
// CHECK: %[[#A4:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: %[[#A5:]] = partial_apply [callee_guaranteed] %[[#A4]](%2, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: %[[#A6:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: %[[#A7:]] = partial_apply [callee_guaranteed] %[[#A6]](%[[#A5]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: %[[#A8:]] = apply %[[#A7]](%0) : $@callee_guaranteed (Float) -> Float
// COMBINE-NOT: = partial_apply
// COMBINE: %[[#A8:]] = apply %[[#A6]](%0, %[[#A5]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: strong_release %[[#A7]] : $@callee_guaranteed (Float) -> Float
// CHECK: %[[#A10:]] = enum $Optional<Float>, #Optional.some!enumelt, %[[#A8]] : $Float
// CHECK: %[[#A11:]] = struct $Optional<Float>.TangentVector (%[[#A10]] : $Optional<Float>)
// CHECK: return %[[#A11]] : $Optional<Float>.TangentVector
//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s4test5mul42yS2fSgFTJrSpSr:
// CHECK: sil hidden @$s4test5mul42yS2fSgFTJrSpSr : $@convention(thin) (Optional<Float>) -> (Float, @owned @callee_guaranteed (Float) -> Optional<Float>.TangentVector) {
// CHECK: bb1(%2 : $Float):
// CHECK: %[[#B4:]] = struct $Float (%[[#]] : $Builtin.FPIEEE32)
// TRUNNER: %[[#B10:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#B11:]] = partial_apply [callee_guaranteed] %[[#B10]](%2, %[[#B4]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#B12:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: %[[#B13:]] = partial_apply [callee_guaranteed] %[[#B12]](%[[#B11]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: %[[#B14:]] = function_ref @$s4test5mul42yS2fSgFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional<Float>.TangentVector
// COMBINE-NOT: = partial_apply
// COMBINE-NOT: = function_ref
// CHECK: %[[#B15:]] = function_ref @$s4test5mul42yS2fSgFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktj1_k5FZSf_K6SfcfU_S2fTf1nnc_n : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional<Float>.TangentVector
// CHECK: %[[#B16:]] = partial_apply [callee_guaranteed] %[[#B15]](%[[#]], %2, %[[#B4]]) : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, Float, Float) -> Optional<Float>.TangentVector
// TRUNNER: release_value %[[#B11]] : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: %[[#B18:]] = tuple (%[[#]] : $Float, %[[#B16]] : $@callee_guaranteed (Float) -> Optional<Float>.TangentVector)
// CHECK: return %[[#B18]]
switch_enum %0, case #Optional.some!enumelt: bb1, case #Optional.none!enumelt: bb2
bb1(%3 : $Float):
%4 = float_literal $Builtin.FPIEEE32, 0x42280000 // 42
%5 = struct $Float (%4)
%6 = tuple ()
%7 = enum $_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, #_AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %6
%8 = struct_extract %3, #Float._value
%9 = builtin "fmul_FPIEEE32"(%4, %8) : $Builtin.FPIEEE32
%10 = struct $Float (%9)
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%11 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%12 = partial_apply [callee_guaranteed] %11(%3, %5) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
%13 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%14 = partial_apply [callee_guaranteed] %13(%12) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// function_ref pullback of mul42(_:)
%16 = function_ref @$s4test5mul42yS2fSgFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional<Float>.TangentVector
%17 = partial_apply [callee_guaranteed] %16(%7, %14) : $@convention(thin) (Float, @owned _AD__$s4test5mul42yS2fSgF_bb2__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float) -> Optional<Float>.TangentVector
%18 = tuple (%10, %17)
return %18
bb2:
unreachable
} // end sil function '$s4test5mul42yS2fSgFTJrSpSr'
///////////////////
/// Test case 2 ///
///////////////////
/// This SIL corresponds to the following Swift:
///
/// struct Class: Differentiable {
/// var stored: Float
/// var optional: Float?
///
/// init(stored: Float, optional: Float?) {
/// self.stored = stored
/// self.optional = optional
/// }
///
/// @differentiable(reverse)
/// func method() -> Float {
/// let c: Class
/// do {
/// let tmp = Class(stored: 1 * stored, optional: optional)
/// let tuple = (tmp, tmp)
/// c = tuple.0
/// }
/// var ret : Float = 0
/// if let x = c.optional {
/// ret = x * c.stored
/// } else {
/// ret = 1 * c.stored
/// }
/// return 1 * ret * ret
/// }
/// }
struct Class : Differentiable {
@_hasStorage var stored: Float { get set }
@_hasStorage @_hasInitialValue var optional: Float? { get set }
init(stored: Float, optional: Float?)
@differentiable(reverse, wrt: self)
func method() -> Float
struct TangentVector : AdditiveArithmetic, Differentiable {
@_hasStorage var stored: Float { get set }
@_hasStorage var optional: Optional<Float>.TangentVector { get set }
static func + (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector
static func - (lhs: Class.TangentVector, rhs: Class.TangentVector) -> Class.TangentVector
typealias TangentVector = Class.TangentVector
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Class.TangentVector, _ b: Class.TangentVector) -> Bool
init(stored: Float, optional: Optional<Float>.TangentVector)
static var zero: Class.TangentVector { get }
}
mutating func move(by offset: Class.TangentVector)
}
enum _AD__$s4test5ClassV6methodSfyF_bb0__Pred__src_0_wrt_0 {
}
enum _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0 {
case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)))
}
enum _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0 {
case bb0(((Float) -> Float, (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)))
}
enum _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0 {
case bb2((predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, (Float) -> Float))
case bb1((predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float)))
}
enum _AD__$s4test13methodWrapperySfAA5ClassVF_bb0__Pred__src_0_wrt_0 {
}
enum _AD__$s4test5ClassV6stored8optionalACSf_SfSgtcfC_bb0__Pred__src_0_wrt_0_1 {
}
sil @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
// pullback of Class.method()
sil private @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector {
bb0(%0 : $Float, %1 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, %2 : $@callee_guaranteed (Float) -> Float, %3 : $@callee_guaranteed (Float) -> (Float, Float)):
%4 = float_literal $Builtin.FPIEEE32, 0x0 // 0
%8 = apply %3(%0) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %3
%10 = tuple_extract %8, 0
%11 = tuple_extract %8, 1
%12 = struct_extract %11, #Float._value
%13 = builtin "fadd_FPIEEE32"(%4, %12) : $Builtin.FPIEEE32
%15 = apply %2(%10) : $@callee_guaranteed (Float) -> Float
strong_release %2
%17 = struct_extract %15, #Float._value
%18 = builtin "fadd_FPIEEE32"(%13, %17) : $Builtin.FPIEEE32
switch_enum %1, case #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb2
bb1(%37 : $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float)):
%38 = tuple_extract %37, 0
%39 = tuple_extract %37, 1
%40 = builtin "fadd_FPIEEE32"(%18, %4) : $Builtin.FPIEEE32
%41 = struct $Float (%40)
%42 = apply %39(%41) : $@callee_guaranteed (Float) -> Float
strong_release %39
%44 = struct_extract %42, #Float._value
%45 = builtin "fadd_FPIEEE32"(%44, %4) : $Builtin.FPIEEE32
%46 = builtin "fadd_FPIEEE32"(%4, %45) : $Builtin.FPIEEE32
%50 = unchecked_enum_data %38, #_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0.bb0!enumelt
%51 = tuple_extract %50, 1
%52 = tuple_extract %50, 0
br bb3(%4, %46, %52, %51)
bb2(%54 : $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float))):
%55 = tuple_extract %54, 0
%56 = tuple_extract %54, 1
%57 = builtin "fadd_FPIEEE32"(%18, %4) : $Builtin.FPIEEE32
%58 = struct $Float (%57)
%59 = apply %56(%58) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %56
%61 = tuple_extract %59, 0
%62 = tuple_extract %59, 1
%63 = struct_extract %61, #Float._value
%64 = builtin "fadd_FPIEEE32"(%63, %4) : $Builtin.FPIEEE32
%65 = struct_extract %62, #Float._value
%66 = builtin "fadd_FPIEEE32"(%65, %4) : $Builtin.FPIEEE32
%67 = builtin "fadd_FPIEEE32"(%4, %66) : $Builtin.FPIEEE32
%69 = builtin "fadd_FPIEEE32"(%64, %4) : $Builtin.FPIEEE32
%70 = builtin "fadd_FPIEEE32"(%69, %4) : $Builtin.FPIEEE32
%73 = unchecked_enum_data %55, #_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0.bb0!enumelt
%74 = tuple_extract %73, 1
%75 = tuple_extract %73, 0
br bb3(%70, %67, %75, %74)
bb3(%77 : $Builtin.FPIEEE32, %78 : $Builtin.FPIEEE32, %79 : $@callee_guaranteed (Float) -> Float, %80 : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)):
%81 = builtin "fadd_FPIEEE32"(%4, %77) : $Builtin.FPIEEE32
%85 = builtin "fadd_FPIEEE32"(%78, %4) : $Builtin.FPIEEE32
%87 = builtin "fadd_FPIEEE32"(%81, %4) : $Builtin.FPIEEE32
%89 = builtin "fadd_FPIEEE32"(%85, %4) : $Builtin.FPIEEE32
%91 = builtin "fadd_FPIEEE32"(%87, %4) : $Builtin.FPIEEE32
%93 = builtin "fadd_FPIEEE32"(%89, %4) : $Builtin.FPIEEE32
%95 = builtin "fadd_FPIEEE32"(%91, %4) : $Builtin.FPIEEE32
%97 = builtin "fadd_FPIEEE32"(%93, %4) : $Builtin.FPIEEE32
%99 = builtin "fadd_FPIEEE32"(%95, %4) : $Builtin.FPIEEE32
%100 = builtin "fadd_FPIEEE32"(%4, %97) : $Builtin.FPIEEE32
%102 = builtin "fadd_FPIEEE32"(%4, %99) : $Builtin.FPIEEE32
%104 = builtin "fadd_FPIEEE32"(%100, %4) : $Builtin.FPIEEE32
%105 = struct $Float (%104)
%106 = builtin "fadd_FPIEEE32"(%102, %4) : $Builtin.FPIEEE32
%107 = struct $Float (%106)
%108 = enum $Optional<Float>, #Optional.some!enumelt, %107
%109 = struct $Optional<Float>.TangentVector (%108)
%110 = struct $Class.TangentVector (%105, %109)
%111 = apply %80(%110) : $@callee_guaranteed (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
strong_release %80
%113 = tuple_extract %111, 1
%114 = struct_extract %113, #Optional.TangentVector.value
switch_enum %114, case #Optional.none!enumelt: bb4, case #Optional.some!enumelt: bb5
bb4:
br bb6(%4)
bb5(%117 : $Float):
%118 = unchecked_enum_data %114, #Optional.some!enumelt
%119 = struct_extract %118, #Float._value
%120 = builtin "fadd_FPIEEE32"(%119, %4) : $Builtin.FPIEEE32
br bb6(%120)
bb6(%122 : $Builtin.FPIEEE32):
%123 = tuple_extract %111, 0
%124 = struct_extract %123, #Float._value
%125 = builtin "fadd_FPIEEE32"(%124, %4) : $Builtin.FPIEEE32
%126 = struct $Float (%125)
%127 = apply %79(%126) : $@callee_guaranteed (Float) -> Float
strong_release %79
%129 = struct_extract %127, #Float._value
%130 = builtin "fadd_FPIEEE32"(%129, %4) : $Builtin.FPIEEE32
%131 = builtin "fadd_FPIEEE32"(%4, %130) : $Builtin.FPIEEE32
%132 = builtin "fadd_FPIEEE32"(%4, %122) : $Builtin.FPIEEE32
%133 = struct $Float (%132)
%134 = enum $Optional<Float>, #Optional.some!enumelt, %133
%135 = struct $Float (%131)
%136 = struct $Optional<Float>.TangentVector (%134)
%137 = struct $Class.TangentVector (%135, %136)
return %137
} // end sil function '$s4test5ClassV6methodSfyFTJpSpSr'
// reverse-mode derivative of Class.method()
sil hidden @$s4test5ClassV6methodSfyFTJrSpSr : $@convention(method) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
bb0(%0 : $Class):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s4test5ClassV6methodSfyFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]], %[[#C42:]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector
// TRUNNER-NEXT: Passed in closures:
// TRUNNER-NEXT: 1. %[[#]] = partial_apply [callee_guaranteed] %[[#C7:]](%[[#C34:]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER-NEXT: 2. %[[#C42]] = partial_apply [callee_guaranteed] %[[#C7]](%[[#C34]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER-EMPTY:
//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s4test5ClassV6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n
// CHECK: sil private @$s4test5ClassV6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector {
// CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, %2 : $Float, %3 : $Float, %4 : $Float, %5 : $Float):
// CHECK: %[[#D6:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: %[[#D7:]] = partial_apply [callee_guaranteed] %[[#D6]](%2, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: %[[#D8:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: %[[#D9:]] = partial_apply [callee_guaranteed] %[[#D8]](%[[#D7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// COMBINE-NOT: = partial_apply
// CHECK: %[[#D10:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#D11:]] = partial_apply [callee_guaranteed] %[[#D10]](%4, %5) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// COMBINE-NOT: = partial_apply
// TRUNNER: %[[#D13:]] = apply %[[#D11]](%0) : $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE: %[[#D13:]] = apply %[[#D10]](%0, %4, %5) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: strong_release %[[#D11]] : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: %[[#D15:]] = tuple_extract %[[#D13]] : $(Float, Float), 0
// TRUNNER: %[[#]] = apply %[[#D9]](%[[#D15]]) : $@callee_guaranteed (Float) -> Float
// COMBINE: %[[#]] = apply %[[#D8]](%[[#D15]], %[[#D7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: strong_release %[[#D9]] : $@callee_guaranteed (Float) -> Float
//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s4test5ClassV6methodSfyFTJrSpSr:
// CHECK: sil hidden @$s4test5ClassV6methodSfyFTJrSpSr : $@convention(method) (Class) -> (Float, @owned @callee_guaranteed (Float) -> Class.TangentVector) {
// CHECK: bb0(%0 : $Class):
// CHECK: %[[#E2:]] = struct $Float
// CHECK: %[[#E7:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: %[[#E9:]] = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// CHECK: bb1(%[[#]] : $Float):
// CHECK: bb2:
// CHECK: bb3(%[[#E33:]] : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, %[[#E34:]] : $Float):
// CHECK: %[[#E37:]] = struct $Float
// TRUNNER: %[[#E38:]] = partial_apply [callee_guaranteed] %[[#E7]](%[[#E34]], %[[#E2]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#]] = partial_apply [callee_guaranteed] %[[#E9]](%[[#E38]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: %[[#E42:]] = partial_apply [callee_guaranteed] %[[#E7]](%[[#E34]], %[[#E37]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector
// COMBINE-NOT: = partial_apply
// COMBINE-NOT: = function_ref @$s4test5ClassV6methodSfyFTJpSpSr
// CHECK: %[[#E44:]] = function_ref @$s4test5ClassV6methodSfyFTJpSpSr073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktk1_l5FZSf_L6SfcfU_S2fAES2fTf1nncc_n : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector
// CHECK: %[[#E45:]] = partial_apply [callee_guaranteed] %[[#E44]](%[[#E33]], %[[#E34]], %[[#E2]], %[[#E34]], %[[#E37]]) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, Float, Float, Float, Float) -> Class.TangentVector
// TRUNNER: release_value %[[#E38]] : $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: release_value %[[#E42]] : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: %[[#E48:]] = tuple (%[[#]] : $Float, %[[#E45]] : $@callee_guaranteed (Float) -> Class.TangentVector)
// CHECK: return %[[#E48]]
%2 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1
%3 = struct $Float (%2)
%4 = struct_extract %0, #Class.stored
%5 = struct_extract %4, #Float._value
%6 = builtin "fmul_FPIEEE32"(%2, %5) : $Builtin.FPIEEE32
%7 = struct $Float (%6)
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%8 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%9 = partial_apply [callee_guaranteed] %8(%4, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
%10 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%12 = struct_extract %0, #Class.optional
// function_ref pullback of Class.init(stored:optional:)
%25 = function_ref @$s4test5ClassV6stored8optionalACSf_SfSgtcfCTJpSSUpSr : $@convention(thin) (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
%26 = thin_to_thick_function %25 to $@callee_guaranteed (Class.TangentVector) -> (Float, Optional<Float>.TangentVector)
%27 = tuple (%11, %26)
switch_enum %12, case #Optional.some!enumelt: bb1, case #Optional.none!enumelt: bb2
bb1(%29 : $Float):
%30 = enum $_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %27
%32 = struct_extract %29, #Float._value
%33 = builtin "fmul_FPIEEE32"(%32, %6) : $Builtin.FPIEEE32
%34 = struct $Float (%33)
%35 = partial_apply [callee_guaranteed] %8(%7, %29) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%36 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float)) (%30, %35)
%37 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %36
br bb3(%37, %34)
bb2:
%39 = enum $_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %27
%40 = builtin "fmul_FPIEEE32"(%2, %6) : $Builtin.FPIEEE32
%41 = struct $Float (%40)
%42 = partial_apply [callee_guaranteed] %8(%7, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%43 = partial_apply [callee_guaranteed] %10(%42) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%44 = tuple $(predecessor: _AD__$s4test5ClassV6methodSfyF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%39, %43)
%45 = enum $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, #_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %44
br bb3(%45, %41)
bb3(%47 : $_AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, %48 : $Float):
%49 = struct_extract %48, #Float._value
%50 = builtin "fmul_FPIEEE32"(%2, %49) : $Builtin.FPIEEE32
%51 = struct $Float (%50)
%52 = partial_apply [callee_guaranteed] %8(%48, %3) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%53 = partial_apply [callee_guaranteed] %10(%52) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%54 = builtin "fmul_FPIEEE32"(%50, %49) : $Builtin.FPIEEE32
%55 = struct $Float (%54)
%56 = partial_apply [callee_guaranteed] %8(%48, %51) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// function_ref pullback of Class.method()
%57 = function_ref @$s4test5ClassV6methodSfyFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector
%58 = partial_apply [callee_guaranteed] %57(%47, %53, %56) : $@convention(thin) (Float, @owned _AD__$s4test5ClassV6methodSfyF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Class.TangentVector
%59 = tuple (%55, %58)
return %59
} // end sil function '$s4test5ClassV6methodSfyFTJrSpSr'
///////////////////
/// Test case 3 ///
///////////////////
/// This SIL corresponds to the following Swift:
///
/// @differentiable(reverse)
/// func cond_tuple_var(_ x: Float) -> Float {
/// // Convoluted function returning `x + x`.
/// var y: (Float, Float) = (x, x)
/// var z: (Float, Float) = (x + x, x - x)
/// if x > 0 {
/// let w = (x, x)
/// y.0 = w.1
/// y.1 = w.0
/// z.0 = z.0 - y.0
/// z.1 = z.1 + y.0
/// } else {
/// z = (1 * x, x)
/// }
/// return y.0 + y.1 - z.0 + z.1
/// }
enum _AD__$s4test14cond_tuple_varyS2fF_bb0__Pred__src_0_wrt_0 {
}
enum _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0 {
case bb0(((Float) -> (Float, Float), (Float) -> (Float, Float)))
}
enum _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0 {
case bb0(((Float) -> (Float, Float), (Float) -> (Float, Float)))
}
enum _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0 {
case bb2((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, (Float) -> Float))
case bb1((predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, (Float) -> (Float, Float), (Float) -> (Float, Float)))
}
sil @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)
sil @$sSf16_DifferentiationE12_vjpSubtract3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)
// pullback of cond_tuple_var(_:)
sil private @$s4test14cond_tuple_varyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {
bb0(%0 : $Float, %1 : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, %2 : $@callee_guaranteed (Float) -> (Float, Float), %3 : $@callee_guaranteed (Float) -> (Float, Float), %4 : $@callee_guaranteed (Float) -> (Float, Float)):
%5 = float_literal $Builtin.FPIEEE32, 0x0 // 0
%10 = apply %4(%0) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %4
%12 = tuple_extract %10, 0
%13 = tuple_extract %10, 1
%14 = struct_extract %13, #Float._value
%15 = builtin "fadd_FPIEEE32"(%5, %14) : $Builtin.FPIEEE32
%17 = apply %3(%12) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %3
%19 = tuple_extract %17, 0
%20 = tuple_extract %17, 1
%21 = struct_extract %20, #Float._value
%22 = builtin "fadd_FPIEEE32"(%5, %21) : $Builtin.FPIEEE32
%24 = apply %2(%19) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %2
%26 = tuple_extract %24, 0
%27 = tuple_extract %24, 1
%28 = struct_extract %27, #Float._value
%29 = builtin "fadd_FPIEEE32"(%5, %28) : $Builtin.FPIEEE32
%31 = struct_extract %26, #Float._value
%32 = builtin "fadd_FPIEEE32"(%5, %31) : $Builtin.FPIEEE32
switch_enum %1, case #_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb2
bb1(%44 : $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float)):
%45 = tuple_extract %44, 0
%46 = tuple_extract %44, 1
%47 = builtin "fadd_FPIEEE32"(%15, %5) : $Builtin.FPIEEE32
%49 = builtin "fadd_FPIEEE32"(%22, %5) : $Builtin.FPIEEE32
%50 = struct $Float (%49)
%52 = apply %46(%50) : $@callee_guaranteed (Float) -> Float
strong_release %46
%54 = struct_extract %52, #Float._value
%55 = builtin "fadd_FPIEEE32"(%54, %47) : $Builtin.FPIEEE32
%61 = unchecked_enum_data %45, #_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt
%62 = tuple_extract %61, 1
%63 = tuple_extract %61, 0
br bb3(%55, %32, %29, %5, %5, %63, %62)
bb2(%65 : $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float), @callee_guaranteed (Float) -> (Float, Float))):
%66 = tuple_extract %65, 0
%67 = tuple_extract %65, 1
%68 = tuple_extract %65, 2
%69 = builtin "fadd_FPIEEE32"(%15, %5) : $Builtin.FPIEEE32
%70 = struct $Float (%69)
%72 = apply %68(%70) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %68
%74 = tuple_extract %72, 0
%75 = tuple_extract %72, 1
%76 = struct_extract %74, #Float._value
%77 = builtin "fadd_FPIEEE32"(%76, %5) : $Builtin.FPIEEE32
%78 = struct_extract %75, #Float._value
%79 = builtin "fadd_FPIEEE32"(%78, %5) : $Builtin.FPIEEE32
%80 = builtin "fadd_FPIEEE32"(%32, %79) : $Builtin.FPIEEE32
%82 = builtin "fadd_FPIEEE32"(%5, %77) : $Builtin.FPIEEE32
%84 = builtin "fadd_FPIEEE32"(%22, %5) : $Builtin.FPIEEE32
%85 = struct $Float (%84)
%87 = apply %67(%85) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %67
%89 = tuple_extract %87, 0
%90 = tuple_extract %87, 1
%91 = struct_extract %89, #Float._value
%92 = builtin "fadd_FPIEEE32"(%91, %5) : $Builtin.FPIEEE32
%93 = struct_extract %90, #Float._value
%94 = builtin "fadd_FPIEEE32"(%93, %5) : $Builtin.FPIEEE32
%95 = builtin "fadd_FPIEEE32"(%80, %94) : $Builtin.FPIEEE32
%97 = builtin "fadd_FPIEEE32"(%5, %92) : $Builtin.FPIEEE32
%99 = builtin "fadd_FPIEEE32"(%29, %5) : $Builtin.FPIEEE32
%101 = builtin "fadd_FPIEEE32"(%99, %5) : $Builtin.FPIEEE32
%102 = builtin "fadd_FPIEEE32"(%95, %5) : $Builtin.FPIEEE32
%104 = builtin "fadd_FPIEEE32"(%102, %5) : $Builtin.FPIEEE32
%105 = builtin "fadd_FPIEEE32"(%101, %5) : $Builtin.FPIEEE32
%106 = builtin "fadd_FPIEEE32"(%104, %5) : $Builtin.FPIEEE32
%107 = builtin "fadd_FPIEEE32"(%105, %5) : $Builtin.FPIEEE32
%108 = builtin "fadd_FPIEEE32"(%106, %107) : $Builtin.FPIEEE32
%114 = unchecked_enum_data %66, #_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt
%115 = tuple_extract %114, 1
%116 = tuple_extract %114, 0
br bb3(%108, %5, %5, %97, %82, %116, %115)
bb3(%118 : $Builtin.FPIEEE32, %119 : $Builtin.FPIEEE32, %120 : $Builtin.FPIEEE32, %121 : $Builtin.FPIEEE32, %122 : $Builtin.FPIEEE32, %123 : $@callee_guaranteed (Float) -> (Float, Float), %124 : $@callee_guaranteed (Float) -> (Float, Float)):
%125 = builtin "fadd_FPIEEE32"(%122, %5) : $Builtin.FPIEEE32
%126 = struct $Float (%125)
%127 = apply %124(%126) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %124
%129 = tuple_extract %127, 0
%130 = tuple_extract %127, 1
%131 = struct_extract %129, #Float._value
%132 = builtin "fadd_FPIEEE32"(%131, %118) : $Builtin.FPIEEE32
%133 = struct_extract %130, #Float._value
%134 = builtin "fadd_FPIEEE32"(%133, %132) : $Builtin.FPIEEE32
%135 = builtin "fadd_FPIEEE32"(%121, %5) : $Builtin.FPIEEE32
%136 = struct $Float (%135)
%137 = apply %123(%136) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %123
%139 = tuple_extract %137, 0
%140 = tuple_extract %137, 1
%141 = struct_extract %139, #Float._value
%142 = builtin "fadd_FPIEEE32"(%141, %134) : $Builtin.FPIEEE32
%143 = struct_extract %140, #Float._value
%144 = builtin "fadd_FPIEEE32"(%143, %142) : $Builtin.FPIEEE32
%145 = builtin "fadd_FPIEEE32"(%120, %144) : $Builtin.FPIEEE32
%146 = builtin "fadd_FPIEEE32"(%119, %145) : $Builtin.FPIEEE32
%147 = struct $Float (%146)
return %147
} // end sil function '$s4test14cond_tuple_varyS2fFTJpSpSr'
// reverse-mode derivative of cond_tuple_var(_:)
sil hidden @$s4test14cond_tuple_varyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
[global: ]
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s4test14cond_tuple_varyS2fFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#F1:]], %[[#F2:]], %[[#F3:]]) : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER-NEXT: Passed in closures:
// TRUNNER-NEXT: 1. %[[#F1]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER-NEXT: 2. %[[#F2]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER-NEXT: 3. %[[#F3]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER-EMPTY:
//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s4test14cond_tuple_varyS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n
// CHECK: sil private @$s4test14cond_tuple_varyS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0) -> Float {
// CHECK: bb0(%0 : $Float, %1 : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0):
// CHECK: %[[#F2:]] = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)
// TRUNNER: %[[#F3:]] = thin_to_thick_function %[[#F2]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE-NOT: = thin_to_thick_function
// CHECK: %[[#F4:]] = function_ref @$sSf16_DifferentiationE12_vjpSubtract3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)
// TRUNNER: %[[#F5:]] = thin_to_thick_function %[[#F4]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE-NOT: = thin_to_thick_function
// CHECK: %[[#F6:]] = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)
// TRUNNER: %[[#F7:]] = thin_to_thick_function %[[#F6]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE-NOT: = thin_to_thick_function
// CHECK: %[[#F8:]] = float_literal $Builtin.FPIEEE32, 0x0 // 0
// TRUNNER: %[[#F9:]] = apply %[[#F7]](%0) : $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE: %[[#F9:]] = apply %[[#F6]](%0) : $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: strong_release %[[#F7]] : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: %[[#F11:]] = tuple_extract %[[#F9]] : $(Float, Float), 0
// TRUNNER: %[[#F15:]] = apply %[[#F5]](%[[#F11]]) : $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE: %[[#F15:]] = apply %[[#F4]](%[[#F11]]) : $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: strong_release %[[#F5]] : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: %[[#F17:]] = tuple_extract %[[#F15]] : $(Float, Float), 0
// TRUNNER: %[[#]] = apply %[[#F3]](%[[#F17]]) : $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE: %[[#]] = apply %[[#F2]](%[[#F17]]) : $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: strong_release %[[#F3]] : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: {{^}}bb1{{.*}}:
// CHECK: {{^}}bb2{{.*}}:
// CHECK: {{^}}bb3{{.*}}:
//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s4test14cond_tuple_varyS2fFTJrSpSr:
// CHECK: sil hidden @$s4test14cond_tuple_varyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK: bb0(%0 : $Float):
// CHECK: %[[#G7:]] = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)
// CHECK: %[[#G11:]] = function_ref @$sSf16_DifferentiationE12_vjpSubtract3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)
// CHECK: bb1:
// CHECK: bb2:
// CHECK: bb3(%[[#G31:]] : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, %[[#]] : $Builtin.FPIEEE32, %[[#]] : $Builtin.FPIEEE32):
// COMBINE-NOT: = thin_to_thick_function
// COMBINE-NOT: = function_ref @$s4test14cond_tuple_varyS2fFTJpSpSr
// TRUNNER: %[[#G33:]] = thin_to_thick_function %[[#G7]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: %[[#G34:]] = thin_to_thick_function %[[#G11]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: %[[#G35:]] = thin_to_thick_function %[[#G7]] : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: %[[#]] = function_ref @$s4test14cond_tuple_varyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
// CHECK: %[[#G41:]] = function_ref @$s4test14cond_tuple_varyS2fFTJpSpSr067$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktl1_m5FZSf_M6SfcfU_0ef1_g4E12_i16Subtract3lhs3rhsk1_l1_mnl1_mo1_mP2U_ACTf1nnccc_n : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0) -> Float
// CHECK: %[[#G42:]] = partial_apply [callee_guaranteed] %[[#G41]](%[[#G31]]) : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0) -> Float
// TRUNNER: release_value %[[#G33]] : $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: release_value %[[#G34]] : $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: release_value %[[#G35]] : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: %[[#G46:]] = tuple (%[[#]] : $Float, %[[#G42]] : $@callee_guaranteed (Float) -> Float)
// CHECK: return %[[#G46]]
%4 = struct_extract %0, #Float._value
%5 = builtin "fadd_FPIEEE32"(%4, %4) : $Builtin.FPIEEE32
// function_ref closure #1 in static Float._vjpAdd(lhs:rhs:)
%7 = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)
%8 = thin_to_thick_function %7 to $@callee_guaranteed (Float) -> (Float, Float)
%9 = builtin "fsub_FPIEEE32"(%4, %4) : $Builtin.FPIEEE32
// function_ref closure #1 in static Float._vjpSubtract(lhs:rhs:)
%11 = function_ref @$sSf16_DifferentiationE12_vjpSubtract3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float)
%12 = thin_to_thick_function %11 to $@callee_guaranteed (Float) -> (Float, Float)
%13 = float_literal $Builtin.FPIEEE32, 0x0 // 0
%14 = builtin "fcmp_olt_FPIEEE32"(%13, %4) : $Builtin.Int1
%15 = tuple (%8, %12)
cond_br %14, bb1, bb2
bb1:
%17 = enum $_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %15
%22 = builtin "fsub_FPIEEE32"(%5, %4) : $Builtin.FPIEEE32
%23 = thin_to_thick_function %11 to $@callee_guaranteed (Float) -> (Float, Float)
%25 = builtin "fadd_FPIEEE32"(%9, %4) : $Builtin.FPIEEE32
%27 = thin_to_thick_function %7 to $@callee_guaranteed (Float) -> (Float, Float)
%28 = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float), @callee_guaranteed (Float) -> (Float, Float)) (%17, %23, %27)
%29 = enum $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %28
br bb3(%29, %25, %22)
bb2:
%31 = enum $_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %15
%32 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1
%33 = struct $Float (%32)
%34 = builtin "fmul_FPIEEE32"(%32, %4) : $Builtin.FPIEEE32
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%36 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%37 = partial_apply [callee_guaranteed] %36(%0, %33) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// function_ref autodiff subset parameters thunk for pullback from @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
%38 = function_ref @$sS3fIegydd_TJSpSSUpSrUSUP : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%39 = partial_apply [callee_guaranteed] %38(%37) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%41 = tuple $(predecessor: _AD__$s4test14cond_tuple_varyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float) (%31, %39)
%42 = enum $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %41
br bb3(%42, %4, %34)
bb3(%44 : $_AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, %45 : $Builtin.FPIEEE32, %46 : $Builtin.FPIEEE32):
%47 = thin_to_thick_function %7 to $@callee_guaranteed (Float) -> (Float, Float)
%48 = builtin "fsub_FPIEEE32"(%5, %46) : $Builtin.FPIEEE32
%49 = thin_to_thick_function %11 to $@callee_guaranteed (Float) -> (Float, Float)
%50 = builtin "fadd_FPIEEE32"(%48, %45) : $Builtin.FPIEEE32
%51 = struct $Float (%50)
%52 = thin_to_thick_function %7 to $@callee_guaranteed (Float) -> (Float, Float)
// function_ref pullback of cond_tuple_var(_:)
%53 = function_ref @$s4test14cond_tuple_varyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
%54 = partial_apply [callee_guaranteed] %53(%44, %47, %49, %52) : $@convention(thin) (Float, @owned _AD__$s4test14cond_tuple_varyS2fF_bb3__Pred__src_0_wrt_0, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
%55 = tuple (%51, %54)
return %55
} // end sil function '$s4test14cond_tuple_varyS2fFTJrSpSr'