mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
460 lines
11 KiB
Swift
460 lines
11 KiB
Swift
//===--- SIMDDifferentiation.swift.gyb ------------------------*- swift -*-===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
import Swift
|
|
|
|
%{
|
|
storagescalarCounts = [2,4,8,16,32,64]
|
|
vectorscalarCounts = storagescalarCounts + [3]
|
|
}%
|
|
|
|
%for n in vectorscalarCounts:
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Protocol conformances
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
extension SIMD${n}: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {}
|
|
|
|
extension SIMD${n}: Differentiable
|
|
where
|
|
Scalar: Differentiable & BinaryFloatingPoint,
|
|
Scalar.TangentVector: BinaryFloatingPoint
|
|
{
|
|
public typealias TangentVector = SIMD${n}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Derivatives
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
extension SIMD${n}
|
|
where
|
|
Scalar: Differentiable & BinaryFloatingPoint,
|
|
Scalar.TangentVector == Scalar
|
|
{
|
|
// NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation.
|
|
@inlinable
|
|
@derivative(of: subscript(_:))
|
|
internal func _vjpSubscript(_ index: Int)
|
|
-> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector)
|
|
{
|
|
return (self[index], { v in
|
|
var zeros = Self.zero
|
|
zeros[index] = v
|
|
return zeros
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: subscript(_:))
|
|
internal func _jvpSubscript(index: Int)
|
|
-> (value: Scalar, differential: (TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
return (self[index], { v in
|
|
return .init(v[index])
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: subscript(_:).set)
|
|
internal mutating func _vjpSubscriptSetter(_ newValue: Scalar, _ index: Int)
|
|
-> (value: Void, pullback: (inout TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
self[index] = newValue
|
|
return ((), { dSelf in
|
|
let dNewValue = dSelf[index]
|
|
dSelf[index] = 0
|
|
return dNewValue
|
|
})
|
|
}
|
|
}
|
|
|
|
%end
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
TangentVector: SIMD,
|
|
Scalar: BinaryFloatingPoint,
|
|
TangentVector.Scalar: BinaryFloatingPoint
|
|
{
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _vjpAdd(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return (lhs + rhs, { v in
|
|
return (v, v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _jvpAdd(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return (lhs + rhs, { ltan, rtan in
|
|
return ltan + rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpSubtract(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return (lhs - rhs, { v in
|
|
return (v, -v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpSubtract(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return (lhs - rhs, { ltan, rtan in
|
|
return ltan - rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpNegate(rhs: Self)
|
|
-> (value: Self, pullback: (TangentVector) -> (TangentVector))
|
|
{
|
|
return (-rhs, { v in
|
|
return -v
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpNegate(rhs: Self)
|
|
-> (value: Self, differential: (TangentVector) -> (TangentVector))
|
|
{
|
|
return (-rhs, { v in
|
|
return -v
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
Scalar: BinaryFloatingPoint,
|
|
Self.TangentVector == Self
|
|
{
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _vjpMultiply(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return (lhs * rhs, { v in
|
|
return (v * rhs, v * lhs)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _jvpMultiply(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return (lhs * rhs, { ltan, rtan in
|
|
return lhs * rtan + ltan * rhs
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _vjpDivide(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return ( lhs / rhs, { v in
|
|
(v / rhs, -lhs / (rhs * rhs) * v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _jvpDivide(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return ( lhs / rhs, { ltan, rtan in
|
|
(ltan * rhs - lhs * rtan) / (rhs * rhs)
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
TangentVector: SIMD,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Scalar.TangentVector: BinaryFloatingPoint,
|
|
TangentVector.Scalar == Scalar.TangentVector
|
|
{
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _vjpAdd(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs + rhs, { v in
|
|
return (v.sum(), v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _jvpAdd(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs + rhs, { ltan, rtan in
|
|
return ltan + rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpSubtract(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs - rhs, { v in
|
|
return (v.sum(), -v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpSubtract(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs - rhs, { ltan, rtan in
|
|
return ltan - rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _vjpAdd(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs + rhs, { v in
|
|
return (v, v.sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _jvpAdd(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs + rhs, { ltan, rtan in
|
|
return ltan + rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpSubtract(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs - rhs, { v in
|
|
return (v, -v.sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpSubtract(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs - rhs, { ltan, rtan in
|
|
return ltan - rtan
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Self.TangentVector == Self,
|
|
Scalar.TangentVector == Scalar
|
|
{
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _vjpMultiply(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs * rhs, { v in
|
|
return (v * rhs, (v * lhs).sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _jvpMultiply(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs * rhs, { ltan, rtan in
|
|
return lhs * rtan + ltan * rhs
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _vjpDivide(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs / rhs, { v in
|
|
(v / rhs, (-lhs / (rhs * rhs) * v).sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _jvpDivide(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs / rhs, { ltan, rtan in
|
|
(ltan * rhs - lhs * rtan) / (rhs * rhs)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _vjpMultiply(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs * rhs, { v in
|
|
return ((v * rhs).sum(), v * lhs)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _jvpMultiply(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs * rhs, { ltan, rtan in
|
|
return lhs * rtan + ltan * rhs
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _vjpDivide(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs / rhs, { v in
|
|
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _jvpDivide(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs / rhs, { ltan, rtan in
|
|
(ltan * rhs - lhs * rtan) / (rhs * rhs)
|
|
})
|
|
}
|
|
}
|
|
|
|
// FIXME(TF-1103): Derivative registration does not yet support
|
|
// `@_alwaysEmitIntoClient` original functions like `SIMD.sum()`.
|
|
/*
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
TangentVector: SIMD,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Scalar.TangentVector: BinaryFloatingPoint,
|
|
TangentVector == Self
|
|
{
|
|
@inlinable
|
|
@derivative(of: sum)
|
|
func _vjpSum() -> (
|
|
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (sum(), { v in Self(repeating: Scalar(v)) })
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: sum)
|
|
func _jvpSum() -> (
|
|
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
|
|
) {
|
|
return (sum(), { v in Scalar.TangentVector(v.sum()) })
|
|
}
|
|
}
|
|
*/
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Self.TangentVector == Self,
|
|
Scalar.TangentVector == Scalar
|
|
{
|
|
@inlinable
|
|
@derivative(of: init(repeating:))
|
|
static func _vjpInit(repeating value: Scalar)
|
|
-> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
return (Self(repeating: value), { v in v.sum() })
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: init(repeating:))
|
|
static func _jvpInit(repeating value: Scalar)
|
|
-> (value: Self, differential: (Scalar.TangentVector) -> TangentVector)
|
|
{
|
|
return (Self(repeating: value), { v in Self(repeating: v) })
|
|
}
|
|
}
|