mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[ASTGen] Fix '@differentiable' attribute
* Typo: '_liner' -> '_linear' * Accept '@differentiable(_linear)' type attribute
This commit is contained in:
@@ -603,7 +603,7 @@ extension ASTGenVisitor {
|
||||
switch text {
|
||||
case "reverse": return .reverse
|
||||
case "wrt", "withRespectTo": return .normal
|
||||
case "_liner": return .linear
|
||||
case "_linear": return .linear
|
||||
case "_forward": return .forward
|
||||
default: return .nonDifferentiable
|
||||
}
|
||||
|
||||
@@ -206,10 +206,17 @@ extension ASTGenVisitor {
|
||||
differentiabilityLoc = nil
|
||||
}
|
||||
|
||||
// Only 'reverse' is supported today.
|
||||
guard differentiability == .reverse else {
|
||||
// Only 'reverse' is formally supported today. '_linear' works for testing
|
||||
// purposes. '_forward' is rejected.
|
||||
switch differentiability {
|
||||
case .normal, .nonDifferentiable:
|
||||
// TODO: Diagnose
|
||||
fatalError("Only @differentiable(reverse) is supported")
|
||||
case .forward:
|
||||
// TODO: Diagnose
|
||||
fatalError("Only @differentiable(reverse) is supported")
|
||||
case .reverse, .linear:
|
||||
break
|
||||
}
|
||||
|
||||
return .createParsed(
|
||||
|
||||
@@ -17,6 +17,10 @@ func testDifferentiableTypeAttr(_ fn: @escaping @differentiable(reverse) (Float)
|
||||
-> @differentiable(reverse) (Float) -> Float {
|
||||
return fn
|
||||
}
|
||||
func testDifferentiableTypeAttrLinear(_ fn: @escaping @differentiable(_linear) (Float) -> Float)
|
||||
-> @differentiable(_linear) (Float) -> Float {
|
||||
return fn
|
||||
}
|
||||
|
||||
@differentiable(reverse)
|
||||
func testDifferentiableSimple(_ x: Float) -> Float { return x * x }
|
||||
|
||||
Reference in New Issue
Block a user