mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Consider an `@_alwaysEmitIntoClient` function and a custom derivative defined for it. Previously, such a combination resulted different errors under different circumstances. Sometimes, there were linker errors due to missing derivative function symbol - these occurred when we tried to find the derivative in a module, while it should have been emitted into client's code (and it did not happen). Sometimes, there were SIL verification failures like this: ``` SIL verification failed: internal/private function cannot be serialized or serializable: !F->isAnySerialized() || embedded ``` Linkage and serialization options for the derivative were not handled properly, and, instead of PublicNonABI linkage, we had Private one which is unsupported for serialization - but we need to serialize `@_alwaysEmitIntoClient` functions so the client's code is able to see them. This patch resolves the issue and adds proper handling of custom derivatives of `@_alwaysEmitIntoClient` functions. Note that either both the function and its custom derivative or none of them should have `@_alwaysEmitIntoClient` attribute, mismatch in this attribute is not supported. The following cases are handled (assume that in each case client's code uses the derivative). 1. Both the function and its derivative are defined in a single file in one module. 2. Both the function and its derivative are defined in different files which are compiled to a single module. 3. The function is defined in one module, its derivative is defined in another module. 4. The function and the derivative are defined as members of a protocol extension in two separate modules - one for the function and one for the derivative. A struct conforming the protocol is defined in the third module. 5. The function and the derivative are defined as members of a struct extension in two separate modules - one for the function and one for the derivative. The changes allow to define derivatives for methods of `SIMD`. Fixes #54445 <!-- If this pull request is targeting a release branch, please fill out the following form: https://github.com/swiftlang/.github/blob/main/PULL_REQUEST_TEMPLATE/release.md?plain=1 Otherwise, replace this comment with a description of your changes and rationale. Provide links to external references/discussions if appropriate. If this pull request resolves any GitHub issues, link them like so: Resolves <link to issue>, resolves <link to another issue>. For more information about linking a pull request to an issue, see: https://docs.github.com/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue --> <!-- Before merging this pull request, you must run the Swift continuous integration tests. For information about triggering CI builds via @swift-ci, see: https://github.com/apple/swift/blob/main/docs/ContinuousIntegration.md#swift-ci Thank you for your contribution to Swift! -->
458 lines
11 KiB
Swift
458 lines
11 KiB
Swift
//===--- SIMDDifferentiation.swift.gyb ------------------------*- swift -*-===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
import Swift
|
|
|
|
%{
|
|
storagescalarCounts = [2,4,8,16,32,64]
|
|
vectorscalarCounts = storagescalarCounts + [3]
|
|
}%
|
|
|
|
%for n in vectorscalarCounts:
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Protocol conformances
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
extension SIMD${n}: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {}
|
|
|
|
extension SIMD${n}: Differentiable
|
|
where
|
|
Scalar: Differentiable & BinaryFloatingPoint,
|
|
Scalar.TangentVector: BinaryFloatingPoint
|
|
{
|
|
public typealias TangentVector = SIMD${n}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Derivatives
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
extension SIMD${n}
|
|
where
|
|
Scalar: Differentiable & BinaryFloatingPoint,
|
|
Scalar.TangentVector == Scalar
|
|
{
|
|
// NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation.
|
|
@inlinable
|
|
@derivative(of: subscript(_:))
|
|
internal func _vjpSubscript(_ index: Int)
|
|
-> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector)
|
|
{
|
|
return (self[index], { v in
|
|
var zeros = Self.zero
|
|
zeros[index] = v
|
|
return zeros
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: subscript(_:))
|
|
internal func _jvpSubscript(index: Int)
|
|
-> (value: Scalar, differential: (TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
return (self[index], { v in
|
|
return .init(v[index])
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: subscript(_:).set)
|
|
internal mutating func _vjpSubscriptSetter(_ newValue: Scalar, _ index: Int)
|
|
-> (value: Void, pullback: (inout TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
self[index] = newValue
|
|
return ((), { dSelf in
|
|
let dNewValue = dSelf[index]
|
|
dSelf[index] = 0
|
|
return dNewValue
|
|
})
|
|
}
|
|
}
|
|
|
|
%end
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
TangentVector: SIMD,
|
|
Scalar: BinaryFloatingPoint,
|
|
TangentVector.Scalar: BinaryFloatingPoint
|
|
{
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _vjpAdd(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return (lhs + rhs, { v in
|
|
return (v, v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _jvpAdd(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return (lhs + rhs, { ltan, rtan in
|
|
return ltan + rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpSubtract(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return (lhs - rhs, { v in
|
|
return (v, -v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpSubtract(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return (lhs - rhs, { ltan, rtan in
|
|
return ltan - rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpNegate(rhs: Self)
|
|
-> (value: Self, pullback: (TangentVector) -> (TangentVector))
|
|
{
|
|
return (-rhs, { v in
|
|
return -v
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpNegate(rhs: Self)
|
|
-> (value: Self, differential: (TangentVector) -> (TangentVector))
|
|
{
|
|
return (-rhs, { v in
|
|
return -v
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
Scalar: BinaryFloatingPoint,
|
|
Self.TangentVector == Self
|
|
{
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _vjpMultiply(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return (lhs * rhs, { v in
|
|
return (v * rhs, v * lhs)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _jvpMultiply(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return (lhs * rhs, { ltan, rtan in
|
|
return lhs * rtan + ltan * rhs
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _vjpDivide(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
|
|
)
|
|
{
|
|
return ( lhs / rhs, { v in
|
|
(v / rhs, -lhs / (rhs * rhs) * v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _jvpDivide(lhs: Self, rhs: Self)
|
|
-> (
|
|
value: Self, differential: (TangentVector, TangentVector) -> TangentVector
|
|
)
|
|
{
|
|
return ( lhs / rhs, { ltan, rtan in
|
|
(ltan * rhs - lhs * rtan) / (rhs * rhs)
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
TangentVector: SIMD,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Scalar.TangentVector: BinaryFloatingPoint,
|
|
TangentVector.Scalar == Scalar.TangentVector
|
|
{
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _vjpAdd(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs + rhs, { v in
|
|
return (v.sum(), v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _jvpAdd(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs + rhs, { ltan, rtan in
|
|
return ltan + rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpSubtract(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs - rhs, { v in
|
|
return (v.sum(), -v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpSubtract(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs - rhs, { ltan, rtan in
|
|
return ltan - rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _vjpAdd(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs + rhs, { v in
|
|
return (v, v.sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: +)
|
|
static func _jvpAdd(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs + rhs, { ltan, rtan in
|
|
return ltan + rtan
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _vjpSubtract(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs - rhs, { v in
|
|
return (v, -v.sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: -)
|
|
static func _jvpSubtract(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs - rhs, { ltan, rtan in
|
|
return ltan - rtan
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Self.TangentVector == Self,
|
|
Scalar.TangentVector == Scalar
|
|
{
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _vjpMultiply(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs * rhs, { v in
|
|
return (v * rhs, (v * lhs).sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _jvpMultiply(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs * rhs, { ltan, rtan in
|
|
return lhs * rtan + ltan * rhs
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _vjpDivide(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)
|
|
) {
|
|
return (lhs / rhs, { v in
|
|
(v / rhs, (-lhs / (rhs * rhs) * v).sum())
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _jvpDivide(lhs: Self, rhs: Scalar) -> (
|
|
value: Self,
|
|
differential: (TangentVector, Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs / rhs, { ltan, rtan in
|
|
(ltan * rhs - lhs * rtan) / (rhs * rhs)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _vjpMultiply(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs * rhs, { v in
|
|
return ((v * rhs).sum(), v * lhs)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: *)
|
|
static func _jvpMultiply(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs * rhs, { ltan, rtan in
|
|
return lhs * rtan + ltan * rhs
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _vjpDivide(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)
|
|
) {
|
|
return (lhs / rhs, { v in
|
|
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
|
|
})
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: /)
|
|
static func _jvpDivide(lhs: Scalar, rhs: Self) -> (
|
|
value: Self,
|
|
differential: (Scalar.TangentVector, TangentVector) -> TangentVector
|
|
) {
|
|
return (lhs / rhs, { ltan, rtan in
|
|
(ltan * rhs - lhs * rtan) / (rhs * rhs)
|
|
})
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
TangentVector: SIMD,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Scalar.TangentVector: BinaryFloatingPoint,
|
|
TangentVector == Self
|
|
{
|
|
@inlinable
|
|
@_alwaysEmitIntoClient
|
|
@derivative(of: sum)
|
|
func _vjpSum() -> (
|
|
value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector
|
|
) {
|
|
return (sum(), { v in Self(repeating: Scalar(v)) })
|
|
}
|
|
|
|
@inlinable
|
|
@_alwaysEmitIntoClient
|
|
@derivative(of: sum)
|
|
func _jvpSum() -> (
|
|
value: Scalar, differential: (TangentVector) -> Scalar.TangentVector
|
|
) {
|
|
return (sum(), { v in Scalar.TangentVector(v.sum()) })
|
|
}
|
|
}
|
|
|
|
extension SIMD
|
|
where
|
|
Self: Differentiable,
|
|
Scalar: BinaryFloatingPoint & Differentiable,
|
|
Self.TangentVector == Self,
|
|
Scalar.TangentVector == Scalar
|
|
{
|
|
@inlinable
|
|
@derivative(of: init(repeating:))
|
|
static func _vjpInit(repeating value: Scalar)
|
|
-> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector)
|
|
{
|
|
return (Self(repeating: value), { v in v.sum() })
|
|
}
|
|
|
|
@inlinable
|
|
@derivative(of: init(repeating:))
|
|
static func _jvpInit(repeating value: Scalar)
|
|
-> (value: Self, differential: (Scalar.TangentVector) -> TangentVector)
|
|
{
|
|
return (Self(repeating: value), { v in Self(repeating: v) })
|
|
}
|
|
}
|