mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
179 lines
9.0 KiB
Swift
179 lines
9.0 KiB
Swift
// RUN: %target-swift-emit-silgen -verify -module-name main %s | %FileCheck %s
|
|
// RUN: %target-swift-emit-sil -verify -module-name main %s
|
|
|
|
// NOTE: SILParser crashes for SILGen round-trip
|
|
// (https://github.com/apple/swift/issues/54370).
|
|
|
|
// This file tests:
|
|
// - The "derivative generic signature" of `@differentiable` and `@derivative`
|
|
// attributes.
|
|
// - The generic signature of lowered SIL differentiability witnesses.
|
|
|
|
// Context:
|
|
// - For `@differentiable` attributes: the derivative generic signature is
|
|
// resolved from the original declaration's generic signature and additional
|
|
// `where` clause requirements.
|
|
// - For `@derivative` attributes: the derivative generic signature is the
|
|
// attributed declaration's generic signature.
|
|
|
|
import _Differentiation
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Same-type requirements
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Test original declaration with a generic signature and derivative generic
|
|
// signature where all generic parameters are concrete (i.e. bound to concrete
|
|
// types via same-type requirements).
|
|
|
|
struct AllConcrete<T>: Differentiable {}
|
|
|
|
extension AllConcrete {
|
|
// Original generic signature: `<T>`
|
|
// Derivative generic signature: `<T where T == Float>`
|
|
// Witness generic signature: `<T where T == Float>`
|
|
@_silgen_name("allconcrete_where_gensig_constrained")
|
|
@differentiable(reverse where T == Float)
|
|
func whereClauseGenericSignatureConstrained() -> AllConcrete {
|
|
return self
|
|
}
|
|
}
|
|
extension AllConcrete where T == Float {
|
|
@derivative(of: whereClauseGenericSignatureConstrained)
|
|
func jvpWhereClauseGenericSignatureConstrained() -> (
|
|
value: AllConcrete, differential: (TangentVector) -> TangentVector
|
|
) {
|
|
(whereClauseGenericSignatureConstrained(), { $0 })
|
|
}
|
|
}
|
|
|
|
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained
|
|
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T where T == Float> @allconcrete_where_gensig_constrained : $@convention(method) <T> (AllConcrete<T>) -> AllConcrete<T> {
|
|
// CHECK-NEXT: jvp: @allconcrete_where_gensig_constrainedSfRszlTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
|
|
// CHECK-NEXT: }
|
|
|
|
// If a `@differentiable` or `@derivative` attribute satisfies two conditions:
|
|
// 1. The derivative generic signature is equal to the original generic signature.
|
|
// 2. The derivative generic signature has *all concrete* generic parameters.
|
|
//
|
|
// Then the attribute should be lowered to a SIL differentiability witness with
|
|
// *no* derivative generic signature.
|
|
|
|
extension AllConcrete where T == Float {
|
|
// Original generic signature: `<T where T == Float>`
|
|
// Derivative generic signature: `<T where T == Float>`
|
|
// Witness generic signature: none
|
|
@_silgen_name("allconcrete_original_gensig")
|
|
@differentiable(reverse)
|
|
func originalGenericSignature() -> AllConcrete {
|
|
return self
|
|
}
|
|
|
|
@derivative(of: originalGenericSignature)
|
|
func jvpOriginalGenericSignature() -> (
|
|
value: AllConcrete, differential: (TangentVector) -> TangentVector
|
|
) {
|
|
(originalGenericSignature(), { $0 })
|
|
}
|
|
|
|
// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig
|
|
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
|
|
// CHECK-NEXT: jvp: @allconcrete_original_gensigTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
|
|
// CHECK-NEXT: }
|
|
|
|
// Original generic signature: `<T where T == Float>`
|
|
// Derivative generic signature: `<T where T == Float>` (explicit `where` clause)
|
|
// Witness generic signature: none
|
|
@_silgen_name("allconcrete_where_gensig")
|
|
@differentiable(reverse where T == Float)
|
|
func whereClauseGenericSignature() -> AllConcrete {
|
|
return self
|
|
}
|
|
|
|
@derivative(of: whereClauseGenericSignature)
|
|
func jvpWhereClauseGenericSignature() -> (
|
|
value: AllConcrete, differential: (TangentVector) -> TangentVector
|
|
) {
|
|
(whereClauseGenericSignature(), { $0 })
|
|
}
|
|
|
|
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig
|
|
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
|
|
// CHECK-NEXT: jvp: @allconcrete_where_gensigTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
|
|
// CHECK-NEXT: }
|
|
}
|
|
|
|
// Test original declaration with a generic signature and derivative generic
|
|
// signature where *not* all generic parameters are concrete.
|
|
// types via same-type requirements).
|
|
|
|
struct NotAllConcrete<T, U>: Differentiable {}
|
|
|
|
extension NotAllConcrete {
|
|
// Original generic signature: `<T, U>`
|
|
// Derivative generic signature: `<T, U where T == Float>`
|
|
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
|
|
@_silgen_name("notallconcrete_where_gensig_constrained")
|
|
@differentiable(reverse where T == Float)
|
|
func whereClauseGenericSignatureConstrained() -> NotAllConcrete {
|
|
return self
|
|
}
|
|
}
|
|
extension NotAllConcrete where T == Float {
|
|
@derivative(of: whereClauseGenericSignatureConstrained)
|
|
func jvpWhereClauseGenericSignatureConstrained() -> (
|
|
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
|
|
) {
|
|
(whereClauseGenericSignatureConstrained(), { $0 })
|
|
}
|
|
}
|
|
|
|
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained
|
|
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig_constrained : $@convention(method) <T, U> (NotAllConcrete<T, U>) -> NotAllConcrete<T, U> {
|
|
// CHECK-NEXT: jvp: @notallconcrete_where_gensig_constrainedSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
|
|
// CHECK-NEXT: }
|
|
|
|
extension NotAllConcrete where T == Float {
|
|
// Original generic signature: `<T, U where T == Float>`
|
|
// Derivative generic signature: `<T, U where T == Float>`
|
|
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
|
|
@_silgen_name("notallconcrete_original_gensig")
|
|
@differentiable(reverse)
|
|
func originalGenericSignature() -> NotAllConcrete {
|
|
return self
|
|
}
|
|
|
|
@derivative(of: originalGenericSignature)
|
|
func jvpOriginalGenericSignature() -> (
|
|
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
|
|
) {
|
|
(originalGenericSignature(), { $0 })
|
|
}
|
|
|
|
// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig
|
|
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_original_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
|
|
// CHECK-NEXT: jvp: @notallconcrete_original_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
|
|
// CHECK-NEXT: }
|
|
|
|
// Original generic signature: `<T, U where T == Float>`
|
|
// Derivative generic signature: `<T, U where T == Float>` (explicit `where` clause)
|
|
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
|
|
@_silgen_name("notallconcrete_where_gensig")
|
|
@differentiable(reverse where T == Float)
|
|
func whereClauseGenericSignature() -> NotAllConcrete {
|
|
return self
|
|
}
|
|
|
|
@derivative(of: whereClauseGenericSignature)
|
|
func jvpWhereClauseGenericSignature() -> (
|
|
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
|
|
) {
|
|
(whereClauseGenericSignature(), { $0 })
|
|
}
|
|
|
|
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig
|
|
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
|
|
// CHECK-NEXT: jvp: @notallconcrete_where_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
|
|
// CHECK-NEXT: }
|
|
}
|