mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Specifically, I made it so that assuming our instruction is inserted into a
block already that we:
1. Return a constraint of {OwnershipKind::Any, UseLifetimeConstraint::NonLifetimeEnding}.
2. Return OwnershipKind::None for all values.
Noticed above I said that if the instruction is already inserted into a block
then we do this. The reason why is that if this is called before an instruction
is inserted into a block, we can't get access to the SILFunction that has the
information on whether or not we are in OSSA form. The only time this can happen
is if one is using these APIs from within SILBuilder since SILBuilder is the
only place where we allow this to happen. In SILBuilder, we already know whether
or not our function is in ossa or not and already does different things as
appropriate (namely in non-ossa does not call getOwnershipKind()). So we know
that if these APIs are called in such a situation, we will only be calling it if
we are in OSSA already. Given that, we just assume we are in OSSA if we do not
have a function.
To make sure that no mistakes are made as a result of that assumption, I put in
a verifier check that all values when ownership is disabled return a
OwnershipKind::None from getOwnershipKind().
The main upside to this is this means that we can write code for both
OSSA/non-OSSA and write code for non-None ownership without needing to check if
ownership is enabled.
142 lines
7.5 KiB
Plaintext
142 lines
7.5 KiB
Plaintext
// RUN: %target-sil-opt -differentiation %s | %FileCheck %s
|
|
|
|
sil_stage raw
|
|
|
|
import _Differentiation
|
|
import Swift
|
|
|
|
// Test `function_ref` instructions.
|
|
|
|
sil @basic : $@convention(thin) (Float) -> Float {
|
|
bb0(%0 : $Float):
|
|
return %0 : $Float
|
|
}
|
|
|
|
sil @test_function_ref : $@convention(thin) () -> () {
|
|
bb0:
|
|
%1 = function_ref @basic : $@convention(thin) (Float) -> Float
|
|
%2 = thin_to_thick_function %1 : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
|
|
%3 = differentiable_function [parameters 0] [results 0] %2 : $@callee_guaranteed (Float) -> Float
|
|
%void = tuple ()
|
|
return %void : $()
|
|
}
|
|
|
|
// CHECK-LABEL: sil @test_function_ref
|
|
// CHECK: [[ORIG_FN_REF:%.*]] = function_ref @basic : $@convention(thin) (Float) -> Float
|
|
// CHECK: [[ORIG_FN:%.*]] = thin_to_thick_function [[ORIG_FN_REF]]
|
|
// CHECK: [[JVP_FN_REF:%.*]] = differentiability_witness_function [jvp] [parameters 0] [results 0] @basic
|
|
// CHECK: [[JVP_FN:%.*]] = thin_to_thick_function [[JVP_FN_REF]]
|
|
// CHECK: [[VJP_FN_REF:%.*]] = differentiability_witness_function [vjp] [parameters 0] [results 0] @basic
|
|
// CHECK: [[VJP_FN:%.*]] = thin_to_thick_function [[VJP_FN_REF]]
|
|
// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_FN]] : $@callee_guaranteed (Float) -> Float with_derivative {[[JVP_FN]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[VJP_FN]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
|
|
// CHECK: }
|
|
|
|
// Test `differentiable_function_extract` instructions.
|
|
|
|
sil @test_differentiable_function_extract : $@convention(thin) (@differentiable @callee_guaranteed (Float) -> Float) -> () {
|
|
bb0(%0 : $@differentiable @callee_guaranteed (Float) -> Float):
|
|
%1 = differentiable_function_extract [original] %0 : $@differentiable @callee_guaranteed (Float) -> Float
|
|
%2 = differentiable_function [parameters 0] [results 0] %1 : $@callee_guaranteed (Float) -> Float
|
|
%void = tuple ()
|
|
return %void : $()
|
|
}
|
|
|
|
// CHECK-LABEL: sil @test_differentiable_function_extract
|
|
// CHECK: bb0([[ARG:%.*]] : $@differentiable @callee_guaranteed (Float) -> Float):
|
|
// CHECK: [[ORIG_FN:%.*]] = differentiable_function_extract [original] [[ARG]]
|
|
// CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[ARG]]
|
|
// CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[ARG]]
|
|
// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_FN]] : {{.*}} with_derivative {[[JVP_FN]] : {{.*}}, [[VJP_FN]] : {{.*}}}
|
|
// CHECK: }
|
|
|
|
// Test `witness_method` instructions.
|
|
|
|
protocol Protocol {
|
|
@differentiable(wrt: x)
|
|
func method(_ x: Float) -> Float
|
|
}
|
|
|
|
sil @test_witness_method : $@convention(thin) <T where T : Protocol> (@in_guaranteed T) -> () {
|
|
bb0(%0 : $*T):
|
|
%1 = witness_method $T, #Protocol.method : <Self where Self : Protocol> (Self) -> (Float) -> Float : $@convention(witness_method: Protocol) <τ_0_0 where τ_0_0 : Protocol> (Float, @in_guaranteed τ_0_0) -> Float
|
|
%2 = differentiable_function [parameters 0] [results 0] %1 : $@convention(witness_method: Protocol) <τ_0_0 where τ_0_0 : Protocol> (Float, @in_guaranteed τ_0_0) -> Float
|
|
%void = tuple ()
|
|
return %void : $()
|
|
}
|
|
|
|
// CHECK-LABEL: sil @test_witness_method
|
|
// CHECK: [[ORIG_FN:%.*]] = witness_method $T, #Protocol.method
|
|
// CHECK: [[JVP_FN:%.*]] = witness_method $T, #Protocol.method!jvp.SU
|
|
// CHECK: [[VJP_FN:%.*]] = witness_method $T, #Protocol.method!vjp.SU
|
|
// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_FN]] : {{.*}} with_derivative {[[JVP_FN]] : {{.*}}, [[VJP_FN]] : {{.*}}}
|
|
// CHECK: }
|
|
|
|
sil @test_witness_method_partial_apply : $@convention(thin) <T where T : Protocol> (@in_guaranteed T) -> () {
|
|
bb0(%0 : $*T):
|
|
%1 = witness_method $T, #Protocol.method : <Self where Self : Protocol> (Self) -> (Float) -> Float : $@convention(witness_method: Protocol) <τ_0_0 where τ_0_0 : Protocol> (Float, @in_guaranteed τ_0_0) -> Float
|
|
%2 = partial_apply [callee_guaranteed] %1<T>(%0) : $@convention(witness_method: Protocol) <τ_0_0 where τ_0_0 : Protocol> (Float, @in_guaranteed τ_0_0) -> Float
|
|
%3 = differentiable_function [parameters 0] [results 0] %2 : $@callee_guaranteed (Float) -> Float
|
|
%void = tuple ()
|
|
return %void : $()
|
|
}
|
|
|
|
// CHECK-LABEL: sil @test_witness_method_partial_apply
|
|
// CHECK: bb0([[ARG:%.*]] : $*T):
|
|
// CHECK: [[ORIG_FN:%.*]] = witness_method $T, #Protocol.method
|
|
// CHECK: [[ARGCOPY1:%.*]] = alloc_stack $T
|
|
// CHECK: copy_addr [[ARG]] to [initialization] [[ARGCOPY1]] : $*T
|
|
// CHECK: [[ARGCOPY2:%.*]] = alloc_stack $T
|
|
// CHECK: copy_addr [[ARG]] to [initialization] [[ARGCOPY2]] : $*T
|
|
// CHECK: [[ORIG_FN_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[ORIG_FN]]<T>([[ARG]])
|
|
// CHECK: [[JVP_FN:%.*]] = witness_method $T, #Protocol.method!jvp.SU
|
|
// CHECK: [[JVP_FN_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[JVP_FN]]<T>([[ARGCOPY1]])
|
|
// CHECK: [[VJP_FN:%.*]] = witness_method $T, #Protocol.method!vjp.SU
|
|
// CHECK: [[VJP_FN_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_FN]]<T>([[ARGCOPY2]])
|
|
// CHECK: dealloc_stack [[ARGCOPY2]]
|
|
// CHECK: dealloc_stack [[ARGCOPY1]]
|
|
// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_FN_PARTIALLY_APPLIED]] : {{.*}} with_derivative {[[JVP_FN_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_FN_PARTIALLY_APPLIED]] : {{.*}}}
|
|
// CHECK: }
|
|
|
|
// Test `class_method` instructions.
|
|
|
|
class Class {
|
|
@differentiable(wrt: x)
|
|
func method(_ x: Float) -> Float
|
|
}
|
|
|
|
sil @test_class_method : $@convention(thin) <T where T : Class> (@guaranteed Class) -> () {
|
|
bb0(%0 : $Class):
|
|
%1 = class_method %0 : $Class, #Class.method : (Class) -> (Float) -> Float, $@convention(method) (Float, @guaranteed Class) -> Float
|
|
%2 = differentiable_function [parameters 0] [results 0] %1 : $@convention(method) (Float, @guaranteed Class) -> Float
|
|
%void = tuple ()
|
|
return %void : $()
|
|
}
|
|
|
|
// CHECK-LABEL: sil @test_class_method
|
|
// CHECK: bb0([[ARG:%.*]] : $Class):
|
|
// CHECK: [[ORIG_FN:%.*]] = class_method [[ARG]] : $Class, #Class.method
|
|
// CHECK: [[JVP_FN:%.*]] = class_method [[ARG]] : $Class, #Class.method!jvp.SU
|
|
// CHECK: [[VJP_FN:%.*]] = class_method [[ARG]] : $Class, #Class.method!vjp.SU
|
|
// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_FN]] : {{.*}} with_derivative {[[JVP_FN]] : {{.*}}, [[VJP_FN]] : {{.*}}}
|
|
// CHECK: }
|
|
|
|
sil @test_class_method_partial_apply : $@convention(thin) <T where T : Class> (@guaranteed Class) -> () {
|
|
bb0(%0 : $Class):
|
|
%1 = class_method %0 : $Class, #Class.method : (Class) -> (Float) -> Float, $@convention(method) (Float, @guaranteed Class) -> Float
|
|
%2 = partial_apply [callee_guaranteed] %1(%0) : $@convention(method) (Float, @guaranteed Class) -> Float
|
|
%3 = differentiable_function [parameters 0] [results 0] %2 : $@callee_guaranteed (Float) -> Float
|
|
%void = tuple ()
|
|
return %void : $()
|
|
}
|
|
|
|
// CHECK-LABEL: sil @test_class_method_partial_apply
|
|
// CHECK: bb0([[ARG:%.*]] : $Class):
|
|
// CHECK: [[ORIG_FN:%.*]] = class_method [[ARG]] : $Class, #Class.method
|
|
// CHECK: [[ORIG_FN_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[ORIG_FN]]([[ARG]])
|
|
// CHECK: [[JVP_FN:%.*]] = class_method [[ARG]] : $Class, #Class.method!jvp.SU
|
|
// CHECK: [[JVP_FN_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[JVP_FN]]([[ARG]])
|
|
// CHECK: [[VJP_FN:%.*]] = class_method [[ARG]] : $Class, #Class.method!vjp.SU
|
|
// CHECK: [[VJP_FN_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_FN]]([[ARG]])
|
|
// CHECK: differentiable_function [parameters 0] [results 0] [[ORIG_FN_PARTIALLY_APPLIED]] : {{.*}} with_derivative {[[JVP_FN_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_FN_PARTIALLY_APPLIED]] : {{.*}}}
|
|
// CHECK: }
|