mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Compiler: - Add `Forward` and `Reverse` to `DifferentiabilityKind`. - Expand `DifferentiabilityMask` in `ExtInfo` to 3 bits so that it now holds all 4 cases of `DifferentiabilityKind`. - Parse `@differentiable(reverse)` and `@differentiable(_forward)` declaration attributes and type attributes. - Emit a warning for `@differentiable` without `reverse`. - Emit an error for `@differentiable(_forward)`. - Rename `@differentiable(linear)` to `@differentiable(_linear)`. - Make `@differentiable(reverse)` type lowering go through today's `@differentiable` code path. We will specialize it to reverse-mode in a follow-up patch. ABI: - Add `Forward` and `Reverse` to `FunctionMetadataDifferentiabilityKind`. - Extend `TargetFunctionTypeFlags` by 1 bit to store the highest bit of differentiability kind (linear). Note that there is a 2-bit gap in `DifferentiabilityMask` which is reserved for `AsyncMask` and `ConcurrentMask`; `AsyncMask` is ABI-stable so we cannot change that. _Differentiation module: - Replace all occurrences of `@differentiable` with `@differentiable(reverse)`. - Delete `_transpose(of:)`. Resolves rdar://69980056.
25 lines
1.9 KiB
Swift
25 lines
1.9 KiB
Swift
// RUN: %target-build-swift -Osize %s
|
|
|
|
// SR-12732: Fix `partial_apply` optimization.
|
|
|
|
// Do not rewrite `partial_apply` to `thin_to_thick_function` if the specialized
|
|
// callee is not `@convention(thin)`.
|
|
|
|
// FIXME(SR-13021): Disabled due to flakiness on Linux, likely related to TF-1197.
|
|
// REQUIRES: SR13021
|
|
|
|
import DifferentiationUnittest
|
|
|
|
func callback(_ x: inout Tracked<Float>.TangentVector) {}
|
|
|
|
@differentiable(reverse)
|
|
func caller(_ x: Tracked<Float>) -> Tracked<Float> {
|
|
return x.withDerivative(callback)
|
|
}
|
|
|
|
// SIL verification failed: operand of thin_to_thick_function must be thin: opFTy->getRepresentation() == SILFunctionType::Representation::Thin
|
|
// Verifying instruction:
|
|
// // function_ref specialized Differentiable._vjpWithDerivative(_:)
|
|
// %10 = function_ref @$s16_Differentiation14DifferentiablePAAE18_vjpWithDerivativeyx5value_13TangentVectorQzAGc8pullbacktyAGzcF0A8Unittest7TrackedVySfG_Tg5 : $@convention(method) (@guaranteed @callee_guaranteed @substituted <τ_0_0> (@inout τ_0_0) -> () for <Tracked<Float>>, @in_guaranteed Tracked<Float>) -> (@out Tracked<Float>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Tracked<Float>, Tracked<Float>>) // user: %11
|
|
// -> %11 = thin_to_thick_function %10 : $@convention(method) (@guaranteed @callee_guaranteed @substituted <τ_0_0> (@inout τ_0_0) -> () for <Tracked<Float>>, @in_guaranteed Tracked<Float>) -> (@out Tracked<Float>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Tracked<Float>, Tracked<Float>>) to $@callee_guaranteed (@guaranteed @callee_guaranteed @substituted <τ_0_0> (@inout τ_0_0) -> () for <Tracked<Float>>, @in_guaranteed Tracked<Float>) -> (@out Tracked<Float>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Tracked<Float>, Tracked<Float>>) // user: %12
|