Files
swift-mirror/test/AutoDiff/SILOptimizer/derivative_sil.swift
Dan Zheng 2eb460de4d [AutoDiff upstream] Add forward-mode differentiation. (#30878)
JVP functions are forward-mode derivative functions. They take original
arguments and return original results and a differential function. Differential
functions take derivatives wrt arguments and return derivatives wrt results.

`JVPEmitter` is a cloner that emits JVP and differential functions at the same
time. In JVP functions, function applications are replaced with JVP function
applications. In differential functions, function applications are replaced
with differential function applications.

In JVP functions, each basic block takes a differential struct containing callee
differentials. These structs are consumed by differential functions.
2020-04-08 11:29:21 -07:00

91 lines
6.4 KiB
Swift

// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
// REQUIRES: asserts
// Simple generated derivative code FileCheck tests.
import _Differentiation
extension Float {
@_silgen_name("add")
static func add(_ x: Float, _ y: Float) -> Float {
return x + y
}
@derivative(of: add)
static func addVJP(_ x: Float, _ y: Float) -> (
value: Float, pullback: (Float) -> (Float, Float)
) {
return (add(x, y), { v in (v, v) })
}
}
@_silgen_name("foo")
@differentiable
func foo(_ x: Float) -> Float {
let y = Float.add(x, x)
return y
}
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK-SIL: bb0([[X:%.*]] : $Float):
// CHECK-SIL: [[ADD_ORIG_REF:%.*]] = function_ref @add : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SIL: [[ADD_JVP_REF:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @add
// CHECK-SIL: [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @add
// CHECK-SIL: [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
// CHECK-SIL: [[ADD_JVP_FN:%.*]] = differentiable_function_extract [jvp] [[ADD_DIFF_FN]]
// CHECK-SIL: end_borrow [[ADD_DIFF_FN]]
// CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_JVP_FN]]([[X]], [[X]], {{.*}})
// CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_DF:%.*]]) = destructure_tuple [[ADD_RESULT]]
// CHECK-SIL: [[DF_STRUCT:%.*]] = struct $_AD__foo_bb0__DF__src_0_wrt_0 ([[ADD_DF]] : $@callee_guaranteed (Float, Float) -> Float)
// CHECK-SIL: [[DF_REF:%.*]] = function_ref @AD__foo__differential_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__foo_bb0__DF__src_0_wrt_0) -> Float
// CHECK-SIL: [[DF_FN:%.*]] = partial_apply [callee_guaranteed] [[DF_REF]]([[DF_STRUCT]])
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[DF_FN]] : $@callee_guaranteed (Float) -> Float)
// CHECK-SIL: return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float)
// CHECK-SIL: }
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__foo__differential_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__foo_bb0__DF__src_0_wrt_0) -> Float {
// CHECK-SIL: bb0([[DX:%.*]] : $Float, [[DF_STRUCT:%.*]] : @owned $_AD__foo_bb0__DF__src_0_wrt_0):
// CHECK-SIL: [[ADD_DF:%.*]] = destructure_struct [[DF_STRUCT]] : $_AD__foo_bb0__DF__src_0_wrt_0
// CHECK-SIL: [[DY:%.*]] = apply [[ADD_DF]]([[DX]], [[DX]]) : $@callee_guaranteed (Float, Float) -> Float
// CHECK-SIL: destroy_value [[ADD_DF]] : $@callee_guaranteed (Float, Float) -> Float
// CHECK-SIL: return [[DY]] : $Float
// CHECK-SIL: }
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
// CHECK-SIL: bb0([[X:%.*]] : $Float):
// CHECK-SIL: [[ADD_ORIG_REF:%.*]] = function_ref @add : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK-SIL: [[ADD_JVP_REF:%.*]] = differentiability_witness_function [jvp] [parameters 0 1] [results 0] @add
// CHECK-SIL: [[ADD_VJP_REF:%.*]] = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @add
// CHECK-SIL: [[ADD_DIFF_FN:%.*]] = differentiable_function [parameters 0 1] [[ADD_ORIG_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with_derivative {[[ADD_JVP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[ADD_VJP_REF]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))}
// CHECK-SIL: [[ADD_VJP_FN:%.*]] = differentiable_function_extract [vjp] [[ADD_DIFF_FN]]
// CHECK-SIL: end_borrow [[ADD_DIFF_FN]]
// CHECK-SIL: [[ADD_RESULT:%.*]] = apply [[ADD_VJP_FN]]([[X]], [[X]], {{.*}})
// CHECK-SIL: ([[ORIG_RES:%.*]], [[ADD_PB:%.*]]) = destructure_tuple [[ADD_RESULT]]
// CHECK-SIL: [[PB_STRUCT:%.*]] = struct $_AD__foo_bb0__PB__src_0_wrt_0 ([[ADD_PB]] : $@callee_guaranteed (Float) -> (Float, Float))
// CHECK-SIL: [[PB_REF:%.*]] = function_ref @AD__foo__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__foo_bb0__PB__src_0_wrt_0) -> Float
// CHECK-SIL: [[PB_FN:%.*]] = partial_apply [callee_guaranteed] [[PB_REF]]([[PB_STRUCT]])
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB_FN]] : $@callee_guaranteed (Float) -> Float)
// CHECK-SIL: return [[VJP_RESULT]] : $(Float, @callee_guaranteed (Float) -> Float)
// CHECK-SIL: }
// CHECK-SIL-LABEL: sil private [ossa] @AD__foo__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__foo_bb0__PB__src_0_wrt_0) -> Float {
// CHECK-SIL: bb0([[DY:%.*]] : $Float, [[PB_STRUCT:%.*]] : @owned $_AD__foo_bb0__PB__src_0_wrt_0):
// CHECK-SIL: [[ADD_PB:%.*]] = destructure_struct [[PB_STRUCT]] : $_AD__foo_bb0__PB__src_0_wrt_0
// CHECK-SIL: [[ADD_PB_RES:%.*]] = apply [[ADD_PB]]([[DY]]) : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK-SIL: ([[DX_1:%.*]], [[DX_2:%.*]]) = destructure_tuple [[ADD_PB_RES]] : $(Float, Float)
// CHECK-SIL: [[TMP_BUF_RES:%.*]] = alloc_stack $Float
// CHECK-SIL: [[TMP_BUF_LHS:%.*]] = alloc_stack $Float
// CHECK-SIL: [[TMP_BUF_RHS:%.*]] = alloc_stack $Float
// CHECK-SIL: store [[DX_1]] to [trivial] [[TMP_BUF_LHS]] : $*Float
// CHECK-SIL: store [[DX_2]] to [trivial] [[TMP_BUF_RHS]] : $*Float
// CHECK-SIL: [[PLUS_FN:%.*]] = witness_method $Float, #AdditiveArithmetic."+"
// CHECK-SIL: apply [[PLUS_FN]]<Float>([[TMP_BUF_RES]], [[TMP_BUF_RHS]], [[TMP_BUF_LHS]], {{.*}})
// CHECK-SIL: destroy_addr [[TMP_BUF_LHS]] : $*Float
// CHECK-SIL: destroy_addr [[TMP_BUF_RHS]] : $*Float
// CHECK-SIL: dealloc_stack [[TMP_BUF_RHS]] : $*Float
// CHECK-SIL: dealloc_stack [[TMP_BUF_LHS]] : $*Float
// CHECK-SIL: [[DX:%.*]] = load [trivial] [[TMP_BUF_RES]] : $*Float
// CHECK-SIL: dealloc_stack [[TMP_BUF_RES]] : $*Float
// CHECK-SIL: return [[DX]] : $Float
// CHECK-SIL: }