//===--- SIMDDifferentiation.swift.gyb ------------------------*- swift -*-===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// import Swift %{ storagescalarCounts = [2,4,8,16,32,64] vectorscalarCounts = storagescalarCounts + [3] }% %for n in vectorscalarCounts: //===----------------------------------------------------------------------===// // Protocol conformances //===----------------------------------------------------------------------===// extension SIMD${n}: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} extension SIMD${n}: Differentiable where Scalar: Differentiable & BinaryFloatingPoint, Scalar.TangentVector: BinaryFloatingPoint { public typealias TangentVector = SIMD${n} } //===----------------------------------------------------------------------===// // Derivatives //===----------------------------------------------------------------------===// extension SIMD${n} where Scalar: Differentiable & BinaryFloatingPoint, Scalar.TangentVector == Scalar { // NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation. @inlinable @derivative(of: subscript(_:)) internal func _vjpSubscript(_ index: Int) -> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector) { return (self[index], { v in var zeros = Self.zero zeros[index] = v return zeros }) } @inlinable @derivative(of: subscript(_:)) internal func _jvpSubscript(index: Int) -> (value: Scalar, differential: (TangentVector) -> Scalar.TangentVector) { return (self[index], { v in return .init(v[index]) }) } @inlinable @derivative(of: subscript(_:).set) internal mutating func _vjpSubscriptSetter(_ newValue: Scalar, _ index: Int) -> (value: Void, pullback: (inout TangentVector) -> Scalar.TangentVector) { self[index] = newValue return ((), { dSelf in let dNewValue = dSelf[index] dSelf[index] = 0 return dNewValue }) } } %end extension SIMD where Self: Differentiable, TangentVector: SIMD, Scalar: BinaryFloatingPoint, TangentVector.Scalar: BinaryFloatingPoint { @inlinable @derivative(of: +) static func _vjpAdd(lhs: Self, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { return (lhs + rhs, { v in return (v, v) }) } @inlinable @derivative(of: +) static func _jvpAdd(lhs: Self, rhs: Self) -> ( value: Self, differential: (TangentVector, TangentVector) -> TangentVector ) { return (lhs + rhs, { ltan, rtan in return ltan + rtan }) } @inlinable @derivative(of: -) static func _vjpSubtract(lhs: Self, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { return (lhs - rhs, { v in return (v, -v) }) } @inlinable @derivative(of: -) static func _jvpSubtract(lhs: Self, rhs: Self) -> ( value: Self, differential: (TangentVector, TangentVector) -> TangentVector ) { return (lhs - rhs, { ltan, rtan in return ltan - rtan }) } @inlinable @derivative(of: -) static func _vjpNegate(rhs: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector)) { return (-rhs, { v in return -v }) } @inlinable @derivative(of: -) static func _jvpNegate(rhs: Self) -> (value: Self, differential: (TangentVector) -> (TangentVector)) { return (-rhs, { v in return -v }) } } extension SIMD where Self: Differentiable, Scalar: BinaryFloatingPoint, Self.TangentVector == Self { @inlinable @derivative(of: *) static func _vjpMultiply(lhs: Self, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { return (lhs * rhs, { v in return (v * rhs, v * lhs) }) } @inlinable @derivative(of: *) static func _jvpMultiply(lhs: Self, rhs: Self) -> ( value: Self, differential: (TangentVector, TangentVector) -> TangentVector ) { return (lhs * rhs, { ltan, rtan in return lhs * rtan + ltan * rhs }) } @inlinable @derivative(of: /) static func _vjpDivide(lhs: Self, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { return ( lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) } @inlinable @derivative(of: /) static func _jvpDivide(lhs: Self, rhs: Self) -> ( value: Self, differential: (TangentVector, TangentVector) -> TangentVector ) { return ( lhs / rhs, { ltan, rtan in (ltan * rhs - lhs * rtan) / (rhs * rhs) }) } } extension SIMD where Self: Differentiable, TangentVector: SIMD, Scalar: BinaryFloatingPoint & Differentiable, Scalar.TangentVector: BinaryFloatingPoint, TangentVector.Scalar == Scalar.TangentVector { @inlinable @derivative(of: +) static func _vjpAdd(lhs: Scalar, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) ) { return (lhs + rhs, { v in return (v.sum(), v) }) } @inlinable @derivative(of: +) static func _jvpAdd(lhs: Scalar, rhs: Self) -> ( value: Self, differential: (Scalar.TangentVector, TangentVector) -> TangentVector ) { return (lhs + rhs, { ltan, rtan in return ltan + rtan }) } @inlinable @derivative(of: -) static func _vjpSubtract(lhs: Scalar, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) ) { return (lhs - rhs, { v in return (v.sum(), -v) }) } @inlinable @derivative(of: -) static func _jvpSubtract(lhs: Scalar, rhs: Self) -> ( value: Self, differential: (Scalar.TangentVector, TangentVector) -> TangentVector ) { return (lhs - rhs, { ltan, rtan in return ltan - rtan }) } @inlinable @derivative(of: +) static func _vjpAdd(lhs: Self, rhs: Scalar) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) ) { return (lhs + rhs, { v in return (v, v.sum()) }) } @inlinable @derivative(of: +) static func _jvpAdd(lhs: Self, rhs: Scalar) -> ( value: Self, differential: (TangentVector, Scalar.TangentVector) -> TangentVector ) { return (lhs + rhs, { ltan, rtan in return ltan + rtan }) } @inlinable @derivative(of: -) static func _vjpSubtract(lhs: Self, rhs: Scalar) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) ) { return (lhs - rhs, { v in return (v, -v.sum()) }) } @inlinable @derivative(of: -) static func _jvpSubtract(lhs: Self, rhs: Scalar) -> ( value: Self, differential: (TangentVector, Scalar.TangentVector) -> TangentVector ) { return (lhs - rhs, { ltan, rtan in return ltan - rtan }) } } extension SIMD where Self: Differentiable, Scalar: BinaryFloatingPoint & Differentiable, Self.TangentVector == Self, Scalar.TangentVector == Scalar { @inlinable @derivative(of: *) static func _vjpMultiply(lhs: Self, rhs: Scalar) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) ) { return (lhs * rhs, { v in return (v * rhs, (v * lhs).sum()) }) } @inlinable @derivative(of: *) static func _jvpMultiply(lhs: Self, rhs: Scalar) -> ( value: Self, differential: (TangentVector, Scalar.TangentVector) -> TangentVector ) { return (lhs * rhs, { ltan, rtan in return lhs * rtan + ltan * rhs }) } @inlinable @derivative(of: /) static func _vjpDivide(lhs: Self, rhs: Scalar) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) ) { return (lhs / rhs, { v in (v / rhs, (-lhs / (rhs * rhs) * v).sum()) }) } @inlinable @derivative(of: /) static func _jvpDivide(lhs: Self, rhs: Scalar) -> ( value: Self, differential: (TangentVector, Scalar.TangentVector) -> TangentVector ) { return (lhs / rhs, { ltan, rtan in (ltan * rhs - lhs * rtan) / (rhs * rhs) }) } @inlinable @derivative(of: *) static func _vjpMultiply(lhs: Scalar, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) ) { return (lhs * rhs, { v in return ((v * rhs).sum(), v * lhs) }) } @inlinable @derivative(of: *) static func _jvpMultiply(lhs: Scalar, rhs: Self) -> ( value: Self, differential: (Scalar.TangentVector, TangentVector) -> TangentVector ) { return (lhs * rhs, { ltan, rtan in return lhs * rtan + ltan * rhs }) } @inlinable @derivative(of: /) static func _vjpDivide(lhs: Scalar, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) ) { return (lhs / rhs, { v in ((v / rhs).sum(), -lhs / (rhs * rhs) * v) }) } @inlinable @derivative(of: /) static func _jvpDivide(lhs: Scalar, rhs: Self) -> ( value: Self, differential: (Scalar.TangentVector, TangentVector) -> TangentVector ) { return (lhs / rhs, { ltan, rtan in (ltan * rhs - lhs * rtan) / (rhs * rhs) }) } } extension SIMD where Self: Differentiable, TangentVector: SIMD, Scalar: BinaryFloatingPoint & Differentiable, Scalar.TangentVector: BinaryFloatingPoint, TangentVector == Self { @inlinable @_alwaysEmitIntoClient @derivative(of: sum) func _vjpSum() -> ( value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector ) { return (sum(), { v in Self(repeating: Scalar(v)) }) } @inlinable @_alwaysEmitIntoClient @derivative(of: sum) func _jvpSum() -> ( value: Scalar, differential: (TangentVector) -> Scalar.TangentVector ) { return (sum(), { v in Scalar.TangentVector(v.sum()) }) } } extension SIMD where Self: Differentiable, Scalar: BinaryFloatingPoint & Differentiable, Self.TangentVector == Self, Scalar.TangentVector == Scalar { @inlinable @derivative(of: init(repeating:)) static func _vjpInit(repeating value: Scalar) -> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector) { return (Self(repeating: value), { v in v.sum() }) } @inlinable @derivative(of: init(repeating:)) static func _jvpInit(repeating value: Scalar) -> (value: Self, differential: (Scalar.TangentVector) -> TangentVector) { return (Self(repeating: value), { v in Self(repeating: v) }) } } %for (Scalar, bits) in [('Float',32), ('Double',64)]: % for n in vectorscalarCounts: extension SIMD${n} where Scalar == ${Scalar} { @inlinable @_alwaysEmitIntoClient @derivative(of: init(repeating:)) static func _vjpInit(repeating value: Scalar) -> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector) { return (Self(repeating: value), { v in v.sum() }) } @inlinable @_alwaysEmitIntoClient @derivative(of: init(repeating:)) static func _jvpInit(repeating value: Scalar) -> (value: Self, differential: (Scalar.TangentVector) -> TangentVector) { return (Self(repeating: value), { v in Self(repeating: v) }) } } % end %end