[ASTGen] Fix '@differentiable' attribute

* Typo: '_liner' -> '_linear'
* Accept '@differentiable(_linear)' type attribute
This commit is contained in:
Rintaro Ishizaki
2025-03-15 15:37:57 -07:00
parent d9f5001311
commit 6a7aacd86c
3 changed files with 14 additions and 3 deletions

View File

@@ -603,7 +603,7 @@ extension ASTGenVisitor {
switch text {
case "reverse": return .reverse
case "wrt", "withRespectTo": return .normal
case "_liner": return .linear
case "_linear": return .linear
case "_forward": return .forward
default: return .nonDifferentiable
}

View File

@@ -206,10 +206,17 @@ extension ASTGenVisitor {
differentiabilityLoc = nil
}
// Only 'reverse' is supported today.
guard differentiability == .reverse else {
// Only 'reverse' is formally supported today. '_linear' works for testing
// purposes. '_forward' is rejected.
switch differentiability {
case .normal, .nonDifferentiable:
// TODO: Diagnose
fatalError("Only @differentiable(reverse) is supported")
case .forward:
// TODO: Diagnose
fatalError("Only @differentiable(reverse) is supported")
case .reverse, .linear:
break
}
return .createParsed(

View File

@@ -17,6 +17,10 @@ func testDifferentiableTypeAttr(_ fn: @escaping @differentiable(reverse) (Float)
-> @differentiable(reverse) (Float) -> Float {
return fn
}
func testDifferentiableTypeAttrLinear(_ fn: @escaping @differentiable(_linear) (Float) -> Float)
-> @differentiable(_linear) (Float) -> Float {
return fn
}
@differentiable(reverse)
func testDifferentiableSimple(_ x: Float) -> Float { return x * x }