mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Previously, AutoDiff closure specialization pass was triggered only on VJPs containing single basic block. However, the pass logic allows running on arbitrary VJPs. This PR enables the pass for all VJPs unconditionally. So, if the pullback corresponding to multiple-BB VJP accepts some closures directly as arguments, these closures might become specialized by the pass. Closures passed via payload of branch tracing enum are not specialized - this is subject for future changes. The PR contains several commits. 1. The thing named "call site" in the code is partial_apply of pullback corresponding to the VJP. This might appear only once, so we drop support for multiple "call sites". 2. Enhance existing SILOptimizer tests for the pass. 3. Add validation-tests for single basic block case. 4. The change itself - delete check against single basic block. 5. Add validation-tests for multiple basic block case. 6. Add SILOptimizer tests for multiple basic block case.
491 lines
38 KiB
Plaintext
491 lines
38 KiB
Plaintext
// RUN: %target-sil-opt -sil-print-types -test-runner %s -o /dev/null 2>&1 | %FileCheck %s --check-prefixes=TRUNNER,CHECK
|
|
// RUN: %target-sil-opt -sil-print-types -autodiff-closure-specialization -sil-combine %s -o - | %FileCheck %s --check-prefixes=COMBINE,CHECK
|
|
|
|
// REQUIRES: swift_in_compiler
|
|
|
|
sil_stage canonical
|
|
|
|
import Builtin
|
|
import Swift
|
|
import SwiftShims
|
|
|
|
import _Differentiation
|
|
|
|
////////////////////////////////////////////////////////////////
|
|
// Single closure call site where closure is passed as @owned //
|
|
////////////////////////////////////////////////////////////////
|
|
sil @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
|
|
sil private @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {
|
|
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)):
|
|
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> (Float, Float)
|
|
strong_release %1 : $@callee_guaranteed (Float) -> (Float, Float)
|
|
%4 = tuple_extract %2 : $(Float, Float), 0
|
|
%5 = tuple_extract %2 : $(Float, Float), 1
|
|
%6 = struct_extract %5 : $Float, #Float._value
|
|
%7 = struct_extract %4 : $Float, #Float._value
|
|
%8 = builtin "fadd_FPIEEE32"(%6 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
%9 = struct $Float (%8 : $Builtin.FPIEEE32)
|
|
return %9 : $Float
|
|
}
|
|
|
|
// reverse-mode derivative of f(_:)
|
|
sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
bb0(%0 : $Float):
|
|
//=========== Test callsite and closure gathering logic ===========//
|
|
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
|
|
// TRUNNER-LABEL: Specializing closures in function: $s4test1fyS2fFTJrSpSr
|
|
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#A1:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
|
|
// TRUNNER-NEXT: Passed in closures:
|
|
// TRUNNER-NEXT: 1. %[[#A1]] = partial_apply [callee_guaranteed] %[[#]](%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER-EMPTY:
|
|
|
|
//=========== Test specialized function signature and body ===========//
|
|
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
|
|
// TRUNNER-LABEL: Generated specialized function: $s11$pullback_f12$vjpMultiplyS2fTf1nc_n
|
|
// CHECK: sil private @$s11$pullback_f12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float {
|
|
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float):
|
|
// CHECK: %[[#A2:]] = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER: %[[#A3:]] = partial_apply [callee_guaranteed] %[[#A2]](%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER: %[[#]] = apply %[[#A3]](%0) : $@callee_guaranteed (Float) -> (Float, Float)
|
|
// COMBINE: %[[#]] = apply %[[#A2]](%0, %1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER: strong_release %[[#A3]] : $@callee_guaranteed (Float) -> (Float, Float)
|
|
|
|
//=========== Test rewritten body ===========//
|
|
specify_test "autodiff_closure_specialize_rewritten_caller_body"
|
|
// TRUNNER-LABEL: Rewritten caller body for: $s4test1fyS2fFTJrSpSr
|
|
// CHECK: sil hidden @$s4test1fyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
// CHECK: bb0(%0 : $Float):
|
|
// CHECK: %[[#A4:]] = struct_extract %0 : $Float, #Float._value
|
|
// CHECK: %[[#A5:]] = builtin "fmul_FPIEEE32"(%[[#A4]] : $Builtin.FPIEEE32, %[[#A4]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
// CHECK: %[[#A6:]] = struct $Float (%[[#A5]] : $Builtin.FPIEEE32)
|
|
// COMBINE-NOT: function_ref @$vjpMultiply
|
|
// CHECK: %[[#A7:]] = function_ref @$s11$pullback_f12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float
|
|
// CHECK: %[[#A8:]] = partial_apply [callee_guaranteed] %[[#A7]](%0, %0) : $@convention(thin) (Float, Float, Float) -> Float
|
|
// CHECK: %[[#A9:]] = tuple (%[[#A6]] : $Float, %[[#A8]] : $@callee_guaranteed (Float) -> Float)
|
|
// CHECK: return %[[#A9]]
|
|
|
|
%2 = struct_extract %0 : $Float, #Float._value
|
|
%3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
%4 = struct $Float (%3 : $Builtin.FPIEEE32)
|
|
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
|
|
%5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
%6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// function_ref pullback of f(_:)
|
|
%7 = function_ref @$pullback_f : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
|
|
%8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
|
|
%9 = tuple (%4 : $Float, %8 : $@callee_guaranteed (Float) -> Float)
|
|
return %9 : $(Float, @callee_guaranteed (Float) -> Float)
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////
|
|
// Single closure call site where closure is passed as @guaranteed //
|
|
/////////////////////////////////////////////////////////////////////
|
|
sil private @$pullback_k : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float {
|
|
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> (Float, Float)):
|
|
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> (Float, Float)
|
|
%3 = tuple_extract %2 : $(Float, Float), 0
|
|
%4 = tuple_extract %2 : $(Float, Float), 1
|
|
%5 = struct_extract %4 : $Float, #Float._value
|
|
%6 = struct_extract %3 : $Float, #Float._value
|
|
%7 = builtin "fadd_FPIEEE32"(%5 : $Builtin.FPIEEE32, %6 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
%8 = struct $Float (%7 : $Builtin.FPIEEE32)
|
|
return %8 : $Float
|
|
} // end sil function '$pullback_k'
|
|
|
|
// reverse-mode derivative of k(_:)
|
|
sil hidden @$s4test1kyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
bb0(%0 : $Float):
|
|
//=========== Test callsite and closure gathering logic ===========//
|
|
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
|
|
// TRUNNER-LABEL: Specializing closures in function: $s4test1kyS2fFTJrSpSr
|
|
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#B1:]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
|
|
// TRUNNER: Passed in closures:
|
|
// TRUNNER: 1. %[[#B1]] = partial_apply [callee_guaranteed] %[[#]](%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER-EMPTY:
|
|
|
|
//=========== Test specialized function signature and body ===========//
|
|
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
|
|
// TRUNNER-LABEL: Generated specialized function: $s11$pullback_k12$vjpMultiplyS2fTf1nc_n
|
|
// CHECK: sil private @$s11$pullback_k12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float {
|
|
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float):
|
|
// CHECK: %[[#B2:]] = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER: %[[#B3:]] = partial_apply [callee_guaranteed] %[[#B2]](%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER: %[[#]] = apply %[[#B3]](%0) : $@callee_guaranteed (Float) -> (Float, Float)
|
|
// COMBINE: %[[#]] = apply %[[#B2]](%0, %1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER: release_value %[[#B3]] : $@callee_guaranteed (Float) -> (Float, Float)
|
|
|
|
//=========== Test rewritten body ===========//
|
|
specify_test "autodiff_closure_specialize_rewritten_caller_body"
|
|
// TRUNNER-LABEL: Rewritten caller body for: $s4test1kyS2fFTJrSpSr
|
|
// CHECK: sil hidden @$s4test1kyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
// CHECK: bb0(%0 : $Float):
|
|
// CHECK: %[[#B4:]] = struct_extract %0 : $Float, #Float._value
|
|
// CHECK: %[[#B5:]] = builtin "fmul_FPIEEE32"(%[[#B4]] : $Builtin.FPIEEE32, %[[#B4]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
// CHECK: %[[#B6:]] = struct $Float (%[[#B5]] : $Builtin.FPIEEE32)
|
|
// COMBINE-NOT: function_ref @$vjpMultiply
|
|
// CHECK: %[[#B7:]] = function_ref @$s11$pullback_k12$vjpMultiplyS2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float
|
|
// CHECK: %[[#B8:]] = partial_apply [callee_guaranteed] %[[#B7]](%0, %0) : $@convention(thin) (Float, Float, Float) -> Float
|
|
// CHECK: %[[#B9:]] = tuple (%[[#B6]] : $Float, %[[#B8]] : $@callee_guaranteed (Float) -> Float)
|
|
// CHECK: return %[[#B9]]
|
|
|
|
%2 = struct_extract %0 : $Float, #Float._value
|
|
%3 = builtin "fmul_FPIEEE32"(%2 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
%4 = struct $Float (%3 : $Builtin.FPIEEE32)
|
|
// function_ref $vjpMultiply
|
|
%5 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
%6 = partial_apply [callee_guaranteed] %5(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// function_ref $pullback_k
|
|
%7 = function_ref @$pullback_k : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
|
|
%8 = partial_apply [callee_guaranteed] %7(%6) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> Float
|
|
strong_release %6 : $@callee_guaranteed (Float) -> (Float, Float)
|
|
%10 = tuple (%4 : $Float, %8 : $@callee_guaranteed (Float) -> Float)
|
|
return %10 : $(Float, @callee_guaranteed (Float) -> Float)
|
|
} // end sil function '$s4test1kyS2fFTJrSpSr'
|
|
|
|
///////////////////////////////
|
|
// Multiple closure callsite //
|
|
///////////////////////////////
|
|
sil @$vjpSin : $@convention(thin) (Float, Float) -> Float
|
|
sil @$vjpCos : $@convention(thin) (Float, Float) -> Float
|
|
|
|
// pullback of g(_:)
|
|
sil private @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float {
|
|
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float, %2 : $@callee_guaranteed (Float) -> Float, %3 : $@callee_guaranteed (Float) -> (Float, Float)):
|
|
%4 = apply %3(%0) : $@callee_guaranteed (Float) -> (Float, Float)
|
|
strong_release %3 : $@callee_guaranteed (Float) -> (Float, Float)
|
|
%6 = tuple_extract %4 : $(Float, Float), 0
|
|
%7 = tuple_extract %4 : $(Float, Float), 1
|
|
%8 = apply %2(%7) : $@callee_guaranteed (Float) -> Float
|
|
strong_release %2 : $@callee_guaranteed (Float) -> Float
|
|
%10 = apply %1(%6) : $@callee_guaranteed (Float) -> Float
|
|
strong_release %1 : $@callee_guaranteed (Float) -> Float
|
|
%12 = struct_extract %8 : $Float, #Float._value
|
|
%13 = struct_extract %10 : $Float, #Float._value
|
|
%14 = builtin "fadd_FPIEEE32"(%13 : $Builtin.FPIEEE32, %12 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
%15 = struct $Float (%14 : $Builtin.FPIEEE32)
|
|
return %15 : $Float
|
|
}
|
|
|
|
// reverse-mode derivative of g(_:)
|
|
sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
bb0(%0 : $Float):
|
|
//=========== Test callsite and closure gathering logic ===========//
|
|
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
|
|
// TRUNNER-LABEL: Specializing closures in function: $s4test1gyS2fFTJrSpSr
|
|
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#C1:]], %[[#C2:]], %[[#C3:]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
|
|
// TRUNNER: Passed in closures:
|
|
// TRUNNER: 1. %[[#C1]] = partial_apply [callee_guaranteed] %[[#]](%0) : $@convention(thin) (Float, Float) -> Float
|
|
// TRUNNER: 2. %[[#C2]] = partial_apply [callee_guaranteed] %[[#]](%0) : $@convention(thin) (Float, Float) -> Float
|
|
// TRUNNER: 3. %[[#C3]] = partial_apply [callee_guaranteed] %[[#]](%[[#]], %[[#]]) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER-EMPTY:
|
|
|
|
//=========== Test specialized function signature and body ===========//
|
|
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
|
|
// TRUNNER-LABEL: Generated specialized function: $s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n
|
|
// CHECK: sil private @$s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float {
|
|
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float, %3 : $Float, %4 : $Float):
|
|
// CHECK: %[[#C4:]] = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float
|
|
// TRUNNER: %[[#C5:]] = partial_apply [callee_guaranteed] %[[#C4]](%1) : $@convention(thin) (Float, Float) -> Float
|
|
// CHECK: %[[#C6:]] = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float
|
|
// TRUNNER: %[[#C7:]] = partial_apply [callee_guaranteed] %[[#C6]](%2) : $@convention(thin) (Float, Float) -> Float
|
|
// CHECK: %[[#C8:]] = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER: %[[#C9:]] = partial_apply [callee_guaranteed] %[[#C8]](%3, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER: %[[#]] = apply %[[#C9]](%0) : $@callee_guaranteed (Float) -> (Float, Float)
|
|
// COMBINE: %[[#]] = apply %[[#C8]](%0, %3, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER: strong_release %[[#C9]] : $@callee_guaranteed (Float) -> (Float, Float)
|
|
// TRUNNER: %[[#]] = apply %[[#C7]](%[[#]]) : $@callee_guaranteed (Float) -> Float
|
|
// COMBINE: %[[#]] = apply %[[#C6]](%[[#]], %2) : $@convention(thin) (Float, Float) -> Float
|
|
// TRUNNER: strong_release %[[#C7]] : $@callee_guaranteed (Float) -> Float
|
|
// TRUNNER: %[[#]] = apply %[[#C5]](%[[#]]) : $@callee_guaranteed (Float) -> Float
|
|
// COMBINE: %[[#]] = apply %[[#C4]](%[[#]], %1) : $@convention(thin) (Float, Float) -> Float
|
|
// TRUNNER: strong_release %[[#C5]] : $@callee_guaranteed (Float) -> Float
|
|
|
|
//=========== Test rewritten body ===========//
|
|
specify_test "autodiff_closure_specialize_rewritten_caller_body"
|
|
// TRUNNER-LABEL: Rewritten caller body for: $s4test1gyS2fFTJrSpSr
|
|
// CHECK: sil hidden @$s4test1gyS2fFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
// CHECK: bb0(%0 : $Float):
|
|
// CHECK: %[[#C10:]] = struct_extract %0 : $Float, #Float._value
|
|
// CHECK: %[[#C11:]] = builtin "int_sin_FPIEEE32"(%[[#C10]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
// CHECK: %[[#C12:]] = struct $Float (%[[#C11]] : $Builtin.FPIEEE32)
|
|
// COMBINE-NOT: function_ref @$vjpSin
|
|
// CHECK: %[[#C13:]] = builtin "int_cos_FPIEEE32"(%[[#C10]] : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
// COMBINE-NOT: function_ref @$vjpCos
|
|
// COMBINE-NOT: function_ref @$vjpMultiply
|
|
// CHECK: %[[#C14:]] = struct $Float (%[[#C13]] : $Builtin.FPIEEE32)
|
|
// CHECK: %[[#C15:]] = function_ref @$s11$pullback_g7$vjpSinSf0B3CosSf0B8MultiplyS2fTf1nccc_n : $@convention(thin) (Float, Float, Float, Float, Float) -> Float
|
|
// CHECK: %[[#C16:]] = partial_apply [callee_guaranteed] %[[#C15]](%0, %0, %[[#C14]], %[[#C12]]) : $@convention(thin) (Float, Float, Float, Float, Float) -> Float
|
|
// CHECK: %[[#C17:]] = tuple (%[[#]] : $Float, %[[#C16]] : $@callee_guaranteed (Float) -> Float)
|
|
// CHECK: return %[[#C17]]
|
|
|
|
%2 = struct_extract %0 : $Float, #Float._value
|
|
%3 = builtin "int_sin_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
%4 = struct $Float (%3 : $Builtin.FPIEEE32)
|
|
// function_ref closure #1 in _vjpSin(_:)
|
|
%5 = function_ref @$vjpSin : $@convention(thin) (Float, Float) -> Float
|
|
%6 = partial_apply [callee_guaranteed] %5(%0) : $@convention(thin) (Float, Float) -> Float
|
|
%7 = builtin "int_cos_FPIEEE32"(%2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
%8 = struct $Float (%7 : $Builtin.FPIEEE32)
|
|
// function_ref closure #1 in _vjpCos(_:)
|
|
%9 = function_ref @$vjpCos : $@convention(thin) (Float, Float) -> Float
|
|
%10 = partial_apply [callee_guaranteed] %9(%0) : $@convention(thin) (Float, Float) -> Float
|
|
%11 = builtin "fmul_FPIEEE32"(%3 : $Builtin.FPIEEE32, %7 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
%12 = struct $Float (%11 : $Builtin.FPIEEE32)
|
|
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
|
|
%13 = function_ref @$vjpMultiply : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
%14 = partial_apply [callee_guaranteed] %13(%8, %4) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// function_ref pullback of g(_:)
|
|
%15 = function_ref @$pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
|
|
%16 = partial_apply [callee_guaranteed] %15(%6, %10, %14) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> Float, @owned @callee_guaranteed (Float) -> (Float, Float)) -> Float
|
|
%17 = tuple (%12 : $Float, %16 : $@callee_guaranteed (Float) -> Float)
|
|
return %17 : $(Float, @callee_guaranteed (Float) -> Float)
|
|
}
|
|
|
|
///////////////////////////////
|
|
/// Parameter subset thunks ///
|
|
///////////////////////////////
|
|
struct X : Differentiable {
|
|
@_hasStorage var a: Float { get set }
|
|
@_hasStorage var b: Double { get set }
|
|
struct TangentVector : AdditiveArithmetic, Differentiable {
|
|
@_hasStorage var a: Float { get set }
|
|
@_hasStorage var b: Double { get set }
|
|
static func + (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector
|
|
static func - (lhs: X.TangentVector, rhs: X.TangentVector) -> X.TangentVector
|
|
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X.TangentVector, _ b: X.TangentVector) -> Bool
|
|
typealias TangentVector = X.TangentVector
|
|
init(a: Float, b: Double)
|
|
static var zero: X.TangentVector { get }
|
|
}
|
|
init(a: Float, b: Double)
|
|
mutating func move(by offset: X.TangentVector)
|
|
}
|
|
|
|
sil [transparent] [thunk] @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
|
|
|
|
sil @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector
|
|
|
|
sil shared @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector {
|
|
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> X.TangentVector):
|
|
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> X.TangentVector
|
|
strong_release %1 : $@callee_guaranteed (Float) -> X.TangentVector
|
|
return %2 : $X.TangentVector
|
|
}
|
|
|
|
sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) {
|
|
bb0(%0 : $X):
|
|
//=========== Test callsite and closure gathering logic ===========//
|
|
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
|
|
// TRUNNER-LABEL: Specializing closures in function: $s5test21g1xSfAA1XV_tFTJrSpSr
|
|
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector
|
|
// TRUNNER: Passed in closures:
|
|
// TRUNNER: 1. %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector
|
|
// TRUNNER-EMPTY:
|
|
|
|
//=========== Test specialized function signature and body ===========//
|
|
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
|
|
// TRUNNER-LABEL: Generated specialized function: $s10pullback_g0A2_fTf1nc_n
|
|
// CHECK: sil shared @$s10pullback_g0A2_fTf1nc_n : $@convention(thin) (Float) -> X.TangentVector {
|
|
// CHECK: bb0(%0 : $Float):
|
|
// CHECK: %[[#D1:]] = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector
|
|
// CHECK: %[[#D2:]] = thin_to_thick_function %[[#D1]] : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector
|
|
// CHECK: %[[#D3:]] = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
|
|
// TRUNNER: %[[#D4:]] = partial_apply [callee_guaranteed] %[[#D3]](%[[#D2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
|
|
// TRUNNER: %[[#D5:]] = apply %[[#D4]](%0) : $@callee_guaranteed (Float) -> X.TangentVector
|
|
// COMBINE: %[[#D5:]] = apply %[[#D3]](%0, %[[#D2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
|
|
// TRUNNER: strong_release %[[#D4]] : $@callee_guaranteed (Float) -> X.TangentVector
|
|
// CHECK: return %[[#D5]]
|
|
|
|
//=========== Test rewritten body ===========//
|
|
specify_test "autodiff_closure_specialize_rewritten_caller_body"
|
|
// TRUNNER-LABEL: Rewritten caller body for: $s5test21g1xSfAA1XV_tFTJrSpSr
|
|
// CHECK: sil hidden @$s5test21g1xSfAA1XV_tFTJrSpSr : $@convention(thin) (X) -> (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) {
|
|
// CHECK: bb0(%0 : $X):
|
|
// CHECK: %[[#D6:]] = struct_extract %0 : $X, #X.a
|
|
// COMBINE-NOT: function_ref @pullback_f
|
|
// COMBINE-NOT: function_ref @subset_parameter_thunk
|
|
// CHECK: %[[#D7:]] = function_ref @$s10pullback_g0A2_fTf1nc_n : $@convention(thin) (Float) -> X.TangentVector
|
|
// TRUNNER: %[[#D8:]] = partial_apply [callee_guaranteed] %[[#D7]]() : $@convention(thin) (Float) -> X.TangentVector
|
|
// COMBINE: %[[#D8:]] = thin_to_thick_function %[[#D7]] : $@convention(thin) (Float) -> X.TangentVector to $@callee_guaranteed (Float) -> X.TangentVector
|
|
// CHECK: %[[#D9:]] = tuple (%[[#D6]] : $Float, %[[#D8]] : $@callee_guaranteed (Float) -> X.TangentVector)
|
|
// CHECK: return %[[#D9]]
|
|
|
|
%1 = struct_extract %0 : $X, #X.a
|
|
// function_ref pullback_f
|
|
%2 = function_ref @pullback_f : $@convention(thin) (Float, Double) -> X.TangentVector
|
|
%3 = thin_to_thick_function %2 : $@convention(thin) (Float, Double) -> X.TangentVector to $@callee_guaranteed (Float, Double) -> X.TangentVector
|
|
// function_ref subset_parameter_thunk
|
|
%4 = function_ref @subset_parameter_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
|
|
%5 = partial_apply [callee_guaranteed] %4(%3) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (Float, Double) -> X.TangentVector) -> X.TangentVector
|
|
// function_ref pullback_g
|
|
%6 = function_ref @pullback_g : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector
|
|
%7 = partial_apply [callee_guaranteed] %6(%5) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> X.TangentVector) -> X.TangentVector
|
|
%8 = tuple (%1 : $Float, %7 : $@callee_guaranteed (Float) -> X.TangentVector)
|
|
return %8 : $(Float, @callee_guaranteed (Float) -> X.TangentVector)
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////
|
|
///////// Specialized generic closures - PartialApply Closure /////////
|
|
///////////////////////////////////////////////////////////////////////
|
|
|
|
// closure #1 in static Float._vjpMultiply(lhs:rhs:)
|
|
sil @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
|
|
// thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
|
|
sil [transparent] [reabstraction_thunk] @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
|
|
|
|
// function_ref specialized pullback of f<A>(a:)
|
|
sil [transparent] [thunk] @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
|
|
|
|
// thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float)
|
|
sil [transparent] [reabstraction_thunk] @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
|
|
sil private [signature_optimized_thunk] [always_inline] @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float {
|
|
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float):
|
|
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float
|
|
strong_release %1 : $@callee_guaranteed (Float) -> Float
|
|
return %2 : $Float
|
|
}
|
|
|
|
// reverse-mode derivative of h(x:)
|
|
sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
bb0(%0 : $Float):
|
|
//=========== Test callsite and closure gathering logic ===========//
|
|
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
|
|
// TRUNNER-LABEL: Specializing closures in function: $s5test21h1xS2f_tFTJrSpSr
|
|
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
|
|
// TRUNNER: Passed in closures:
|
|
// TRUNNER: 1. %[[#]] = partial_apply [callee_guaranteed] %[[#]](%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// TRUNNER-EMPTY:
|
|
|
|
//=========== Test specialized function signature and body ===========//
|
|
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
|
|
// TRUNNER-LABEL: Generated specialized function: $s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n
|
|
// CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float {
|
|
// CHECK: bb0(%0 : $Float, %1 : $Float, %2 : $Float):
|
|
// CHECK: %[[#E1:]] = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// CHECK: %[[#E2:]] = partial_apply [callee_guaranteed] %[[#E1]](%1, %2) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
// CHECK: %[[#E3:]] = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
|
|
// CHECK: %[[#E4:]] = partial_apply [callee_guaranteed] %[[#E3]](%[[#E2]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
|
|
// CHECK: %[[#E5:]] = convert_function %[[#E4]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>
|
|
// CHECK: %[[#E6:]] = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
|
|
// CHECK: %[[#E7:]] = partial_apply [callee_guaranteed] %[[#E6]](%[[#E5]]) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
|
|
// CHECK: %[[#E8:]] = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
// TRUNNER: %[[#E9:]] = partial_apply [callee_guaranteed] %[[#E8]](%[[#E7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
// TRUNNER: %[[#E10:]] = apply %[[#E9]](%0) : $@callee_guaranteed (Float) -> Float
|
|
// COMBINE: %[[#E10:]] = apply %[[#E8]](%0, %[[#E7]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
// TRUNNER: strong_release %[[#E9]] : $@callee_guaranteed (Float) -> Float
|
|
// COMBINE: strong_release %[[#E7]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float
|
|
// CHECK: return %[[#E10]]
|
|
|
|
//=========== Test rewritten body ===========//
|
|
specify_test "autodiff_closure_specialize_rewritten_caller_body"
|
|
// TRUNNER-LABEL: Rewritten caller body for: $s5test21h1xS2f_tFTJrSpSr
|
|
// CHECK: sil hidden @$s5test21h1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
// CHECK: bb0(%0 : $Float):
|
|
// CHECK: %[[#E11:]] = struct_extract %0 : $Float, #Float._value
|
|
// CHECK: %[[#E12:]] = builtin "fmul_FPIEEE32"(%[[#E11]] : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
// COMBINE-NOT: function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_
|
|
// COMBINE-NOT: function_ref @$sS3fIegydd_S3fIegnrr_TR
|
|
// COMBINE-NOT: function_ref @pullback_f_specialized
|
|
// COMBINE-NOT: function_ref @$sS2fIegnr_S2fIegyd_TR
|
|
// CHECK: %[[#E13:]] = struct $Float (%[[#E12]] : $Builtin.FPIEEE32)
|
|
// CHECK: %[[#E14:]] = function_ref @$s10pullback_h073$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbackti1_j5FZSf_J6SfcfU_S2fTf1nc_n : $@convention(thin) (Float, Float, Float) -> Float
|
|
// CHECK: %[[#E15:]] = partial_apply [callee_guaranteed] %[[#E14]](%0, %0) : $@convention(thin) (Float, Float, Float) -> Float
|
|
// CHECK: %[[#E16:]] = tuple (%[[#E13]] : $Float, %[[#E15]] : $@callee_guaranteed (Float) -> Float)
|
|
// CHECK: return %[[#E16]]
|
|
|
|
%1 = struct_extract %0 : $Float, #Float._value
|
|
%2 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
|
|
|
|
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
|
|
%3 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
%4 = partial_apply [callee_guaranteed] %3(%0, %0) : $@convention(thin) (Float, Float, Float) -> (Float, Float)
|
|
|
|
// function_ref thunk for @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float, @unowned Float)
|
|
%5 = function_ref @$sS3fIegydd_S3fIegnrr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
|
|
%6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, Float)) -> (@out Float, @out Float)
|
|
%7 = convert_function %6 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @out Float) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>
|
|
|
|
// function_ref pullback_f_specialized
|
|
%8 = function_ref @pullback_f_specialized : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
|
|
%9 = partial_apply [callee_guaranteed] %8(%7) : $@convention(thin) (@in_guaranteed Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @out τ_0_2) for <Float, Float, Float>) -> @out Float
|
|
|
|
// function_ref thunk for @escaping @callee_guaranteed (@in_guaranteed Float) -> (@out Float)
|
|
%10 = function_ref @$sS2fIegnr_S2fIegyd_TR : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
%11 = partial_apply [callee_guaranteed] %10(%9) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
%12 = struct $Float (%2 : $Builtin.FPIEEE32)
|
|
|
|
// function_ref pullback_h
|
|
%13 = function_ref @pullback_h : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
|
|
%14 = partial_apply [callee_guaranteed] %13(%11) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
|
|
%15 = tuple (%12 : $Float, %14 : $@callee_guaranteed (Float) -> Float)
|
|
return %15 : $(Float, @callee_guaranteed (Float) -> Float)
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
///////// Specialized generic closures - ThinToThickFunction closure /////////
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
sil [transparent] [thunk] @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float
|
|
|
|
sil [transparent] [reabstraction_thunk] @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
|
|
sil private [signature_optimized_thunk] [always_inline] @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float {
|
|
bb0(%0 : $Float, %1 : $@callee_guaranteed (Float) -> Float):
|
|
%2 = apply %1(%0) : $@callee_guaranteed (Float) -> Float
|
|
strong_release %1 : $@callee_guaranteed (Float) -> Float
|
|
return %2 : $Float
|
|
}
|
|
|
|
sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
bb0(%0 : $Float):
|
|
//=========== Test callsite and closure gathering logic ===========//
|
|
specify_test "autodiff_closure_specialize_get_pullback_closure_info"
|
|
// TRUNNER-LABEL: Specializing closures in function: $s5test21z1xS2f_tFTJrSpSr
|
|
// TRUNNER: PartialApply of pullback: %[[#]] = partial_apply [callee_guaranteed] %[[#]](%[[#]]) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
|
|
// TRUNNER: Passed in closures:
|
|
// TRUNNER: 1. %[[#]] = thin_to_thick_function %[[#]] : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float
|
|
// TRUNNER-EMPTY:
|
|
|
|
//=========== Test specialized function signature and body ===========//
|
|
specify_test "autodiff_closure_specialize_specialized_function_signature_and_body"
|
|
// TRUNNER-LABEL: Generated specialized function: $s10pullback_z0A14_y_specializedTf1nc_n
|
|
// CHECK: sil private [signature_optimized_thunk] [always_inline] @$s10pullback_z0A14_y_specializedTf1nc_n : $@convention(thin) (Float) -> Float {
|
|
// CHECK: bb0(%0 : $Float):
|
|
// CHECK: %[[#F1:]] = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float
|
|
// CHECK: %[[#F2:]] = thin_to_thick_function %[[#F1]] : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float
|
|
// CHECK: %[[#F3:]] = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
// TRUNNER: %[[#F4:]] = partial_apply [callee_guaranteed] %[[#F3]](%[[#F2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
// TRUNNER: %[[#F5:]] = apply %[[#F4]](%0) : $@callee_guaranteed (Float) -> Float
|
|
// COMBINE: %[[#F5:]] = apply %[[#F3]](%0, %[[#F2]]) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
// TRUNNER: strong_release %[[#F4]] : $@callee_guaranteed (Float) -> Float
|
|
// CHECK: return %[[#F5]]
|
|
|
|
//=========== Test rewritten body ===========//
|
|
specify_test "autodiff_closure_specialize_rewritten_caller_body"
|
|
// TRUNNER-LABEL: Rewritten caller body for: $s5test21z1xS2f_tFTJrSpSr
|
|
// CHECK: sil hidden @$s5test21z1xS2f_tFTJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
|
|
// CHECK: bb0(%0 : $Float):
|
|
// COMBINE-NOT: function_ref @pullback_y_specialized
|
|
// COMBINE-NOT: function_ref @reabstraction_thunk
|
|
// CHECK: %[[#F6:]] = function_ref @$s10pullback_z0A14_y_specializedTf1nc_n : $@convention(thin) (Float) -> Float
|
|
// TRUNNER: %[[#F7:]] = partial_apply [callee_guaranteed] %[[#F6]]() : $@convention(thin) (Float) -> Float
|
|
// COMBINE: %[[#F7:]] = thin_to_thick_function %[[#F6]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
|
|
// CHECK: %[[#F8:]] = tuple (%0 : $Float, %[[#F7]] : $@callee_guaranteed (Float) -> Float)
|
|
// CHECK: return %[[#F8]]
|
|
|
|
// function_ref pullback_y_specialized
|
|
%1 = function_ref @pullback_y_specialized : $@convention(thin) (@in_guaranteed Float) -> @out Float
|
|
%2 = thin_to_thick_function %1 : $@convention(thin) (@in_guaranteed Float) -> @out Float to $@callee_guaranteed (@in_guaranteed Float) -> @out Float
|
|
// function_ref reabstraction_thunk
|
|
%3 = function_ref @reabstraction_thunk : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
%4 = partial_apply [callee_guaranteed] %3(%2) : $@convention(thin) (Float, @guaranteed @callee_guaranteed (@in_guaranteed Float) -> @out Float) -> Float
|
|
// function_ref pullback_z
|
|
%5 = function_ref @pullback_z : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
|
|
%6 = partial_apply [callee_guaranteed] %5(%4) : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> Float) -> Float
|
|
%7 = tuple (%0 : $Float, %6 : $@callee_guaranteed (Float) -> Float)
|
|
return %7 : $(Float, @callee_guaranteed (Float) -> Float)
|
|
}
|