//===--- FloatingPointDifferentiation.swift.gyb ---------------*- swift -*-===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// import Swift import SwiftShims % from SwiftFloatingPointTypes import all_floating_point_types % for self_type in all_floating_point_types(): %{ Self = self_type.stdlib_name bits = self_type.bits def Availability(bits): if bits == 16: return '@available(iOS 14.0, tvOS 14.0, watchOS 7.0, *)' return '' }% % if bits == 80: #if !(os(Windows) || os(Android) || ($Embedded && !os(Linux) && !(os(macOS) || os(iOS) || os(watchOS) || os(tvOS)))) && (arch(i386) || arch(x86_64)) % end % if bits == 16: #if !os(macOS) && !(os(iOS) && targetEnvironment(macCatalyst)) % end //===----------------------------------------------------------------------===// // Protocol conformances //===----------------------------------------------------------------------===// ${Availability(bits)} extension ${Self}: Differentiable { ${Availability(bits)} public typealias TangentVector = ${Self} ${Availability(bits)} @inlinable public mutating func move(by offset: TangentVector) { self += offset } } //===----------------------------------------------------------------------===// // Derivatives //===----------------------------------------------------------------------===// /// Derivatives of constructors. ${Availability(bits)} extension ${Self} { % for other_type in all_floating_point_types(): %{ Other = other_type.stdlib_name other_bits = other_type.bits }% % if other_bits == 80: #if !(os(Windows) || os(Android) || $Embedded) && (arch(i386) || arch(x86_64)) % end % if other_bits == 16: #if !os(macOS) && !(os(iOS) && targetEnvironment(macCatalyst)) % end ${Availability(other_bits)} @inlinable @_transparent @derivative(of: init(_:)) static func _vjpInit(x: ${Other}) -> (value: ${Self}, pullback: (${Self}) -> ${Other}) { return (value: ${Self}(x), pullback: { v in ${Other}(v) }) } ${Availability(other_bits)} @inlinable @_transparent @derivative(of: init(_:)) static func _jvpInit(x: ${Other}) -> (value: ${Self}, differential: (${Other}) -> ${Self}) { return (value: ${Self}(x), differential: { dx in ${Self}(dx) }) } % if other_bits == 80 or other_bits == 16: #endif % end % end } /// Derivatives of standard unary operators. ${Availability(bits)} extension ${Self} { @inlinable @_transparent @derivative(of: -) static func _vjpNegate(x: ${Self}) -> (value: ${Self}, pullback: (${Self}) -> ${Self}) { return (-x, { v in -v }) } @inlinable @_transparent @derivative(of: -) static func _jvpNegate(x: ${Self}) -> (value: ${Self}, differential: (${Self}) -> ${Self}) { return (-x, { dx in -dx }) } } /// Derivatives of standard binary operators. ${Availability(bits)} extension ${Self} { ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: +) static func _vjpAdd( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs + rhs, { v in (v, v) }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: +) static func _jvpAdd( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) { return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: +=) static func _vjpAddAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> ( value: Void, pullback: (inout ${Self}) -> ${Self} ) { lhs += rhs return ((), { v in v }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: +=) static func _jvpAddAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> ( value: Void, differential: (inout ${Self}, ${Self}) -> Void ) { lhs += rhs return ((), { $0 += $1 }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: -) static func _vjpSubtract( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs - rhs, { v in (v, -v) }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: -) static func _jvpSubtract( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) { return (lhs - rhs, { (dlhs, drhs) in dlhs - drhs }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: -=) static func _vjpSubtractAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> ( value: Void, pullback: (inout ${Self}) -> ${Self} ) { lhs -= rhs return ((), { v in -v }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: -=) static func _jvpSubtractAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> ( value: Void, differential: (inout ${Self}, ${Self}) -> Void ) { lhs -= rhs return ((), { $0 -= $1 }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: *) static func _vjpMultiply( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs * rhs, { v in (rhs * v, lhs * v) }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: *) static func _jvpMultiply( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) { return (lhs * rhs, { (dlhs, drhs) in lhs * drhs + rhs * dlhs }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: *=) static func _vjpMultiplyAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> ( value: Void, pullback: (inout ${Self}) -> ${Self} ) { defer { lhs *= rhs } return ((), { [lhs = lhs] v in let drhs = lhs * v v *= rhs return drhs }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: *=) static func _jvpMultiplyAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> ( value: Void, differential: (inout ${Self}, ${Self}) -> Void ) { let oldLhs = lhs lhs *= rhs return ((), { $0 = $0 * rhs + oldLhs * $1 }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: /) static func _vjpDivide( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: /) static func _jvpDivide( lhs: ${Self}, rhs: ${Self} ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) { return (lhs / rhs, { (dlhs, drhs) in dlhs / rhs - lhs / (rhs * rhs) * drhs }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: /=) static func _vjpDivideAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> ( value: Void, pullback: (inout ${Self}) -> ${Self} ) { defer { lhs /= rhs } return ((), { [lhs = lhs] v in let drhs = -lhs / (rhs * rhs) * v v /= rhs return drhs }) } ${Availability(bits)} @inlinable // FIXME(sil-serialize-all) @_transparent @derivative(of: /=) static func _jvpDivideAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> ( value: Void, differential: (inout ${Self}, ${Self}) -> Void ) { let oldLhs = lhs lhs /= rhs return ((), { $0 = ($0 * rhs - oldLhs * $1) / (rhs * rhs) }) } } % if bits == 80 or bits == 16: #endif % end % end extension FloatingPoint where Self: Differentiable, Self == Self.TangentVector { @inlinable // FIXME(sil-serialize-all) @derivative(of: addingProduct) func _vjpAddingProduct( _ lhs: Self, _ rhs: Self ) -> (value: Self, pullback: (Self) -> (Self, Self, Self)) { return (addingProduct(lhs, rhs), { v in (v, v * rhs, v * lhs) }) } @inlinable // FIXME(sil-serialize-all) @derivative(of: squareRoot) func _vjpSquareRoot() -> (value: Self, pullback: (Self) -> Self) { let y = squareRoot() return (y, { v in v / (2 * y) }) } @inlinable @derivative(of: minimum) static func _vjpMinimum(_ x: Self, _ y: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { if x <= y || y.isNaN { return (x, { v in (v, .zero) }) } return (y, { v in (.zero, v) }) } @inlinable @derivative(of: maximum) static func _vjpMaximum(_ x: Self, _ y: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { if x > y || y.isNaN { return (x, { v in (v, .zero) }) } return (y, { v in (.zero, v) }) } }