// Pullback generation tests written in SIL for features // that may not be directly supported by the Swift frontend // RUN: %target-sil-opt -sil-print-types --differentiation -emit-sorted-sil %s 2>&1 | %FileCheck %s //===----------------------------------------------------------------------===// // Pullback generation - `struct_extract` // - Input to pullback has non-owned ownership semantics which requires copying // this value to stack before lifetime-ending uses. //===----------------------------------------------------------------------===// sil_stage raw import Builtin import Swift import SwiftShims import _Differentiation struct X { @_hasStorage var a: Float { get set } @_hasStorage var b: String { get set } init(a: Float, b: String) } extension X : Differentiable, Equatable, AdditiveArithmetic { public typealias TangentVector = X mutating func move(by offset: X) public static var zero: X { get } public static func + (lhs: X, rhs: X) -> X public static func - (lhs: X, rhs: X) -> X @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X, _ b: X) -> Bool } struct Y { @_hasStorage var a: X { get set } @_hasStorage var b: String { get set } init(a: X, b: String) } extension Y : Differentiable, Equatable, AdditiveArithmetic { public typealias TangentVector = Y mutating func move(by offset: Y) public static var zero: Y { get } public static func + (lhs: Y, rhs: Y) -> Y public static func - (lhs: Y, rhs: Y) -> Y @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Y, _ b: Y) -> Bool } sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X { } sil hidden [ossa] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X { bb0(%0 : @guaranteed $Y): %1 = struct_extract %0 : $Y, #Y.a %2 = copy_value %1 : $X return %2 : $X } // CHECK-LABEL: sil private [ossa] @$function_with_struct_extract_1TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned Y { // CHECK: bb0(%0 : @guaranteed $X): // CHECK: %1 = alloc_stack $Y // CHECK: %2 = witness_method $Y, #AdditiveArithmetic.zero!getter : (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %3 = metatype $@thick Y.Type // CHECK: %4 = apply %2(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %5 = struct_element_addr %1 : $*Y, #Y.a // Since input parameter $0 has non-owned ownership semantics, it // needs to be copied before a lifetime-ending use. // CHECK: %6 = copy_value %0 : $X // CHECK: %7 = alloc_stack $X // CHECK: store %6 to [init] %7 : $*X // CHECK: %9 = witness_method $X, #AdditiveArithmetic."+=" : (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // CHECK: %10 = metatype $@thick X.Type // CHECK: %11 = apply %9(%5, %7, %10) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // CHECK: destroy_addr %7 : $*X // CHECK: dealloc_stack %7 : $*X // CHECK: %14 = load [take] %1 : $*Y // CHECK: dealloc_stack %1 : $*Y // CHECK: %16 = copy_value %14 : $Y // CHECK: destroy_value %14 : $Y // CHECK: return %16 : $Y // CHECK: } // end sil function '$function_with_struct_extract_1TJpSpSr' //===----------------------------------------------------------------------===// // Pullback generation - `tuple_extract` // - Tuples as differentiable input arguments are not supported yet, so creating // a basic test in SIL instead. //===----------------------------------------------------------------------===// sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float { } sil hidden [ossa] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float { bb0(%0 : $(Float, Float)): %1 = tuple_extract %0 : $(Float, Float), 0 return %1 : $Float } // CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_1TJpSpSr : $@convention(thin) (Float) -> (Float, Float) { // CHECK: bb0(%0 : $Float): // CHECK: %1 = alloc_stack $(Float, Float) // CHECK: %2 = tuple_element_addr %1 : $*(Float, Float), 0 // CHECK: %3 = witness_method $Float, #AdditiveArithmetic.zero!getter : (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %4 = metatype $@thick Float.Type // CHECK: %5 = apply %3(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %6 = tuple_element_addr %1 : $*(Float, Float), 1 // CHECK: %7 = witness_method $Float, #AdditiveArithmetic.zero!getter : (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %8 = metatype $@thick Float.Type // CHECK: %9 = apply %7(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %10 = tuple_element_addr %1 : $*(Float, Float), 0 // CHECK: %11 = alloc_stack $Float // CHECK: store %0 to [trivial] %11 : $*Float // CHECK: %13 = witness_method $Float, #AdditiveArithmetic."+=" : (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // CHECK: %14 = metatype $@thick Float.Type // CHECK: %15 = apply %13(%10, %11, %14) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // CHECK: destroy_addr %11 : $*Float // CHECK: dealloc_stack %11 : $*Float // CHECK: %18 = load [trivial] %1 : $*(Float, Float) // CHECK: dealloc_stack %1 : $*(Float, Float) // CHECK: return %18 : $(Float, Float) // CHECK: } // end sil function 'function_with_tuple_extract_1TJpSpSr' //===----------------------------------------------------------------------===// // Pullback generation - Inner values of concrete adjoints must be copied // during direct materialization. // - If the input to pullback BB has non-owned ownership semantics we cannot // perform a lifetime-ending operation on it. // - If the input to the pullback BB is an owned, non-trivial value we must // copy it or there will be a double consume when all owned parameters are // destroyed at the end of the basic block. //===----------------------------------------------------------------------===// sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X { } sil hidden [ossa] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X { bb0(%0 : @guaranteed $(X, X)): %1 = tuple_extract %0 : $(X, X), 0 %2 = copy_value %1: $X return %2 : $X } // CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_2TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned (X, X) { // CHECK: bb0(%0 : @guaranteed $X): // CHECK: %1 = alloc_stack $(X, X) // CHECK: %2 = tuple_element_addr %1 : $*(X, X), 0 // CHECK: %3 = witness_method $X, #AdditiveArithmetic.zero!getter : (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %4 = metatype $@thick X.Type // CHECK: %5 = apply %3(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %6 = tuple_element_addr %1 : $*(X, X), 1 // CHECK: %7 = witness_method $X, #AdditiveArithmetic.zero!getter : (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %8 = metatype $@thick X.Type // CHECK: %9 = apply %7(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: %10 = tuple_element_addr %1 : $*(X, X), 0 // CHECK: %11 = copy_value %0 : $X // CHECK: %12 = alloc_stack $X // CHECK: store %11 to [init] %12 : $*X // CHECK: %14 = witness_method $X, #AdditiveArithmetic."+=" : (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // CHECK: %15 = metatype $@thick X.Type // CHECK: %16 = apply %14(%10, %12, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // CHECK: destroy_addr %12 : $*X // CHECK: dealloc_stack %12 : $*X // CHECK: %19 = load [take] %1 : $*(X, X) // CHECK: dealloc_stack %1 : $*(X, X) // CHECK: %21 = copy_value %19 : $(X, X) // CHECK: destroy_value %19 : $(X, X) // CHECK: return %21 : $(X, X) // CHECK: } // end sil function 'function_with_tuple_extract_2TJpSpSr' //===----------------------------------------------------------------------===// // Pullback generation - `tuple_extract` // - Adjoint of extracted element can be `AddElement` // - Just need to make sure that we are able to generate a pullback //===----------------------------------------------------------------------===// sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float { } sil hidden [ossa] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float { bb0(%0 : $((Float, Float), Float)): %1 = tuple_extract %0 : $((Float, Float), Float), 0 %2 = tuple_extract %1 : $(Float, Float), 0 return %2 : $Float } // CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_3TJpSpSr : $@convention(thin) (Float) -> ((Float, Float), Float) {