//===--- AnyDerivative.swift ----------------------------------*- swift -*-===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file defines type-erased wrappers for `Differentiable`-conforming types // and `Differentiable.TangentVector` associated type implementations. // //===----------------------------------------------------------------------===// import Swift //===----------------------------------------------------------------------===// // `AnyDifferentiable` //===----------------------------------------------------------------------===// internal protocol _AnyDifferentiableBox { // `Differentiable` requirements. mutating func _move(by offset: AnyDerivative) /// The underlying base value, type-erased to `Any`. var _typeErasedBase: Any { get } /// Returns the underlying value unboxed to the given type, if possible. func _unboxed(to type: U.Type) -> U? } internal struct _ConcreteDifferentiableBox: _AnyDifferentiableBox { /// The underlying base value. var _base: T init(_ base: T) { self._base = base } /// The underlying base value, type-erased to `Any`. var _typeErasedBase: Any { return _base } func _unboxed(to type: U.Type) -> U? { return (self as? _ConcreteDifferentiableBox)?._base } mutating func _move(by offset: AnyDerivative) { guard let offsetBase = offset.base as? T.TangentVector else { _derivativeTypeMismatch(T.self, type(of: offset.base)) } _base.move(by: offsetBase) } } public struct AnyDifferentiable: Differentiable { internal var _box: _AnyDifferentiableBox internal init(_box: _AnyDifferentiableBox) { self._box = _box } /// The underlying base value. public var base: Any { return _box._typeErasedBase } /// Creates a type-erased derivative from the given derivative. @differentiable(reverse) public init(_ base: T) { self._box = _ConcreteDifferentiableBox(base) } @inlinable @derivative(of: init) internal static func _vjpInit( _ base: T ) -> (value: AnyDifferentiable, pullback: (AnyDerivative) -> T.TangentVector) { return (AnyDifferentiable(base), { v in v.base as! T.TangentVector }) } @inlinable @derivative(of: init) internal static func _jvpInit( _ base: T ) -> ( value: AnyDifferentiable, differential: (T.TangentVector) -> AnyDerivative ) { return (AnyDifferentiable(base), { dbase in AnyDerivative(dbase) }) } public typealias TangentVector = AnyDerivative public mutating func move(by offset: TangentVector) { _box._move(by: offset) } } extension AnyDifferentiable: CustomReflectable { public var customMirror: Mirror { Mirror(reflecting: base) } } //===----------------------------------------------------------------------===// // `AnyDerivative` //===----------------------------------------------------------------------===// @usableFromInline internal protocol _AnyDerivativeBox { // `Equatable` requirements (implied by `AdditiveArithmetic`). func _isEqual(to other: _AnyDerivativeBox) -> Bool func _isNotEqual(to other: _AnyDerivativeBox) -> Bool // `AdditiveArithmetic` requirements. static var _zero: _AnyDerivativeBox { get } func _adding(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox func _subtracting(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox // `Differentiable` requirements. mutating func _move(by offset: _AnyDerivativeBox) /// The underlying base value, type-erased to `Any`. var _typeErasedBase: Any { get } /// Returns the underlying value unboxed to the given type, if possible. func _unboxed(to type: U.Type) -> U? where U: Differentiable, U.TangentVector == U } extension _AnyDerivativeBox { /// Returns true if the underlying value has type `AnyDerivative.OpaqueZero`. @inlinable func _isOpaqueZero() -> Bool { return _unboxed(to: AnyDerivative.OpaqueZero.self) != nil } } @frozen @usableFromInline internal struct _ConcreteDerivativeBox: _AnyDerivativeBox where T: Differentiable, T.TangentVector == T { /// The underlying base value. @usableFromInline var _base: T @inlinable internal init(_ base: T) { self._base = base } /// The underlying base value, type-erased to `Any`. @inlinable var _typeErasedBase: Any { return _base } @inlinable func _unboxed(to type: U.Type) -> U? where U: Differentiable, U.TangentVector == U { return (self as? _ConcreteDerivativeBox)?._base } // `Equatable` requirements (implied by `AdditiveArithmetic`). @inlinable func _isEqual(to other: _AnyDerivativeBox) -> Bool { return _base == other._unboxed(to: T.self) } @inlinable func _isNotEqual(to other: _AnyDerivativeBox) -> Bool { return _base != other._unboxed(to: T.self) } // `AdditiveArithmetic` requirements. @inlinable static var _zero: _AnyDerivativeBox { return _ConcreteDerivativeBox(T.zero) } @inlinable func _adding(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox { // 0 + x = x if _isOpaqueZero() { return x } // y + 0 = y if x._isOpaqueZero() { return self } guard let xBase = x._unboxed(to: T.self) else { _derivativeTypeMismatch(T.self, type(of: x._typeErasedBase)) } return _ConcreteDerivativeBox(_base + xBase) } @inlinable func _subtracting(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox { // y - 0 = y if x._isOpaqueZero() { return self } // 0 - x = -x if _isOpaqueZero() { return type(of: x)._zero._subtracting(x) } guard let xBase = x._unboxed(to: T.self) else { _derivativeTypeMismatch(T.self, type(of: x._typeErasedBase)) } return _ConcreteDerivativeBox(_base - xBase) } // `Differentiable` requirements. @inlinable mutating func _move(by offset: _AnyDerivativeBox) { if offset._isOpaqueZero() { return } // The case where `self._isOpaqueZero()` returns true is handled in // `AnyDerivative.move(by:)`. guard let offsetBase = offset._unboxed(to: T.TangentVector.self) else { _derivativeTypeMismatch(T.self, type(of: offset._typeErasedBase)) } _base.move(by: offsetBase) } } /// A type-erased derivative value. /// /// The `AnyDerivative` type forwards its operations to an arbitrary underlying /// base derivative value conforming to `Differentiable` and /// `AdditiveArithmetic`, hiding the specifics of the underlying value. @frozen public struct AnyDerivative: Differentiable & AdditiveArithmetic { @usableFromInline internal var _box: _AnyDerivativeBox @inlinable internal init(_box: _AnyDerivativeBox) { self._box = _box } /// The underlying base value. @inlinable public var base: Any { return _box._typeErasedBase } /// Creates a type-erased derivative from the given derivative. @inlinable @differentiable(reverse) public init(_ base: T) where T: Differentiable, T.TangentVector == T { self._box = _ConcreteDerivativeBox(base) } @inlinable @derivative(of: init) internal static func _vjpInit( _ base: T ) -> (value: AnyDerivative, pullback: (AnyDerivative) -> T.TangentVector) where T: Differentiable, T.TangentVector == T { return (AnyDerivative(base), { v in v.base as! T.TangentVector }) } @inlinable @derivative(of: init) internal static func _jvpInit( _ base: T ) -> (value: AnyDerivative, differential: (T.TangentVector) -> AnyDerivative) where T: Differentiable, T.TangentVector == T { return (AnyDerivative(base), { dbase in AnyDerivative(dbase) }) } public typealias TangentVector = AnyDerivative // `Equatable` requirements (implied by `AdditiveArithmetic`). @inlinable public static func == (lhs: AnyDerivative, rhs: AnyDerivative) -> Bool { return lhs._box._isEqual(to: rhs._box) } @inlinable public static func != (lhs: AnyDerivative, rhs: AnyDerivative) -> Bool { return lhs._box._isNotEqual(to: rhs._box) } // `AdditiveArithmetic` requirements. /// Internal struct representing an opaque zero value. @frozen @usableFromInline internal struct OpaqueZero: Differentiable & AdditiveArithmetic {} @inlinable public static var zero: AnyDerivative { return AnyDerivative( _box: _ConcreteDerivativeBox(OpaqueZero.zero)) } @inlinable public static func + ( lhs: AnyDerivative, rhs: AnyDerivative ) -> AnyDerivative { return AnyDerivative(_box: lhs._box._adding(rhs._box)) } @derivative(of: +) @inlinable internal static func _vjpAdd( lhs: AnyDerivative, rhs: AnyDerivative ) -> ( value: AnyDerivative, pullback: (AnyDerivative) -> (AnyDerivative, AnyDerivative) ) { return (lhs + rhs, { v in (v, v) }) } @derivative(of: +) @inlinable internal static func _jvpAdd( lhs: AnyDerivative, rhs: AnyDerivative ) -> ( value: AnyDerivative, differential: (AnyDerivative, AnyDerivative) -> (AnyDerivative) ) { return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs }) } @inlinable public static func - ( lhs: AnyDerivative, rhs: AnyDerivative ) -> AnyDerivative { return AnyDerivative(_box: lhs._box._subtracting(rhs._box)) } @derivative(of: -) @inlinable internal static func _vjpSubtract( lhs: AnyDerivative, rhs: AnyDerivative ) -> ( value: AnyDerivative, pullback: (AnyDerivative) -> (AnyDerivative, AnyDerivative) ) { return (lhs - rhs, { v in (v, .zero - v) }) } @derivative(of: -) @inlinable internal static func _jvpSubtract( lhs: AnyDerivative, rhs: AnyDerivative ) -> ( value: AnyDerivative, differential: (AnyDerivative, AnyDerivative) -> AnyDerivative ) { return (lhs - rhs, { (dlhs, drhs) in dlhs - drhs }) } // `Differentiable` requirements. @inlinable public mutating func move(by offset: TangentVector) { if _box._isOpaqueZero() { _box = offset._box return } _box._move(by: offset._box) } } extension AnyDerivative: CustomReflectable { public var customMirror: Mirror { Mirror(reflecting: base) } } //===----------------------------------------------------------------------===// // Helpers //===----------------------------------------------------------------------===// @inline(never) @usableFromInline internal func _derivativeTypeMismatch( _ x: Any.Type, _ y: Any.Type, file: StaticString = #file, line: UInt = #line ) -> Never { preconditionFailure( """ Derivative type mismatch: \ \(String(reflecting: x)) and \(String(reflecting: y)) """, file: file, line: line) }