//===--- SIMDDifferentiation.swift.gyb ------------------------*- swift -*-===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// import Swift %{ storagescalarCounts = [2,4,8,16,32,64] vectorscalarCounts = storagescalarCounts + [3] }% %for n in vectorscalarCounts: //===----------------------------------------------------------------------===// // Protocol conformances //===----------------------------------------------------------------------===// extension SIMD${n}: AdditiveArithmetic where Scalar: FloatingPoint {} extension SIMD${n}: Differentiable where Scalar: Differentiable & BinaryFloatingPoint, Scalar.TangentVector: BinaryFloatingPoint { public typealias TangentVector = SIMD${n} } //===----------------------------------------------------------------------===// // Derivatives //===----------------------------------------------------------------------===// extension SIMD${n} where Scalar: Differentiable & BinaryFloatingPoint, Scalar.TangentVector: BinaryFloatingPoint { // NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation. @inlinable @derivative(of: subscript(_:)) internal func _vjpSubscript(index: Int) -> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector) { return (self[index], { v in var zeros = Self.zero zeros[index] = Scalar(v) return zeros }) } } %end extension SIMD where Self: Differentiable, TangentVector: SIMD, Scalar: BinaryFloatingPoint, TangentVector.Scalar: BinaryFloatingPoint { @inlinable @derivative(of: +) static func _vjpAdd(lhs: Self, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { return (lhs + rhs, { v in return (v, v) }) } @inlinable @derivative(of: -) static func _vjpSubtract(lhs: Self, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { return (lhs - rhs, { v in return (v, -v) }) } @inlinable @derivative(of: -) static func _vjpNegate(rhs: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector)) { return (-rhs, { v in return -v }) } } extension SIMD where Self: Differentiable, TangentVector: SIMD, Scalar: BinaryFloatingPoint, Self.TangentVector == Self { @inlinable @derivative(of: *) static func _vjpMultiply(lhs: Self, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { return (lhs * rhs, { v in return (v * rhs, v * lhs) }) } @inlinable @derivative(of: /) static func _vjpDivide(lhs: Self, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { return ( lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) } } extension SIMD where Self: Differentiable, TangentVector: SIMD, Scalar: BinaryFloatingPoint & Differentiable, Scalar.TangentVector: BinaryFloatingPoint, TangentVector.Scalar == Scalar.TangentVector { @inlinable @derivative(of: +) static func _vjpAdd(lhs: Scalar, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) ) { return (lhs + rhs, { v in return (v.sum(), v) }) } @inlinable @derivative(of: -) static func _vjpSubtract(lhs: Scalar, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) ) { return (lhs - rhs, { v in return (v.sum(), -v) }) } @inlinable @derivative(of: +) static func _vjpAdd(lhs: Self, rhs: Scalar) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) ) { return (lhs + rhs, { v in return (v, v.sum()) }) } @inlinable @derivative(of: -) static func _vjpSubtract(lhs: Self, rhs: Scalar) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) ) { return (lhs - rhs, { v in return (v, -v.sum()) }) } } extension SIMD where Self: Differentiable, TangentVector: SIMD, Scalar: BinaryFloatingPoint & Differentiable, Self.TangentVector == Self, Scalar.TangentVector == Scalar { @inlinable @derivative(of: *) static func _vjpMultiply(lhs: Self, rhs: Scalar) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) ) { return (lhs * rhs, { v in return (v * rhs, (v * lhs).sum()) }) } @inlinable @derivative(of: /) static func _vjpDivide(lhs: Self, rhs: Scalar) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector) ) { return (lhs / rhs, { v in (v / rhs, (-lhs / (rhs * rhs) * v).sum()) }) } @inlinable @derivative(of: *) static func _vjpMultiply(lhs: Scalar, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) ) { return (lhs * rhs, { v in return ((v * rhs).sum(), v * lhs) }) } @inlinable @derivative(of: /) static func _vjpDivide(lhs: Scalar, rhs: Self) -> ( value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector) ) { return (lhs / rhs, { v in ((v / rhs).sum(), -lhs / (rhs * rhs) * v) }) } } // FIXME(TF-1103): Derivative registration does not yet support // `@_alwaysEmitIntoClient` original functions. /* extension SIMD where Self: Differentiable, TangentVector: SIMD, Scalar: BinaryFloatingPoint & Differentiable, Scalar.TangentVector: BinaryFloatingPoint, TangentVector == Self { @inlinable @derivative(of: sum) func _vjpSum() -> ( value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector ) { return (sum(), { v in Self(repeating: Scalar(v)) }) } } */ extension SIMD where Self: Differentiable, Self.TangentVector: SIMD, Scalar: BinaryFloatingPoint & Differentiable, Self.TangentVector == Self, Scalar.TangentVector == Scalar { @inlinable @derivative(of: init(repeating:)) static func _vjpInit(repeating value: Scalar) -> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector) { return (Self(repeating: value), { v in v.sum() }) } }