mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
The PR #81766 introduced some concrete SIMD operations, but
`init(repeating:)` was temporarily disabled due to differentiation
testing break. See:
7a00619065
This PR contains two changes:
1. Define custom derivatives for concrete `init(repeating:)` so we do
not fall into non-differentiability error diagnostic.
2. Add a fix to SIL linker so differentiability witness lookup is done
when the original function has both `@_alwaysEmitIntoClient` and
`@_transparent`. Similar changes were introduced previously in #78908,
but they only handled `@_alwaysEmitIntoClient` without `@_transparent`.
<!--
If this pull request is targeting a release branch, please fill out the
following form:
https://github.com/swiftlang/.github/blob/main/PULL_REQUEST_TEMPLATE/release.md?plain=1
Otherwise, replace this comment with a description of your changes and
rationale. Provide links to external references/discussions if
appropriate.
If this pull request resolves any GitHub issues, link them like so:
Resolves <link to issue>, resolves <link to another issue>.
For more information about linking a pull request to an issue, see:
https://docs.github.com/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue
-->
<!--
Before merging this pull request, you must run the Swift continuous
integration tests.
For information about triggering CI builds via @swift-ci, see:
https://github.com/apple/swift/blob/main/docs/ContinuousIntegration.md#swift-ci
Thank you for your contribution to Swift!
-->
482 lines
11 KiB
Swift
482 lines
11 KiB
Swift
//===--- SIMDDifferentiation.swift.gyb ------------------------*- swift -*-===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
import Swift
|
|
|
|
%{
|
|
storagescalarCounts = [2,4,8,16,32,64]
|
|
vectorscalarCounts = storagescalarCounts + [3]
|
|
}%
|
|
|
|
%for n in vectorscalarCounts:
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Protocol conformances
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
extension SIMD${n}: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {}
|
|
|
|
extension SIMD${n}: Differentiable
|
|
where
|
|
Scalar: Differentiable & BinaryFloatingPoint,
|
|
Scalar.TangentVector: BinaryFloatingPoint
|
|
{
|
|
public typealias TangentVector = SIMD${n}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Derivatives
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
extension SIMD${n}
|
|
where
|
|
Scalar: Differentiable & BinaryFloatingPoint,
|
|
Scalar.TangentVector == Scalar
|
|
{
|
|
// NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation.
|
|
@inlinable
|
|
@derivative(of: subscript(_:))
|
|
internal func _vjpSubscript(_ index: Int)
|
|
-> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector)
|
|
{
|
|
return (self[index], { v in
|
|
var zeros = Self.zero
|
|
zeros[index] = v
|
|
return zeros
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: subscript(_:))
|
|
internal func _jvpSubscript(index: Int)
|
|
-> (value: Scalar, differential: (TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
return (self[index], { v in
|
|
return .init(v[index])
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: subscript(_:).set)
|
|
internal mutating func _vjpSubscriptSetter(_ newValue: Scalar, _ index: Int)
|
|
-> (value: Void, pullback: (inout TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
self[index] = newValue
|
|
return ((), { dSelf in
|
|
let dNewValue = dSelf[index]
|
|
dSelf[index] = 0
|
|
return dNewValue
|
|
})
|
|
}
|
|
}
|
|
|
|
%end
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
TangentVector: SIMD,
|
|
Scalar: BinaryFloatingPoint,
|
|
TangentVector.Scalar: BinaryFloatingPoint
|
|
{
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _vjpAdd(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return (lhs + rhs, { v in
|
|
return (v, v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _jvpAdd(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return (lhs + rhs, { ltan, rtan in
|
|
return ltan + rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpSubtract(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return (lhs - rhs, { v in
|
|
return (v, -v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpSubtract(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return (lhs - rhs, { ltan, rtan in
|
|
return ltan - rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpNegate(rhs: Self)
|
|
-> (value: Self, pullback: (TangentVector) -> (TangentVector))
|
|
{
|
|
return (-rhs, { v in
|
|
return -v
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpNegate(rhs: Self)
|
|
-> (value: Self, differential: (TangentVector) -> (TangentVector))
|
|
{
|
|
return (-rhs, { v in
|
|
return -v
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
Scalar: BinaryFloatingPoint,
|
|
Self.TangentVector == Self
|
|
{
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _vjpMultiply(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return (lhs * rhs, { v in
|
|
return (v * rhs, v * lhs)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _jvpMultiply(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return (lhs * rhs, { ltan, rtan in
|
|
return lhs * rtan + ltan * rhs
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _vjpDivide(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return ( lhs / rhs, { v in
|
|
(v / rhs, -lhs / (rhs * rhs) * v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _jvpDivide(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return ( lhs / rhs, { ltan, rtan in
|
|
(ltan * rhs - lhs * rtan) / (rhs * rhs)
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
TangentVector: SIMD,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Scalar.TangentVector: BinaryFloatingPoint,
|
|
TangentVector.Scalar == Scalar.TangentVector
|
|
{
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _vjpAdd(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs + rhs, { v in
|
|
return (v.sum(), v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _jvpAdd(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs + rhs, { ltan, rtan in
|
|
return ltan + rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpSubtract(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs - rhs, { v in
|
|
return (v.sum(), -v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpSubtract(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs - rhs, { ltan, rtan in
|
|
return ltan - rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _vjpAdd(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs + rhs, { v in
|
|
return (v, v.sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _jvpAdd(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs + rhs, { ltan, rtan in
|
|
return ltan + rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpSubtract(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs - rhs, { v in
|
|
return (v, -v.sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpSubtract(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs - rhs, { ltan, rtan in
|
|
return ltan - rtan
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Self.TangentVector == Self,
|
|
Scalar.TangentVector == Scalar
|
|
{
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _vjpMultiply(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs * rhs, { v in
|
|
return (v * rhs, (v * lhs).sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _jvpMultiply(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs * rhs, { ltan, rtan in
|
|
return lhs * rtan + ltan * rhs
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _vjpDivide(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs / rhs, { v in
|
|
(v / rhs, (-lhs / (rhs * rhs) * v).sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _jvpDivide(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs / rhs, { ltan, rtan in
|
|
(ltan * rhs - lhs * rtan) / (rhs * rhs)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _vjpMultiply(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs * rhs, { v in
|
|
return ((v * rhs).sum(), v * lhs)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _jvpMultiply(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs * rhs, { ltan, rtan in
|
|
return lhs * rtan + ltan * rhs
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _vjpDivide(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs / rhs, { v in
|
|
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _jvpDivide(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs / rhs, { ltan, rtan in
|
|
(ltan * rhs - lhs * rtan) / (rhs * rhs)
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
TangentVector: SIMD,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Scalar.TangentVector: BinaryFloatingPoint,
|
|
TangentVector == Self
|
|
{
|
|
@inlinable
|
|
@_alwaysEmitIntoClient
|
|
@derivative(of: sum)
|
|
func _vjpSum() -> (
|
|
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (sum(), { v in Self(repeating: Scalar(v)) })
|
|
}
|
|
|
|
@inlinable
|
|
@_alwaysEmitIntoClient
|
|
@derivative(of: sum)
|
|
func _jvpSum() -> (
|
|
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
|
|
) {
|
|
return (sum(), { v in Scalar.TangentVector(v.sum()) })
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Self.TangentVector == Self,
|
|
Scalar.TangentVector == Scalar
|
|
{
|
|
@inlinable
|
|
@derivative(of: init(repeating:))
|
|
static func _vjpInit(repeating value: Scalar)
|
|
-> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
return (Self(repeating: value), { v in v.sum() })
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: init(repeating:))
|
|
static func _jvpInit(repeating value: Scalar)
|
|
-> (value: Self, differential: (Scalar.TangentVector) -> TangentVector)
|
|
{
|
|
return (Self(repeating: value), { v in Self(repeating: v) })
|
|
}
|
|
}
|
|
|
|
%for (Scalar, bits) in [('Float',32), ('Double',64)]:
|
|
% for n in vectorscalarCounts:
|
|
extension SIMD${n} where Scalar == ${Scalar} {
|
|
@inlinable
|
|
@_alwaysEmitIntoClient
|
|
@derivative(of: init(repeating:))
|
|
static func _vjpInit(repeating value: Scalar)
|
|
-> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
return (Self(repeating: value), { v in v.sum() })
|
|
}
|
|
|
|
@inlinable
|
|
@_alwaysEmitIntoClient
|
|
@derivative(of: init(repeating:))
|
|
static func _jvpInit(repeating value: Scalar)
|
|
-> (value: Self, differential: (Scalar.TangentVector) -> TangentVector)
|
|
{
|
|
return (Self(repeating: value), { v in Self(repeating: v) })
|
|
}
|
|
}
|
|
% end
|
|
%end
|