From 6a7aacd86cf0224093b2158cf700d107d5596e5f Mon Sep 17 00:00:00 2001 From: Rintaro Ishizaki Date: Sat, 15 Mar 2025 15:37:57 -0700 Subject: [PATCH] [ASTGen] Fix '@differentiable' attribute * Typo: '_liner' -> '_linear' * Accept '@differentiable(_linear)' type attribute --- lib/ASTGen/Sources/ASTGen/DeclAttrs.swift | 2 +- lib/ASTGen/Sources/ASTGen/TypeAttrs.swift | 11 +++++++++-- test/ASTGen/autodiff.swift | 4 ++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/lib/ASTGen/Sources/ASTGen/DeclAttrs.swift b/lib/ASTGen/Sources/ASTGen/DeclAttrs.swift index 87469618149..2197c1fcbfd 100644 --- a/lib/ASTGen/Sources/ASTGen/DeclAttrs.swift +++ b/lib/ASTGen/Sources/ASTGen/DeclAttrs.swift @@ -603,7 +603,7 @@ extension ASTGenVisitor { switch text { case "reverse": return .reverse case "wrt", "withRespectTo": return .normal - case "_liner": return .linear + case "_linear": return .linear case "_forward": return .forward default: return .nonDifferentiable } diff --git a/lib/ASTGen/Sources/ASTGen/TypeAttrs.swift b/lib/ASTGen/Sources/ASTGen/TypeAttrs.swift index beb82d46e09..8a4e355f11e 100644 --- a/lib/ASTGen/Sources/ASTGen/TypeAttrs.swift +++ b/lib/ASTGen/Sources/ASTGen/TypeAttrs.swift @@ -206,10 +206,17 @@ extension ASTGenVisitor { differentiabilityLoc = nil } - // Only 'reverse' is supported today. - guard differentiability == .reverse else { + // Only 'reverse' is formally supported today. '_linear' works for testing + // purposes. '_forward' is rejected. + switch differentiability { + case .normal, .nonDifferentiable: // TODO: Diagnose fatalError("Only @differentiable(reverse) is supported") + case .forward: + // TODO: Diagnose + fatalError("Only @differentiable(reverse) is supported") + case .reverse, .linear: + break } return .createParsed( diff --git a/test/ASTGen/autodiff.swift b/test/ASTGen/autodiff.swift index 8a010d5d375..36b4f0cf6cf 100644 --- a/test/ASTGen/autodiff.swift +++ b/test/ASTGen/autodiff.swift @@ -17,6 +17,10 @@ func testDifferentiableTypeAttr(_ fn: @escaping @differentiable(reverse) (Float) -> @differentiable(reverse) (Float) -> Float { return fn } +func testDifferentiableTypeAttrLinear(_ fn: @escaping @differentiable(_linear) (Float) -> Float) + -> @differentiable(_linear) (Float) -> Float { + return fn +} @differentiable(reverse) func testDifferentiableSimple(_ x: Float) -> Float { return x * x }