mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
544 lines
15 KiB
Swift
544 lines
15 KiB
Swift
// RUN: %target-run-simple-swift
|
|
|
|
// REQUIRES: executable_test
|
|
|
|
// Test differentiation of `Optional` values and operations.
|
|
|
|
import DifferentiationUnittest
|
|
import StdlibUnittest
|
|
|
|
var OptionalTests = TestSuite("OptionalDifferentiation")
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Basic tests.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TODO(TF-433): operator `??` lowers to an active `try_apply`.
|
|
/*
|
|
@differentiable(reverse)
|
|
func optional_nil_coalescing(_ maybeX: Float?) -> Float {
|
|
return maybeX ?? 10
|
|
}
|
|
*/
|
|
|
|
OptionalTests.test("Active") {
|
|
@differentiable(reverse)
|
|
func id(y: Float) -> Float? {
|
|
return y
|
|
}
|
|
|
|
@differentiable(reverse)
|
|
func id2(y: Float?) -> Float {
|
|
return y!
|
|
}
|
|
|
|
@differentiable(reverse)
|
|
func square(y: Float) -> Float? {
|
|
return y * y
|
|
}
|
|
|
|
@differentiable(reverse)
|
|
func square2(y: Float?) -> Float {
|
|
return y! * y!
|
|
}
|
|
|
|
expectEqual(gradient(at: 10, of: {y in id(y:y)!}), .init(1.0))
|
|
expectEqual(gradient(at: 10, of: {y in id2(y:y)}), .init(1.0))
|
|
expectEqual(gradient(at: 10, of: {y in square(y:y)!}), .init(20.0))
|
|
expectEqual(gradient(at: 10, of: {y in square2(y:y)}), .init(20.0))
|
|
}
|
|
|
|
OptionalTests.test("Let") {
|
|
@differentiable(reverse)
|
|
func optional_let(_ maybeX: Float?) -> Float {
|
|
if let x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_let), .init(20.0))
|
|
expectEqual(gradient(at: nil, of: optional_let), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_let_tracked(_ maybeX: Tracked<Float>?) -> Tracked<Float> {
|
|
if let x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_let_tracked), .init(20.0))
|
|
expectEqual(gradient(at: nil, of: optional_let_tracked), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_let_nonresilient_tracked(_ maybeX: NonresilientTracked<Float>?)
|
|
-> NonresilientTracked<Float>
|
|
{
|
|
if let x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_let_nonresilient_tracked), .init(20.0))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_let_nonresilient_tracked), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_let_nested(_ nestedMaybeX: Float??) -> Float {
|
|
if let maybeX = nestedMaybeX {
|
|
if let x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_let_nested), .init(.init(20.0)))
|
|
expectEqual(gradient(at: nil, of: optional_let_nested), .init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_let_nested_tracked(_ nestedMaybeX: Tracked<Float>??) -> Tracked<
|
|
Float
|
|
> {
|
|
if let maybeX = nestedMaybeX {
|
|
if let x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_let_nested_tracked), .init(.init(20.0)))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_let_nested_tracked), .init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_let_nested_nonresilient_tracked(
|
|
_ nestedMaybeX: NonresilientTracked<Float>??
|
|
) -> NonresilientTracked<Float> {
|
|
if let maybeX = nestedMaybeX {
|
|
if let x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_let_nested_nonresilient_tracked),
|
|
.init(.init(20.0)))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_let_nested_nonresilient_tracked),
|
|
.init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_let_generic<T: Differentiable>(_ maybeX: T?, _ defaultValue: T)
|
|
-> T
|
|
{
|
|
if let x = maybeX {
|
|
return x
|
|
}
|
|
return defaultValue
|
|
}
|
|
expectEqual(gradient(at: 10, 20, of: optional_let_generic), (.init(1.0), 0.0))
|
|
expectEqual(
|
|
gradient(at: nil, 20, of: optional_let_generic), (.init(0.0), 1.0))
|
|
|
|
expectEqual(
|
|
gradient(
|
|
at: Tracked<Float>.init(10), Tracked<Float>.init(20),
|
|
of: optional_let_generic), (.init(1.0), 0.0))
|
|
expectEqual(
|
|
gradient(at: nil, Tracked<Float>.init(20), of: optional_let_generic),
|
|
(.init(0.0), 1.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_let_nested_generic<T: Differentiable>(
|
|
_ nestedMaybeX: T??, _ defaultValue: T
|
|
) -> T {
|
|
if let maybeX = nestedMaybeX {
|
|
if let x = maybeX {
|
|
return x
|
|
}
|
|
return defaultValue
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
expectEqual(
|
|
gradient(at: 10.0, 20.0, of: optional_let_nested_generic),
|
|
(.init(.init(1.0)), 0.0))
|
|
expectEqual(
|
|
gradient(at: nil, 20, of: optional_let_nested_generic),
|
|
(.init(.init(0.0)), 1.0))
|
|
}
|
|
|
|
OptionalTests.test("Switch") {
|
|
@differentiable(reverse)
|
|
func optional_switch(_ maybeX: Float?) -> Float {
|
|
switch maybeX {
|
|
case nil: return 10
|
|
case let .some(x): return x * x
|
|
}
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_switch), .init(20.0))
|
|
expectEqual(gradient(at: nil, of: optional_switch), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_switch_tracked(_ maybeX: Tracked<Float>?) -> Tracked<Float> {
|
|
switch maybeX {
|
|
case nil: return 10
|
|
case let .some(x): return x * x
|
|
}
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_switch_tracked), .init(20.0))
|
|
expectEqual(gradient(at: nil, of: optional_switch_tracked), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_switch_nonresilient_tracked(
|
|
_ maybeX: NonresilientTracked<Float>?
|
|
) -> NonresilientTracked<Float> {
|
|
switch maybeX {
|
|
case nil: return 10
|
|
case let .some(x): return x * x
|
|
}
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_switch_nonresilient_tracked), .init(20.0))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_switch_nonresilient_tracked), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_switch_nested(_ nestedMaybeX: Float??) -> Float {
|
|
switch nestedMaybeX {
|
|
case nil: return 10
|
|
case let .some(maybeX):
|
|
switch maybeX {
|
|
case nil: return 10
|
|
case let .some(x): return x * x
|
|
}
|
|
}
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_switch_nested), .init(.init(20.0)))
|
|
expectEqual(gradient(at: nil, of: optional_switch_nested), .init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_switch_nested_tracked(_ nestedMaybeX: Tracked<Float>??)
|
|
-> Tracked<Float>
|
|
{
|
|
switch nestedMaybeX {
|
|
case nil: return 10
|
|
case let .some(maybeX):
|
|
switch maybeX {
|
|
case nil: return 10
|
|
case let .some(x): return x * x
|
|
}
|
|
}
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_switch_nested_tracked), .init(.init(20.0)))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_switch_nested_tracked), .init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_switch_nested_nonresilient_tracked(
|
|
_ nestedMaybeX: NonresilientTracked<Float>??
|
|
) -> NonresilientTracked<Float> {
|
|
switch nestedMaybeX {
|
|
case nil: return 10
|
|
case let .some(maybeX):
|
|
switch maybeX {
|
|
case nil: return 10
|
|
case let .some(x): return x * x
|
|
}
|
|
}
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_switch_nested_nonresilient_tracked),
|
|
.init(.init(20.0)))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_switch_nested_nonresilient_tracked),
|
|
.init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_switch_generic<T: Differentiable>(
|
|
_ maybeX: T?, _ defaultValue: T
|
|
) -> T {
|
|
switch maybeX {
|
|
case nil: return defaultValue
|
|
case let .some(x): return x
|
|
}
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, 20, of: optional_switch_generic), (.init(1.0), 0.0))
|
|
expectEqual(
|
|
gradient(at: nil, 20, of: optional_switch_generic), (.init(0.0), 1.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_switch_nested_generic<T: Differentiable>(
|
|
_ nestedMaybeX: T??, _ defaultValue: T
|
|
) -> T {
|
|
switch nestedMaybeX {
|
|
case nil: return defaultValue
|
|
case let .some(maybeX):
|
|
switch maybeX {
|
|
case nil: return defaultValue
|
|
case let .some(x): return x
|
|
}
|
|
}
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, 20, of: optional_switch_nested_generic),
|
|
(.init(.init(1.0)), 0.0))
|
|
expectEqual(
|
|
gradient(at: nil, 20, of: optional_switch_nested_generic),
|
|
(.init(.init(0.0)), 1.0))
|
|
}
|
|
|
|
OptionalTests.test("Optional binding: if let") {
|
|
@differentiable(reverse)
|
|
func optional_var1(_ maybeX: Float?) -> Float {
|
|
var maybeX = maybeX
|
|
if let x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_var1), .init(20.0))
|
|
expectEqual(gradient(at: nil, of: optional_var1), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var1_tracked(_ maybeX: Tracked<Float>?) -> Tracked<Float> {
|
|
var maybeX = maybeX
|
|
if let x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_var1_tracked), .init(20.0))
|
|
expectEqual(gradient(at: nil, of: optional_var1_tracked), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var1_nonresilient_tracked(_ maybeX: NonresilientTracked<Float>?)
|
|
-> NonresilientTracked<Float>
|
|
{
|
|
var maybeX = maybeX
|
|
if let x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_var1_nonresilient_tracked), .init(20.0))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_var1_nonresilient_tracked), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var1_nested(_ nestedMaybeX: Float??) -> Float {
|
|
var nestedMaybeX = nestedMaybeX
|
|
if let maybeX = nestedMaybeX {
|
|
if var x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_var1_nested), .init(.init(20.0)))
|
|
expectEqual(gradient(at: nil, of: optional_var1_nested), .init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var1_nested_tracked(_ nestedMaybeX: Tracked<Float>??)
|
|
-> Tracked<Float>
|
|
{
|
|
var nestedMaybeX = nestedMaybeX
|
|
if let maybeX = nestedMaybeX {
|
|
if var x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_var1_nested_tracked), .init(.init(20.0)))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_var1_nested_tracked), .init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var1_nested_nonresilient_tracked(
|
|
_ nestedMaybeX: NonresilientTracked<Float>??
|
|
) -> NonresilientTracked<Float> {
|
|
var nestedMaybeX = nestedMaybeX
|
|
if let maybeX = nestedMaybeX {
|
|
if var x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_var1_nested_nonresilient_tracked),
|
|
.init(.init(20.0)))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_var1_nested_nonresilient_tracked),
|
|
.init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var1_generic<T: Differentiable>(_ maybeX: T?, _ defaultValue: T)
|
|
-> T
|
|
{
|
|
var maybeX = maybeX
|
|
if let x = maybeX {
|
|
return x
|
|
}
|
|
return defaultValue
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, 20, of: optional_var1_generic), (.init(1.0), 0.0))
|
|
expectEqual(
|
|
gradient(at: nil, 20, of: optional_var1_generic), (.init(0.0), 1.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var1_nested_generic<T: Differentiable>(
|
|
_ nestedMaybeX: T??, _ defaultValue: T
|
|
) -> T {
|
|
var nestedMaybeX = nestedMaybeX
|
|
if let maybeX = nestedMaybeX {
|
|
if var x = maybeX {
|
|
return x
|
|
}
|
|
return defaultValue
|
|
}
|
|
return defaultValue
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, 20, of: optional_var1_nested_generic),
|
|
(.init(.init(1.0)), 0.0))
|
|
expectEqual(
|
|
gradient(at: nil, 20, of: optional_var1_nested_generic),
|
|
(.init(.init(0.0)), 1.0))
|
|
}
|
|
|
|
OptionalTests.test("Optional binding: if var") {
|
|
@differentiable(reverse)
|
|
func optional_var2(_ maybeX: Float?) -> Float {
|
|
if var x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_var2), .init(20.0))
|
|
expectEqual(gradient(at: nil, of: optional_var2), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var2_tracked(_ maybeX: Tracked<Float>?) -> Tracked<Float> {
|
|
if var x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_var2_tracked), .init(20.0))
|
|
expectEqual(gradient(at: nil, of: optional_var2_tracked), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var2_nonresilient_tracked(_ maybeX: NonresilientTracked<Float>?)
|
|
-> NonresilientTracked<Float>
|
|
{
|
|
if var x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_var2_nonresilient_tracked), .init(20.0))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_var2_nonresilient_tracked), .init(0.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var2_nested(_ nestedMaybeX: Float??) -> Float {
|
|
if var maybeX = nestedMaybeX {
|
|
if var x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(gradient(at: 10, of: optional_var2_nested), .init(.init(20.0)))
|
|
expectEqual(gradient(at: nil, of: optional_var2_nested), .init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var2_nested_tracked(_ nestedMaybeX: Tracked<Float>??)
|
|
-> Tracked<Float>
|
|
{
|
|
if var maybeX = nestedMaybeX {
|
|
if var x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_var2_nested_tracked), .init(.init(20.0)))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_var2_nested_tracked), .init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var2_nested_nonresilient_tracked(
|
|
_ nestedMaybeX: NonresilientTracked<Float>??
|
|
) -> NonresilientTracked<Float> {
|
|
if var maybeX = nestedMaybeX {
|
|
if var x = maybeX {
|
|
return x * x
|
|
}
|
|
return 10
|
|
}
|
|
return 10
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, of: optional_var2_nested_nonresilient_tracked),
|
|
.init(.init(20.0)))
|
|
expectEqual(
|
|
gradient(at: nil, of: optional_var2_nested_nonresilient_tracked),
|
|
.init(.init(0.0)))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var2_generic<T: Differentiable>(_ maybeX: T?, _ defaultValue: T)
|
|
-> T
|
|
{
|
|
if var x = maybeX {
|
|
return x
|
|
}
|
|
return defaultValue
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, 20, of: optional_var2_generic), (.init(1.0), 0.0))
|
|
expectEqual(
|
|
gradient(at: nil, 20, of: optional_var2_generic), (.init(0.0), 1.0))
|
|
|
|
@differentiable(reverse)
|
|
func optional_var2_nested_generic<T: Differentiable>(
|
|
_ nestedMaybeX: T??, _ defaultValue: T
|
|
) -> T {
|
|
if var maybeX = nestedMaybeX {
|
|
if var x = maybeX {
|
|
return x
|
|
}
|
|
return defaultValue
|
|
}
|
|
return defaultValue
|
|
}
|
|
expectEqual(
|
|
gradient(at: 10, 20, of: optional_var2_nested_generic),
|
|
(.init(.init(1.0)), 0.0))
|
|
expectEqual(
|
|
gradient(at: nil, 20, of: optional_var2_nested_generic),
|
|
(.init(.init(0.0)), 1.0))
|
|
}
|
|
|
|
runAllTests()
|