Files
Daniil Kovalev 5528cf1cc4 [AutoDiff] Run AutoDiff closure spec pass for all VJPs (#81548)
Previously, AutoDiff closure specialization pass was triggered only on
VJPs containing single basic block. However, the pass logic allows
running on arbitrary VJPs. This PR enables the pass for all VJPs
unconditionally. So, if the pullback corresponding to multiple-BB VJP
accepts some closures directly as arguments, these closures might become
specialized by the pass. Closures passed via payload of branch tracing
enum are not specialized - this is subject for future changes.

The PR contains several commits.
1. The thing named "call site" in the code is partial_apply of pullback
corresponding to the VJP. This might appear only once, so we drop
support for multiple "call sites".
2. Enhance existing SILOptimizer tests for the pass.
3. Add validation-tests for single basic block case.
4. The change itself - delete check against single basic block.
5. Add validation-tests for multiple basic block case.
6. Add SILOptimizer tests for multiple basic block case.
2025-07-07 13:00:14 +00:00

491 lines
38 KiB
Plaintext

// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK
// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK
// REQUIRES: swift_in_compiler
sil_stage canonical
import Builtin
import Swift
import SwiftShims
import _Differentiation
////////////////////////////////////////////////////////////////
// Single closure call site where closure is passed as @owned //
////////////////////////////////////////////////////////////////
sil @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
sil private @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)):
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %1 : $@callee_guaranteed (Float) -> (Float, Float)
%4 = tuple_extract %2 : $(Float, Float), 0
%5 = tuple_extract %2 : $(Float, Float), 1
%6 = struct_extract %5 : $Float, #Float._value
%7 = struct_extract %4 : $Float, #Float._value
%8 = builtin "fadd_FPIEEE32"(%6 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
%9 = struct $Float (%8 : $Builtin.FPIEEE32)
return %9 : $Float
}
// reverse-mode derivative of f(_:)
sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s4test1fyS2fFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#A1:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER-NEXT: Passed in closures:
// TRUNNER-NEXT: 1. %[[#A1]] = partial_apply [callee_guaranteed] %[[#]](%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER-EMPTY:
//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s11$pullback_f12$vjpMultiplyS2fTf1nc_n
// CHECK: sil private @$s11$pullback_f12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float {
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float):
// CHECK: %[[#A2:]] = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#A3:]] = partial_apply [callee_guaranteed] %[[#A2]](%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#]] = apply %[[#A3]](%0) : $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE: %[[#]] = apply %[[#A2]](%0, %1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: strong_release %[[#A3]] : $@callee_guaranteed (Float) -> (Float, Float)
//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s4test1fyS2fFTJrSpSr
// CHECK: sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK: bb0(%0 : $Float):
// CHECK: %[[#A4:]] = struct_extract %0 : $Float, #Float._value
// CHECK: %[[#A5:]] = builtin "fmul_FPIEEE32"(%[[#A4]] : $Builtin.FPIEEE32, %[[#A4]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %[[#A6:]] = struct $Float (%[[#A5]] : $Builtin.FPIEEE32)
// COMBINE-NOT: function_ref @$vjpMultiply
// CHECK: %[[#A7:]] = function_ref @$s11$pullback_f12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %[[#A8:]] = partial_apply [callee_guaranteed] %[[#A7]](%0, %0) : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %[[#A9:]] = tuple (%[[#A6]] : $Float, %[[#A8]] : $@callee_guaranteed (Float) -> Float)
// CHECK: return %[[#A9]]
%2 = struct_extract %0 : $Float, #Float._value
%3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
%4 = struct $Float (%3 : $Builtin.FPIEEE32)
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// function_ref pullback of f(_:)
%7 = function_ref @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
%8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
%9 = tuple (%4 : $Float, %8 : $@callee_guaranteed (Float) -> Float)
return %9 : $(Float, @callee_guaranteed (Float) -> Float)
}
/////////////////////////////////////////////////////////////////////
// Single closure call site where closure is passed as @guaranteed //
/////////////////////////////////////////////////////////////////////
sil private @$pullback_k : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)):
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> (Float, Float)
%3 = tuple_extract %2 : $(Float, Float), 0
%4 = tuple_extract %2 : $(Float, Float), 1
%5 = struct_extract %4 : $Float, #Float._value
%6 = struct_extract %3 : $Float, #Float._value
%7 = builtin "fadd_FPIEEE32"(%5 : $Builtin.FPIEEE32, %6 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
%8 = struct $Float (%7 : $Builtin.FPIEEE32)
return %8 : $Float
} // end sil function '$pullback_k'
// reverse-mode derivative of k(_:)
sil hidden @$s4test1kyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s4test1kyS2fFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#B1:]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: Passed in closures:
// TRUNNER: 1. %[[#B1]] = partial_apply [callee_guaranteed] %[[#]](%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER-EMPTY:
//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s11$pullback_k12$vjpMultiplyS2fTf1nc_n
// CHECK: sil private @$s11$pullback_k12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float {
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float):
// CHECK: %[[#B2:]] = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#B3:]] = partial_apply [callee_guaranteed] %[[#B2]](%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#]] = apply %[[#B3]](%0) : $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE: %[[#]] = apply %[[#B2]](%0, %1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: release_value %[[#B3]] : $@callee_guaranteed (Float) -> (Float, Float)
//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s4test1kyS2fFTJrSpSr
// CHECK: sil hidden @$s4test1kyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK: bb0(%0 : $Float):
// CHECK: %[[#B4:]] = struct_extract %0 : $Float, #Float._value
// CHECK: %[[#B5:]] = builtin "fmul_FPIEEE32"(%[[#B4]] : $Builtin.FPIEEE32, %[[#B4]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %[[#B6:]] = struct $Float (%[[#B5]] : $Builtin.FPIEEE32)
// COMBINE-NOT: function_ref @$vjpMultiply
// CHECK: %[[#B7:]] = function_ref @$s11$pullback_k12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %[[#B8:]] = partial_apply [callee_guaranteed] %[[#B7]](%0, %0) : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %[[#B9:]] = tuple (%[[#B6]] : $Float, %[[#B8]] : $@callee_guaranteed (Float) -> Float)
// CHECK: return %[[#B9]]
%2 = struct_extract %0 : $Float, #Float._value
%3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
%4 = struct $Float (%3 : $Builtin.FPIEEE32)
// function_ref $vjpMultiply
%5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// function_ref $pullback_k
%7 = function_ref @$pullback_k : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
%8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
strong_release %6 : $@callee_guaranteed (Float) -> (Float, Float)
%10 = tuple (%4 : $Float, %8 : $@callee_guaranteed (Float) -> Float)
return %10 : $(Float, @callee_guaranteed (Float) -> Float)
} // end sil function '$s4test1kyS2fFTJrSpSr'
///////////////////////////////
// Multiple closure callsite //
///////////////////////////////
sil @$vjpSin : $@convention(thin) (Float, Float) -> Float
sil @$vjpCos : $@convention(thin) (Float, Float) -> Float
// pullback of g(_:)
sil private @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float, %2 : $@callee_guaranteed (Float) -> Float, %3 : $@callee_guaranteed (Float) -> (Float, Float)):
%4 = apply %3(%0) : $@callee_guaranteed (Float) -> (Float, Float)
strong_release %3 : $@callee_guaranteed (Float) -> (Float, Float)
%6 = tuple_extract %4 : $(Float, Float), 0
%7 = tuple_extract %4 : $(Float, Float), 1
%8 = apply %2(%7) : $@callee_guaranteed (Float) -> Float
strong_release %2 : $@callee_guaranteed (Float) -> Float
%10 = apply %1(%6) : $@callee_guaranteed (Float) -> Float
strong_release %1 : $@callee_guaranteed (Float) -> Float
%12 = struct_extract %8 : $Float, #Float._value
%13 = struct_extract %10 : $Float, #Float._value
%14 = builtin "fadd_FPIEEE32"(%13 : $Builtin.FPIEEE32, %12 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
%15 = struct $Float (%14 : $Builtin.FPIEEE32)
return %15 : $Float
}
// reverse-mode derivative of g(_:)
sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s4test1gyS2fFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#C1:]], %[[#C2:]], %[[#C3:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
// TRUNNER: Passed in closures:
// TRUNNER: 1. %[[#C1]] = partial_apply [callee_guaranteed] %[[#]](%0) : $@convention(thin) (Float, Float) -> Float
// TRUNNER: 2. %[[#C2]] = partial_apply [callee_guaranteed] %[[#]](%0) : $@convention(thin) (Float, Float) -> Float
// TRUNNER: 3. %[[#C3]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER-EMPTY:
//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n
// CHECK: sil private @$s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float {
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float, %3 : $Float, %4 : $Float):
// CHECK: %[[#C4:]] = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float
// TRUNNER: %[[#C5:]] = partial_apply [callee_guaranteed] %[[#C4]](%1) : $@convention(thin) (Float, Float) -> Float
// CHECK: %[[#C6:]] = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float
// TRUNNER: %[[#C7:]] = partial_apply [callee_guaranteed] %[[#C6]](%2) : $@convention(thin) (Float, Float) -> Float
// CHECK: %[[#C8:]] = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#C9:]] = partial_apply [callee_guaranteed] %[[#C8]](%3, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: %[[#]] = apply %[[#C9]](%0) : $@callee_guaranteed (Float) -> (Float, Float)
// COMBINE: %[[#]] = apply %[[#C8]](%0, %3, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER: strong_release %[[#C9]] : $@callee_guaranteed (Float) -> (Float, Float)
// TRUNNER: %[[#]] = apply %[[#C7]](%[[#]]) : $@callee_guaranteed (Float) -> Float
// COMBINE: %[[#]] = apply %[[#C6]](%[[#]], %2) : $@convention(thin) (Float, Float) -> Float
// TRUNNER: strong_release %[[#C7]] : $@callee_guaranteed (Float) -> Float
// TRUNNER: %[[#]] = apply %[[#C5]](%[[#]]) : $@callee_guaranteed (Float) -> Float
// COMBINE: %[[#]] = apply %[[#C4]](%[[#]], %1) : $@convention(thin) (Float, Float) -> Float
// TRUNNER: strong_release %[[#C5]] : $@callee_guaranteed (Float) -> Float
//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s4test1gyS2fFTJrSpSr
// CHECK: sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK: bb0(%0 : $Float):
// CHECK: %[[#C10:]] = struct_extract %0 : $Float, #Float._value
// CHECK: %[[#C11:]] = builtin "int_sin_FPIEEE32"(%[[#C10]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %[[#C12:]] = struct $Float (%[[#C11]] : $Builtin.FPIEEE32)
// COMBINE-NOT: function_ref @$vjpSin
// CHECK: %[[#C13:]] = builtin "int_cos_FPIEEE32"(%[[#C10]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// COMBINE-NOT: function_ref @$vjpCos
// COMBINE-NOT: function_ref @$vjpMultiply
// CHECK: %[[#C14:]] = struct $Float (%[[#C13]] : $Builtin.FPIEEE32)
// CHECK: %[[#C15:]] = function_ref @$s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float
// CHECK: %[[#C16:]] = partial_apply [callee_guaranteed] %[[#C15]](%0, %0, %[[#C14]], %[[#C12]]) : $@convention(thin) (Float, Float, Float, Float, Float) -> Float
// CHECK: %[[#C17:]] = tuple (%[[#]] : $Float, %[[#C16]] : $@callee_guaranteed (Float) -> Float)
// CHECK: return %[[#C17]]
%2 = struct_extract %0 : $Float, #Float._value
%3 = builtin "int_sin_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
%4 = struct $Float (%3 : $Builtin.FPIEEE32)
// function_ref closure #1 in _vjpSin(_:)
%5 = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float
%6 = partial_apply [callee_guaranteed] %5(%0) : $@convention(thin) (Float, Float) -> Float
%7 = builtin "int_cos_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
%8 = struct $Float (%7 : $Builtin.FPIEEE32)
// function_ref closure #1 in _vjpCos(_:)
%9 = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float
%10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float
%11 = builtin "fmul_FPIEEE32"(%3 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
%12 = struct $Float (%11 : $Builtin.FPIEEE32)
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%13 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// function_ref pullback of g(_:)
%15 = function_ref @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
%16 = partial_apply [callee_guaranteed] %15(%6, %10, %14) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
%17 = tuple (%12 : $Float, %16 : $@callee_guaranteed (Float) -> Float)
return %17 : $(Float, @callee_guaranteed (Float) -> Float)
}
///////////////////////////////
/// Parameter subset thunks ///
///////////////////////////////
struct X : Differentiable {
@_hasStorage var a: Float { get set }
@_hasStorage var b: Double { get set }
struct TangentVector : AdditiveArithmetic, Differentiable {
@_hasStorage var a: Float { get set }
@_hasStorage var b: Double { get set }
static func + (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector
static func - (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X.TangentVector, _ b: X.TangentVector) -> Bool
typealias TangentVector = X.TangentVector
init(a: Float, b: Double)
static var zero: X.TangentVector { get }
}
init(a: Float, b: Double)
mutating func move(by offset: X.TangentVector)
}
sil [transparent] [thunk] @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
sil @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector
sil shared @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> X.TangentVector):
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> X.TangentVector
strong_release %1 : $@callee_guaranteed (Float) -> X.TangentVector
return %2 : $X.TangentVector
}
sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) {
bb0(%0 : $X):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s5test21g1xSfAA1XV_tFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector
// TRUNNER: Passed in closures:
// TRUNNER: 1. %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector
// TRUNNER-EMPTY:
//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s10pullback_g0A2_fTf1nc_n
// CHECK: sil shared @$s10pullback_g0A2_fTf1nc_n : $@convention(thin) (Float) -> X.TangentVector {
// CHECK: bb0(%0 : $Float):
// CHECK: %[[#D1:]] = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector
// CHECK: %[[#D2:]] = thin_to_thick_function %[[#D1]] : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector
// CHECK: %[[#D3:]] = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
// TRUNNER: %[[#D4:]] = partial_apply [callee_guaranteed] %[[#D3]](%[[#D2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
// TRUNNER: %[[#D5:]] = apply %[[#D4]](%0) : $@callee_guaranteed (Float) -> X.TangentVector
// COMBINE: %[[#D5:]] = apply %[[#D3]](%0, %[[#D2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
// TRUNNER: strong_release %[[#D4]] : $@callee_guaranteed (Float) -> X.TangentVector
// CHECK: return %[[#D5]]
//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s5test21g1xSfAA1XV_tFTJrSpSr
// CHECK: sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) {
// CHECK: bb0(%0 : $X):
// CHECK: %[[#D6:]] = struct_extract %0 : $X, #X.a
// COMBINE-NOT: function_ref @pullback_f
// COMBINE-NOT: function_ref @subset_parameter_thunk
// CHECK: %[[#D7:]] = function_ref @$s10pullback_g0A2_fTf1nc_n : $@convention(thin) (Float) -> X.TangentVector
// TRUNNER: %[[#D8:]] = partial_apply [callee_guaranteed] %[[#D7]]() : $@convention(thin) (Float) -> X.TangentVector
// COMBINE: %[[#D8:]] = thin_to_thick_function %[[#D7]] : $@convention(thin) (Float) -> X.TangentVector to $@callee_guaranteed (Float) -> X.TangentVector
// CHECK: %[[#D9:]] = tuple (%[[#D6]] : $Float, %[[#D8]] : $@callee_guaranteed (Float) -> X.TangentVector)
// CHECK: return %[[#D9]]
%1 = struct_extract %0 : $X, #X.a
// function_ref pullback_f
%2 = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector
%3 = thin_to_thick_function %2 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector
// function_ref subset_parameter_thunk
%4 = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
%5 = partial_apply [callee_guaranteed] %4(%3) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
// function_ref pullback_g
%6 = function_ref @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector
%7 = partial_apply [callee_guaranteed] %6(%5) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector
%8 = tuple (%1 : $Float, %7 : $@callee_guaranteed (Float) -> X.TangentVector)
return %8 : $(Float, @callee_guaranteed (Float) -> X.TangentVector)
}
///////////////////////////////////////////////////////////////////////
///////// Specialized generic closures - PartialApply Closure /////////
///////////////////////////////////////////////////////////////////////
// closure #1 in static Float._vjpMultiply(lhs:rhs:)
sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
sil [transparent] [reabstraction_thunk] @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
// function_ref specialized pullback of f<A>(a:)
sil [transparent] [thunk] @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
// thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float)
sil [transparent] [reabstraction_thunk] @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
sil private [signature_optimized_thunk] [always_inline] @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float):
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float
strong_release %1 : $@callee_guaranteed (Float) -> Float
return %2 : $Float
}
// reverse-mode derivative of h(x:)
sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s5test21h1xS2f_tFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
// TRUNNER: Passed in closures:
// TRUNNER: 1. %[[#]] = partial_apply [callee_guaranteed] %[[#]](%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// TRUNNER-EMPTY:
//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n
// CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float {
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float):
// CHECK: %[[#E1:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: %[[#E2:]] = partial_apply [callee_guaranteed] %[[#E1]](%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: %[[#E3:]] = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
// CHECK: %[[#E4:]] = partial_apply [callee_guaranteed] %[[#E3]](%[[#E2]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
// CHECK: %[[#E5:]] = convert_function %[[#E4]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>
// CHECK: %[[#E6:]] = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
// CHECK: %[[#E7:]] = partial_apply [callee_guaranteed] %[[#E6]](%[[#E5]]) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
// CHECK: %[[#E8:]] = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
// TRUNNER: %[[#E9:]] = partial_apply [callee_guaranteed] %[[#E8]](%[[#E7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
// TRUNNER: %[[#E10:]] = apply %[[#E9]](%0) : $@callee_guaranteed (Float) -> Float
// COMBINE: %[[#E10:]] = apply %[[#E8]](%0, %[[#E7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
// TRUNNER: strong_release %[[#E9]] : $@callee_guaranteed (Float) -> Float
// COMBINE: strong_release %[[#E7]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float
// CHECK: return %[[#E10]]
//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s5test21h1xS2f_tFTJrSpSr
// CHECK: sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK: bb0(%0 : $Float):
// CHECK: %[[#E11:]] = struct_extract %0 : $Float, #Float._value
// CHECK: %[[#E12:]] = builtin "fmul_FPIEEE32"(%[[#E11]] : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// COMBINE-NOT: function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_
// COMBINE-NOT: function_ref @$sS3fIegydd_S3fIegnrr_TR
// COMBINE-NOT: function_ref @pullback_f_specialized
// COMBINE-NOT: function_ref @$sS2fIegnr_S2fIegyd_TR
// CHECK: %[[#E13:]] = struct $Float (%[[#E12]] : $Builtin.FPIEEE32)
// CHECK: %[[#E14:]] = function_ref @$s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %[[#E15:]] = partial_apply [callee_guaranteed] %[[#E14]](%0, %0) : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %[[#E16:]] = tuple (%[[#E13]] : $Float, %[[#E15]] : $@callee_guaranteed (Float) -> Float)
// CHECK: return %[[#E16]]
%1 = struct_extract %0 : $Float, #Float._value
%2 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%3 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
%4 = partial_apply [callee_guaranteed] %3(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// function_ref thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
%5 = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
%6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
%7 = convert_function %6 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>
// function_ref pullback_f_specialized
%8 = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
%9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
// function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float)
%10 = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
%11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
%12 = struct $Float (%2 : $Builtin.FPIEEE32)
// function_ref pullback_h
%13 = function_ref @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
%14 = partial_apply [callee_guaranteed] %13(%11) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
%15 = tuple (%12 : $Float, %14 : $@callee_guaranteed (Float) -> Float)
return %15 : $(Float, @callee_guaranteed (Float) -> Float)
}
//////////////////////////////////////////////////////////////////////////////
///////// Specialized generic closures - ThinToThickFunction closure /////////
//////////////////////////////////////////////////////////////////////////////
sil [transparent] [thunk] @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float
sil [transparent] [reabstraction_thunk] @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
sil private [signature_optimized_thunk] [always_inline] @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float {
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float):
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float
strong_release %1 : $@callee_guaranteed (Float) -> Float
return %2 : $Float
}
sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
//=========== Test callsite and closure gathering logic ===========//
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
// TRUNNER-LABEL: Specializing closures in function: $s5test21z1xS2f_tFTJrSpSr
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
// TRUNNER: Passed in closures:
// TRUNNER: 1. %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float
// TRUNNER-EMPTY:
//=========== Test specialized function signature and body ===========//
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
// TRUNNER-LABEL: Generated specialized function: $s10pullback_z0A14_y_specializedTf1nc_n
// CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_z0A14_y_specializedTf1nc_n : $@convention(thin) (Float) -> Float {
// CHECK: bb0(%0 : $Float):
// CHECK: %[[#F1:]] = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float
// CHECK: %[[#F2:]] = thin_to_thick_function %[[#F1]] : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float
// CHECK: %[[#F3:]] = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
// TRUNNER: %[[#F4:]] = partial_apply [callee_guaranteed] %[[#F3]](%[[#F2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
// TRUNNER: %[[#F5:]] = apply %[[#F4]](%0) : $@callee_guaranteed (Float) -> Float
// COMBINE: %[[#F5:]] = apply %[[#F3]](%0, %[[#F2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
// TRUNNER: strong_release %[[#F4]] : $@callee_guaranteed (Float) -> Float
// CHECK: return %[[#F5]]
//=========== Test rewritten body ===========//
specify_test "autodiff_closure_specialize_rewritten_caller_body"
// TRUNNER-LABEL: Rewritten caller body for: $s5test21z1xS2f_tFTJrSpSr
// CHECK: sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK: bb0(%0 : $Float):
// COMBINE-NOT: function_ref @pullback_y_specialized
// COMBINE-NOT: function_ref @reabstraction_thunk
// CHECK: %[[#F6:]] = function_ref @$s10pullback_z0A14_y_specializedTf1nc_n : $@convention(thin) (Float) -> Float
// TRUNNER: %[[#F7:]] = partial_apply [callee_guaranteed] %[[#F6]]() : $@convention(thin) (Float) -> Float
// COMBINE: %[[#F7:]] = thin_to_thick_function %[[#F6]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
// CHECK: %[[#F8:]] = tuple (%0 : $Float, %[[#F7]] : $@callee_guaranteed (Float) -> Float)
// CHECK: return %[[#F8]]
// function_ref pullback_y_specialized
%1 = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float
%2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float
// function_ref reabstraction_thunk
%3 = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
%4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
// function_ref pullback_z
%5 = function_ref @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
%6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
%7 = tuple (%0 : $Float, %6 : $@callee_guaranteed (Float) -> Float)
return %7 : $(Float, @callee_guaranteed (Float) -> Float)
}