Files
swift-mirror/test/AutoDiff/SILGen/differentiability_witness_generic_signature.swift
Richard Wei e494df2ee6 [AutoDiff] Add differentiability kind to differentiability witnesses and mangle them.
Differentiability witnesses are now keyed by the original function name, the differentiability kind, and the autodiff config.

Updated SIL syntax:
```
differentiability-kind ::= 'forward' | 'reverse' | 'normal' | 'linear'
sil-differentiability-witness ::=
    'sil_differentiability_witness'
    sil-linkage?
    '[' differentiability-kind ']'
    '[' 'parameters' sil-differentiability-witness-function-index-list ']'
    '[' 'results' sil-differentiability-witness-function-index-list ']'
    generic-parameter-clause?
    sil-function-name ':' sil-type
    sil-differentiability-witness-body?
sil-instruction ::=
    'differentiability_witness_function'
    '[' sil-differentiability-witness-function-kind ']'
    '[' differentiability-kind ']'
    '[' 'parameters' sil-differentiability-witness-function-index-list ']'
    '[' 'results' sil-differentiability-witness-function-index-list ']'
    generic-parameter-clause?
    sil-function-name ':' sil-type
```
```console
sil_differentiability_witness [reverse] [parameters 0 1] [results 0] <T where T: Differentiable> @foo : <T> $(T) -> T
differentiability_witness_function [vjp] [reverse] [parameters 0] [results 0] <T where T: Differentiable> @foo : $(T) -> T
```

New mangling:
```swift
  global ::= global generic-signature? 'WJ' DIFFERENTIABILITY-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' // differentiability witness
```
```console
$s13test_mangling3fooyS2f_S2ftFWJrSpSr ---> reverse differentiability witness for test_mangling.foo(Swift.Float, Swift.Float, Swift.Float) -> Swift.Float with respect to parameters {0} and results {0}
```

Resolves rdar://74380324.
2021-02-17 18:27:42 -05:00

178 lines
9.0 KiB
Swift

// RUN: %target-swift-emit-silgen -verify -module-name main %s | %FileCheck %s
// RUN: %target-swift-emit-sil -verify -module-name main %s
// NOTE(SR-11950): SILParser crashes for SILGen round-trip.
// This file tests:
// - The "derivative generic signature" of `@differentiable` and `@derivative`
// attributes.
// - The generic signature of lowered SIL differentiability witnesses.
// Context:
// - For `@differentiable` attributes: the derivative generic signature is
// resolved from the original declaration's generic signature and additional
// `where` clause requirements.
// - For `@derivative` attributes: the derivative generic signature is the
// attributed declaration's generic signature.
import _Differentiation
//===----------------------------------------------------------------------===//
// Same-type requirements
//===----------------------------------------------------------------------===//
// Test original declaration with a generic signature and derivative generic
// signature where all generic parameters are concrete (i.e. bound to concrete
// types via same-type requirements).
struct AllConcrete<T>: Differentiable {}
extension AllConcrete {
// Original generic signature: `<T>`
// Derivative generic signature: `<T where T == Float>`
// Witness generic signature: `<T where T == Float>`
@_silgen_name("allconcrete_where_gensig_constrained")
@differentiable(reverse where T == Float)
func whereClauseGenericSignatureConstrained() -> AllConcrete {
return self
}
}
extension AllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T where T == Float> @allconcrete_where_gensig_constrained : $@convention(method) <T> (AllConcrete<T>) -> AllConcrete<T> {
// CHECK-NEXT: jvp: @allconcrete_where_gensig_constrainedSfRszlTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
// If a `@differentiable` or `@derivative` attribute satisfies two conditions:
// 1. The derivative generic signature is equal to the original generic signature.
// 2. The derivative generic signature has *all concrete* generic parameters.
//
// Then the attribute should be lowered to a SIL differentiability witness with
// *no* derivative generic signature.
extension AllConcrete where T == Float {
// Original generic signature: `<T where T == Float>`
// Derivative generic signature: `<T where T == Float>`
// Witness generic signature: none
@_silgen_name("allconcrete_original_gensig")
@differentiable(reverse)
func originalGenericSignature() -> AllConcrete {
return self
}
@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for allconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_original_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @allconcrete_original_gensigTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
// Original generic signature: `<T where T == Float>`
// Derivative generic signature: `<T where T == Float>` (explicit `where` clause)
// Witness generic signature: none
@_silgen_name("allconcrete_where_gensig")
@differentiable(reverse where T == Float)
func whereClauseGenericSignature() -> AllConcrete {
return self
}
@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: AllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for allconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @allconcrete_where_gensig : $@convention(method) (AllConcrete<Float>) -> AllConcrete<Float> {
// CHECK-NEXT: jvp: @allconcrete_where_gensigTJfSpSr : $@convention(method) (AllConcrete<Float>) -> (AllConcrete<Float>, @owned @callee_guaranteed (AllConcrete<Float>.TangentVector) -> AllConcrete<Float>.TangentVector)
// CHECK-NEXT: }
}
// Test original declaration with a generic signature and derivative generic
// signature where *not* all generic parameters are concrete.
// types via same-type requirements).
struct NotAllConcrete<T, U>: Differentiable {}
extension NotAllConcrete {
// Original generic signature: `<T, U>`
// Derivative generic signature: `<T, U where T == Float>`
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_where_gensig_constrained")
@differentiable(reverse where T == Float)
func whereClauseGenericSignatureConstrained() -> NotAllConcrete {
return self
}
}
extension NotAllConcrete where T == Float {
@derivative(of: whereClauseGenericSignatureConstrained)
func jvpWhereClauseGenericSignatureConstrained() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignatureConstrained(), { $0 })
}
}
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig_constrained
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig_constrained : $@convention(method) <T, U> (NotAllConcrete<T, U>) -> NotAllConcrete<T, U> {
// CHECK-NEXT: jvp: @notallconcrete_where_gensig_constrainedSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
extension NotAllConcrete where T == Float {
// Original generic signature: `<T, U where T == Float>`
// Derivative generic signature: `<T, U where T == Float>`
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_original_gensig")
@differentiable(reverse)
func originalGenericSignature() -> NotAllConcrete {
return self
}
@derivative(of: originalGenericSignature)
func jvpOriginalGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(originalGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for notallconcrete_original_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_original_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @notallconcrete_original_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
// Original generic signature: `<T, U where T == Float>`
// Derivative generic signature: `<T, U where T == Float>` (explicit `where` clause)
// Witness generic signature: `<T, U where T == Float>` (not all concrete)
@_silgen_name("notallconcrete_where_gensig")
@differentiable(reverse where T == Float)
func whereClauseGenericSignature() -> NotAllConcrete {
return self
}
@derivative(of: whereClauseGenericSignature)
func jvpWhereClauseGenericSignature() -> (
value: NotAllConcrete, differential: (TangentVector) -> TangentVector
) {
(whereClauseGenericSignature(), { $0 })
}
// CHECK-LABEL: // differentiability witness for notallconcrete_where_gensig
// CHECK-NEXT: sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] <T, U where T == Float> @notallconcrete_where_gensig : $@convention(method) <T, U where T == Float> (NotAllConcrete<Float, U>) -> NotAllConcrete<Float, U> {
// CHECK-NEXT: jvp: @notallconcrete_where_gensigSfRszr0_lTJfSpSr : $@convention(method) <τ_0_0, τ_0_1 where τ_0_0 == Float> (NotAllConcrete<Float, τ_0_1>) -> (NotAllConcrete<Float, τ_0_1>, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (τ_0_0) -> τ_0_1 for <NotAllConcrete<Float, τ_0_1>.TangentVector, NotAllConcrete<Float, τ_0_1>.TangentVector>)
// CHECK-NEXT: }
}