mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[ASTGen] Fix '@differentiable' attribute
* Typo: '_liner' -> '_linear' * Accept '@differentiable(_linear)' type attribute
This commit is contained in:
@@ -603,7 +603,7 @@ extension ASTGenVisitor {
|
|||||||
switch text {
|
switch text {
|
||||||
case "reverse": return .reverse
|
case "reverse": return .reverse
|
||||||
case "wrt", "withRespectTo": return .normal
|
case "wrt", "withRespectTo": return .normal
|
||||||
case "_liner": return .linear
|
case "_linear": return .linear
|
||||||
case "_forward": return .forward
|
case "_forward": return .forward
|
||||||
default: return .nonDifferentiable
|
default: return .nonDifferentiable
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -206,10 +206,17 @@ extension ASTGenVisitor {
|
|||||||
differentiabilityLoc = nil
|
differentiabilityLoc = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only 'reverse' is supported today.
|
// Only 'reverse' is formally supported today. '_linear' works for testing
|
||||||
guard differentiability == .reverse else {
|
// purposes. '_forward' is rejected.
|
||||||
|
switch differentiability {
|
||||||
|
case .normal, .nonDifferentiable:
|
||||||
// TODO: Diagnose
|
// TODO: Diagnose
|
||||||
fatalError("Only @differentiable(reverse) is supported")
|
fatalError("Only @differentiable(reverse) is supported")
|
||||||
|
case .forward:
|
||||||
|
// TODO: Diagnose
|
||||||
|
fatalError("Only @differentiable(reverse) is supported")
|
||||||
|
case .reverse, .linear:
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
return .createParsed(
|
return .createParsed(
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ func testDifferentiableTypeAttr(_ fn: @escaping @differentiable(reverse) (Float)
|
|||||||
-> @differentiable(reverse) (Float) -> Float {
|
-> @differentiable(reverse) (Float) -> Float {
|
||||||
return fn
|
return fn
|
||||||
}
|
}
|
||||||
|
func testDifferentiableTypeAttrLinear(_ fn: @escaping @differentiable(_linear) (Float) -> Float)
|
||||||
|
-> @differentiable(_linear) (Float) -> Float {
|
||||||
|
return fn
|
||||||
|
}
|
||||||
|
|
||||||
@differentiable(reverse)
|
@differentiable(reverse)
|
||||||
func testDifferentiableSimple(_ x: Float) -> Float { return x * x }
|
func testDifferentiableSimple(_ x: Float) -> Float { return x * x }
|
||||||
|
|||||||
Reference in New Issue
Block a user