Files
swift-mirror/stdlib/public/Differentiation/OptionalDifferentiation.swift
Andrew Savonichev 2f7b42ceaf [AutoDiff] Handle init_enum_data_addr and inject_enum_addr for Optional (#68300)
Optional's `init_enum_data_addr` and `inject_enum_addr` instructions are generated in presence of non-loadable Optional values. The compiler used to treat these instructions as inactive, and this resulted in silent run-time
issues described in #64223.

The patch marks `init_enum_data_addr` as "active" if its Optional operand is also active, and in PullbackCloner we differentiate through it and the related `inject_enum_addr`.

However, we only determine this relation in simple cases when both instructions are in the same block. There is no def-use relation between them (both take the same Optional operand), so if there is more than one set of instructions
operating on the same Optional, or there is some control flow, we currently bail out.

In PullbackCloner, we walk over instructions in reverse order and start from `inject_enum_addr` and its `Optional<Wrapped>.TangentVector` operand. Assuming that is is already initialized, we emit an `unchecked_take_enum_data_addr` and set it as the adjoint buffer of `init_enum_data_addr`. The Optional value is
invalidated, and we have to destroy the enum data address later when we reach `init_enum_data_addr`.
2023-09-22 01:07:16 -07:00

67 lines
1.8 KiB
Swift

//===--- OptionalDifferentiation.swift ------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
import Swift
extension Optional: Differentiable where Wrapped: Differentiable {
@frozen
public struct TangentVector: Differentiable, AdditiveArithmetic {
public typealias TangentVector = Self
public var value: Wrapped.TangentVector?
public init(_ value: Wrapped.TangentVector?) {
self.value = value
}
public static var zero: Self {
return Self(.zero)
}
public static func + (lhs: Self, rhs: Self) -> Self {
switch (lhs.value, rhs.value) {
case (nil, nil): return Self(nil)
case let (x?, nil): return Self(x)
case let (nil, y?): return Self(y)
case let (x?, y?): return Self(x + y)
}
}
public static func - (lhs: Self, rhs: Self) -> Self {
switch (lhs.value, rhs.value) {
case (nil, nil): return Self(nil)
case let (x?, nil): return Self(x)
case let (nil, y?): return Self(.zero - y)
case let (x?, y?): return Self(x - y)
}
}
public mutating func move(by offset: TangentVector) {
if let value = offset.value {
self.value?.move(by: value)
}
}
}
public mutating func move(by offset: TangentVector) {
if let value = offset.value {
self?.move(by: value)
}
}
}
extension Optional.TangentVector: CustomReflectable {
public var customMirror: Mirror {
return value.customMirror
}
}