Files
swift-mirror/test/AutoDiff/Parse/derivative_attr_parse.swift
ematejska fbec91a1b5 [Autodiff] Derivative Registration for the Get and Set Accessors (#32614)
* initial changes

* Add tests, undo unnecessary changes.

* Fixing up computed properties accessors and adding tests for getters.

* Adding nested type testcase

* Fixing error message for when accessor is referenced but not acutally found.

* Cleanup.

- Improve diagnostic message.
- Clean up code and tests.
- Delete unrelated nested type `@derivative` attribute tests.

* Temporarily disable class subscript setter derivative registration test.

Blocked by SR-13096.

* Adding libsyntax integration and fixing up an error message.

* Added a helper function for checking if the next token is an accessor label.

* Update utils/gyb_syntax_support/AttributeNodes.py

Co-authored-by: Dan Zheng <danielzheng@google.com>

* Update lib/Parse/ParseDecl.cpp

Co-authored-by: Dan Zheng <danielzheng@google.com>

* Add end-to-end derivative registration tests.

* NFC: run `git clang-format`.

* NFC: clean up formatting.

Re-apply `git clang-format`.

* Clarify parsing ambiguity FIXME comments.

* Adding couple of more testcases and fixing up error message for when accessor is not found on functions resolved.

* Update lib/Sema/TypeCheckAttr.cpp

Co-authored-by: Dan Zheng <danielzheng@google.com>

Co-authored-by: Dan Zheng <danielzheng@google.com>
2020-07-01 20:14:58 -07:00

135 lines
3.8 KiB
Swift

// RUN: %target-swift-frontend -parse -verify %s
/// Good
@derivative(of: sin, wrt: x) // ok
func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
@derivative(of: add, wrt: (x, y)) // ok
func vjpAdd(x: Float, y: Float)
-> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
extension AdditiveArithmetic where Self: Differentiable {
@derivative(of: +) // ok
static func vjpAdd(x: Self, y: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (x + y, { v in (v, v) })
}
}
@derivative(of: foo) // ok
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
@derivative(of: property.get) // ok
func dPropertyGetter() -> ()
@derivative(of: subscript.get) // ok
func dSubscriptGetter() -> ()
@derivative(of: subscript(_:label:).get) // ok
func dLabeledSubscriptGetter() -> ()
@derivative(of: property.set) // ok
func dPropertySetter() -> ()
@derivative(of: subscript.set) // ok
func dSubscriptSetter() -> ()
@derivative(of: subscript(_:label:).set) // ok
func dLabeledSubscriptSetter() -> ()
@derivative(of: nestedType.name) // ok
func dNestedTypeFunc() -> ()
/// Bad
// expected-error @+2 {{expected an original function name}}
// expected-error @+1 {{expected declaration}}
@derivative(of: 3)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: wrt, foo)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected a colon ':' after 'wrt'}}
@derivative(of: foo, wrt)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: foo, blah, wrt: x)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: foo, wrt: x, blah)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{unexpected ',' separator}}
@derivative(of: foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: foo, wrt: x,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: foo, foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{unexpected ',' separator}}
@derivative(of: foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// TF-1168: missing comma before `wrt:`.
// expected-error @+2 {{expected ',' separator}}
// expected-error @+1 {{expected declaration}}
@derivative(of: foo wrt: x)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
func testLocalDerivativeRegistration() {
// expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
@derivative(of: sin)
func dsin()
}
func testLocalDerivativeRegistration() {
// expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
@derivative(of: sin)
func dsin()
}
// expected-error @+2 {{expected ',' separator}}
// expected-error @+1 {{expected declaration}}
@derivative(of: nestedType.name.set)
func dNestedTypePropertySetter() -> ()