[ASTGen] Generate AutoDiff attributes

`@differentiable`, `@derivative` and `@transpose`
This commit is contained in:
Rintaro Ishizaki
2025-02-24 18:12:42 -08:00
parent cc145482de
commit 017c0d98ec
8 changed files with 585 additions and 12 deletions

View File

@@ -33,6 +33,7 @@ template<typename T> class ArrayRef;
}
namespace swift {
enum class AccessorKind;
class AvailabilityDomain;
class Argument;
class ASTContext;
@@ -44,6 +45,7 @@ class DeclNameLoc;
class DeclNameRef;
class DiagnosticArgument;
class DiagnosticEngine;
enum class DifferentiabilityKind : uint8_t;
class Fingerprint;
class Identifier;
class IfConfigClauseRangeInfo;
@@ -55,6 +57,7 @@ enum class MacroRole : uint32_t;
class MacroIntroducedDeclName;
enum class MacroIntroducedDeclNameKind;
enum class ParamSpecifier : uint8_t;
class ParsedAutoDiffParameter;
enum class PlatformKind : uint8_t;
class ProtocolConformanceRef;
class RegexLiteralPatternFeature;
@@ -502,6 +505,13 @@ struct BridgedPatternBindingEntry {
BridgedNullablePatternBindingInitializer initContext;
};
enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedAccessorKind {
#define ACCESSOR(ID) BridgedAccessorKind##ID,
#include "swift/AST/AccessorKinds.def"
};
swift::AccessorKind unbridged(BridgedAccessorKind kind);
//===----------------------------------------------------------------------===//
// MARK: Diagnostic Engine
//===----------------------------------------------------------------------===//
@@ -745,6 +755,59 @@ struct BridgedAvailabilityDomain {
bool isNull() const { return opaque == nullptr; };
};
//===----------------------------------------------------------------------===//
// MARK: AutoDiff
//===----------------------------------------------------------------------===//
enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedDifferentiabilityKind {
BridgedDifferentiabilityKindNonDifferentiable = 0,
BridgedDifferentiabilityKindForward = 1,
BridgedDifferentiabilityKindReverse = 2,
BridgedDifferentiabilityKindNormal = 3,
BridgedDifferentiabilityKindLinear = 4,
};
swift::DifferentiabilityKind unbridged(BridgedDifferentiabilityKind cKind);
class BridgedParsedAutoDiffParameter {
private:
BridgedSourceLoc loc;
enum Kind {
Named,
Ordered,
Self,
} kind;
union Value {
BridgedIdentifier name;
unsigned index;
Value(BridgedIdentifier name) : name(name) {}
Value(unsigned index) : index(index) {}
Value() : name() {}
} value;
BridgedParsedAutoDiffParameter(BridgedSourceLoc loc, Kind kind, Value value)
: loc(loc), kind(kind), value(value) {}
public:
SWIFT_NAME("forNamed(_:loc:)")
static BridgedParsedAutoDiffParameter forNamed(BridgedIdentifier name,
BridgedSourceLoc loc) {
return BridgedParsedAutoDiffParameter(loc, Kind::Named, name);
}
SWIFT_NAME("forOrdered(_:loc:)")
static BridgedParsedAutoDiffParameter forOrdered(size_t index,
BridgedSourceLoc loc) {
return BridgedParsedAutoDiffParameter(loc, Kind::Ordered, index);
}
SWIFT_NAME("forSelf(loc:)")
static BridgedParsedAutoDiffParameter forSelf(BridgedSourceLoc loc) {
return BridgedParsedAutoDiffParameter(loc, Kind::Self, {});
}
swift::ParsedAutoDiffParameter unbridged() const;
};
//===----------------------------------------------------------------------===//
// MARK: DeclAttributes
//===----------------------------------------------------------------------===//
@@ -879,6 +942,30 @@ BridgedCustomAttr BridgedCustomAttr_createParsed(
BridgedNullableCustomAttributeInitializer cInitContext,
BridgedNullableArgumentList cArgumentList);
SWIFT_NAME("BridgedDerivativeAttr.createParsed(_:atLoc:range:baseType:"
"originalName:originalNameLoc:accessorKind:params:)")
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedAccessorKind cAccessorKind, BridgedArrayRef cParams);
SWIFT_NAME("BridgedDerivativeAttr.createParsed(_:atLoc:range:baseType:"
"originalName:originalNameLoc:params:)")
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedArrayRef cParams);
SWIFT_NAME("BridgedDifferentiableAttr.createParsed(_:atLoc:range:kind:params:"
"genericWhereClause:)")
BridgedDifferentiableAttr BridgedDifferentiableAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedDifferentiabilityKind cKind,
BridgedArrayRef cParams,
BridgedNullableTrailingWhereClause cGenericWhereClause);
SWIFT_NAME("BridgedDocumentationAttr.createParsed(_:atLoc:range:metadata:"
"accessLevel:)")
BridgedDocumentationAttr BridgedDocumentationAttr_createParsed(
@@ -1260,6 +1347,15 @@ BridgedSILGenNameAttr BridgedSILGenNameAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedStringRef cName, bool isRaw);
SWIFT_NAME(
"BridgedTransposeAttr.createParsed(_:atLoc:range:baseType:originalName:"
"originalNameLoc:params:)")
BridgedTransposeAttr BridgedTransposeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedArrayRef cParams);
SWIFT_NAME(
"BridgedUnavailableFromAsyncAttr.createParsed(_:atLoc:range:message:)")
BridgedUnavailableFromAsyncAttr BridgedUnavailableFromAsyncAttr_createParsed(
@@ -1285,11 +1381,6 @@ enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedStaticSpelling {
BridgedStaticSpellingClass
};
enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedAccessorKind {
#define ACCESSOR(ID) BridgedAccessorKind##ID,
#include "swift/AST/AccessorKinds.def"
};
struct BridgedAccessorRecord {
BridgedSourceLoc lBraceLoc;
BridgedArrayRef accessors;
@@ -2438,6 +2529,13 @@ enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedExecutionTypeAttrExecutionKind {
BridgedExecutionTypeAttrExecutionKind_Caller
};
SWIFT_NAME("BridgedDifferentiableTypeAttr.createParsed(_:atLoc:nameLoc:"
"parensRange:kind:kindLoc:)")
BridgedDifferentiableTypeAttr BridgedDifferentiableTypeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
BridgedDifferentiabilityKind cKind, BridgedSourceLoc cKindLoc);
SWIFT_NAME("BridgedExecutionTypeAttr.createParsed(_:atLoc:nameLoc:parensRange:"
"behavior:behaviorLoc:)")
BridgedExecutionTypeAttr BridgedExecutionTypeAttr_createParsed(

View File

@@ -14,12 +14,47 @@
#include "swift/AST/ASTContext.h"
#include "swift/AST/Attr.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Expr.h"
#include "swift/AST/Identifier.h"
#include "swift/Basic/Assertions.h"
using namespace swift;
//===----------------------------------------------------------------------===//
// MARK: AutoDiff
//===----------------------------------------------------------------------===//
DifferentiabilityKind unbridged(BridgedDifferentiabilityKind cKind) {
switch (cKind) {
case BridgedDifferentiabilityKindNonDifferentiable:
return DifferentiabilityKind::NonDifferentiable;
case BridgedDifferentiabilityKindForward:
return DifferentiabilityKind::Forward;
case BridgedDifferentiabilityKindReverse:
return DifferentiabilityKind::Reverse;
case BridgedDifferentiabilityKindNormal:
return DifferentiabilityKind::Normal;
case BridgedDifferentiabilityKindLinear:
return DifferentiabilityKind::Linear;
}
llvm_unreachable("unhandled enum value");
}
ParsedAutoDiffParameter BridgedParsedAutoDiffParameter::unbridged() const {
switch (kind) {
case Kind::Named:
return ParsedAutoDiffParameter::getNamedParameter(loc.unbridged(),
value.name.unbridged());
case Kind::Ordered:
return ParsedAutoDiffParameter::getOrderedParameter(loc.unbridged(),
value.index);
case Kind::Self:
return ParsedAutoDiffParameter::getSelfParameter(loc.unbridged());
}
llvm_unreachable("unhandled enum value");
}
//===----------------------------------------------------------------------===//
// MARK: DeclAttributes
//===----------------------------------------------------------------------===//
@@ -221,6 +256,62 @@ BridgedCustomAttr BridgedCustomAttr_createParsed(
cInitContext.unbridged(), cArgumentList.unbridged());
}
BridgedDerivativeAttr BridgedDerivativeAttr_createParsedImpl(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
std::optional<BridgedAccessorKind> cAccessorKind, BridgedArrayRef cParams) {
std::optional<AccessorKind> accessorKind;
if (cAccessorKind)
accessorKind = unbridged(*cAccessorKind);
SmallVector<ParsedAutoDiffParameter, 2> params;
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
params.push_back(elem.unbridged());
return DerivativeAttr::create(cContext.unbridged(),
/*implicit=*/false, cAtLoc.unbridged(),
cRange.unbridged(), cBaseType.unbridged(),
DeclNameRefWithLoc{cOriginalName.unbridged(),
cOriginalNameLoc.unbridged(),
accessorKind},
params);
}
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedAccessorKind cAccessorKind, BridgedArrayRef cParams) {
return BridgedDerivativeAttr_createParsedImpl(
cContext, cAtLoc, cRange, cBaseType, cOriginalName, cOriginalNameLoc,
cAccessorKind, cParams);
}
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedArrayRef cParams) {
return BridgedDerivativeAttr_createParsedImpl(
cContext, cAtLoc, cRange, cBaseType, cOriginalName, cOriginalNameLoc,
/*cAccessorKind=*/std::nullopt, cParams);
}
BridgedDifferentiableAttr BridgedDifferentiableAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedDifferentiabilityKind cKind,
BridgedArrayRef cParams,
BridgedNullableTrailingWhereClause cGenericWhereClause) {
SmallVector<ParsedAutoDiffParameter, 2> params;
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
params.push_back(elem.unbridged());
return DifferentiableAttr::create(cContext.unbridged(), /*implicit=*/false,
cAtLoc.unbridged(), cRange.unbridged(),
unbridged(cKind), params,
cGenericWhereClause.unbridged());
}
BridgedDynamicReplacementAttr BridgedDynamicReplacementAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceLoc cAttrNameLoc, BridgedSourceLoc cLParenLoc,
@@ -752,6 +843,24 @@ BridgedSILGenNameAttr BridgedSILGenNameAttr_createParsed(
cRange.unbridged(), /*Implicit=*/false);
}
BridgedTransposeAttr BridgedTransposeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
BridgedArrayRef cParams) {
SmallVector<ParsedAutoDiffParameter, 2> params;
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
params.push_back(elem.unbridged());
return TransposeAttr::create(
cContext.unbridged(),
/*implicit=*/false, cAtLoc.unbridged(), cRange.unbridged(),
cBaseType.unbridged(),
DeclNameRefWithLoc{cOriginalName.unbridged(), cOriginalNameLoc.unbridged(),
/*AccessorKind=*/std::nullopt},
params);
}
BridgedUnavailableFromAsyncAttr BridgedUnavailableFromAsyncAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceRange cRange, BridgedStringRef cMessage) {

View File

@@ -111,7 +111,7 @@ static StaticSpellingKind unbridged(BridgedStaticSpelling kind) {
return static_cast<StaticSpellingKind>(kind);
}
static AccessorKind unbridged(BridgedAccessorKind kind) {
AccessorKind unbridged(BridgedAccessorKind kind) {
return static_cast<AccessorKind>(kind);
}

View File

@@ -78,6 +78,15 @@ BridgedConventionTypeAttr BridgedConventionTypeAttr_createParsed(
{cClangType.unbridged(), cClangTypeLoc.unbridged()});
}
BridgedDifferentiableTypeAttr BridgedDifferentiableTypeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
BridgedDifferentiabilityKind cKind, BridgedSourceLoc cKindLoc) {
return new (cContext.unbridged()) DifferentiableTypeAttr(
cAtLoc.unbridged(), cNameLoc.unbridged(), cParensRange.unbridged(),
{unbridged(cKind), cKindLoc.unbridged()});
}
BridgedExecutionTypeAttr BridgedExecutionTypeAttr_createParsed(
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,

View File

@@ -126,9 +126,9 @@ extension ASTGenVisitor {
case .cDecl:
return handle(self.generateCDeclAttr(attribute: node)?.asDeclAttribute)
case .derivative:
fatalError("unimplemented")
return handle(self.generateDerivativeAttr(attribute: node)?.asDeclAttribute)
case .differentiable:
fatalError("unimplemented")
return handle(self.generateDifferentiableAttr(attribute: node)?.asDeclAttribute)
case .dynamicReplacement:
return handle(self.generateDynamicReplacementAttr(attribute: node)?.asDeclAttribute)
case .documentation:
@@ -182,7 +182,7 @@ extension ASTGenVisitor {
case .swiftNativeObjCRuntimeBase:
return handle(self.generateSwiftNativeObjCRuntimeBaseAttr(attribute: node)?.asDeclAttribute)
case .transpose:
fatalError("unimplemented")
return handle(self.generateTransposeAttr(attribute: node)?.asDeclAttribute)
case .typeEraser:
fatalError("unimplemented")
case .unavailableFromAsync:
@@ -553,6 +553,176 @@ extension ASTGenVisitor {
)
}
struct GeneratedDerivativeOriginalDecl {
var baseType: BridgedTypeRepr?
var declName: BridgedDeclNameRef
var declNameLoc: BridgedDeclNameLoc
}
func generateDerivativeOriginalDecl(expr: ExprSyntax) -> GeneratedDerivativeOriginalDecl? {
var baseType: BridgedTypeRepr?
var declName: BridgedDeclNameRef
var declNameLoc: BridgedDeclNameLoc
if let declrefExpr = expr.as(DeclReferenceExprSyntax.self) {
baseType = nil
(declName, declNameLoc) = self.generateDeclNameRef(declReferenceExpr: declrefExpr)
} else if let memberExpr = expr.as(MemberAccessExprSyntax.self),
let baseExpr = memberExpr.base {
guard let _baseType = self.generateTypeRepr(expr: baseExpr) else {
// TODO: Diagnose.
fatalError("invalid type expression for @derivative qualified decl name")
}
baseType = _baseType
(declName, declNameLoc) = self.generateDeclNameRef(declReferenceExpr: memberExpr.declName)
} else {
// TODO: Diagnosse.
fatalError("invalid expression for @derivative original decl name")
}
return GeneratedDerivativeOriginalDecl(
baseType: baseType,
declName: declName,
declNameLoc: declNameLoc
)
}
func generateDifferentiabilityKind(text: SyntaxText) -> BridgedDifferentiabilityKind {
switch text {
case "reverse": return .reverse
case "wrt", "withRespectTo": return .normal
case "_liner": return .linear
case "_forward": return .forward
default: return .nonDifferentiable
}
}
func generate(differentiabilityArgument node: DifferentiabilityArgumentSyntax) -> BridgedParsedAutoDiffParameter {
let loc = self.generateSourceLoc(node)
switch node.argument.rawTokenKind {
case .identifier:
return .forNamed(self.generateIdentifier(node.argument), loc: loc)
case .integerLiteral:
guard let index = Int(node.argument.text) else {
// TODO: Diagnose
fatalError("(compiler bug) invalid integer literal token text")
}
return .forOrdered(index, loc: loc)
case .keyword where node.argument.rawText == "self":
return .forSelf(loc: loc)
default:
// TODO: Diagnose
fatalError("(compiler bug) invalid token for 'wrt:' argument")
}
}
func generate(differentiabilityWithRespectToArgument node: DifferentiabilityWithRespectToArgumentSyntax?) -> BridgedArrayRef {
guard let node else {
return BridgedArrayRef()
}
switch node.arguments {
case .argument(let node): // Single argument e.g. 'wrt: foo'
return CollectionOfOne(self.generate(differentiabilityArgument: node)).bridgedArray(in: self)
case .argumentList(let node): // Multiple arguments e.g. 'wrt: (self, 2)'
return node.arguments.lazy.map(self.generate(differentiabilityArgument:)).bridgedArray(in: self)
}
}
/// E.g.
/// ```
/// @derivative(of: foo(arg:), wrt: self)
/// ```
func generateDerivativeAttr(attribute node: AttributeSyntax) -> BridgedDerivativeAttr? {
guard let args = node.arguments?.as(DerivativeAttributeArgumentsSyntax.self) else {
fatalError("(compiler bug) invalid arguments for @derivative attribute")
}
guard let originalDecl = self.generateDerivativeOriginalDecl(expr: args.originalDeclName) else {
return nil
}
let accessorKind: BridgedAccessorKind?
if let accessorToken = args.accessorSpecifier {
accessorKind = self.generate(accessorSpecifier: accessorToken)
} else {
accessorKind = nil
}
let parameters = self.generate(differentiabilityWithRespectToArgument: args.arguments)
if let accessorKind {
return .createParsed(
self.ctx,
atLoc: self.generateSourceLoc(node.atSign),
range: self.generateAttrSourceRange(node),
baseType: originalDecl.baseType.asNullable,
originalName: originalDecl.declName,
originalNameLoc: originalDecl.declNameLoc,
accessorKind: accessorKind,
params: parameters
)
} else {
return .createParsed(
self.ctx,
atLoc: self.generateSourceLoc(node.atSign),
range: self.generateAttrSourceRange(node),
baseType: originalDecl.baseType.asNullable,
originalName: originalDecl.declName,
originalNameLoc: originalDecl.declNameLoc,
params: parameters
)
}
}
/// E.g.
/// ```
/// @differentiable(reverse, wrt: (self, 3) where T: U)
/// @differentiable(reverse, wrt: foo where T: U)
/// ```
func generateDifferentiableAttr(attribute node: AttributeSyntax) -> BridgedDifferentiableAttr? {
guard let args = node.arguments?.as(DifferentiableAttributeArgumentsSyntax.self) else {
fatalError("(compiler bug) invalid arguments for @differentiable attribute")
}
var differentiability: BridgedDifferentiabilityKind
if let kindSpecifier = args.kindSpecifier {
differentiability = self.generateDifferentiabilityKind(text: kindSpecifier.rawText)
} else {
differentiability = .normal
}
if differentiability == .normal {
// TODO: Diagnose "'@differentiable' has been renamed to '@differentiable(reverse)"
differentiability = .reverse
}
guard differentiability == .reverse || differentiability == .linear else {
// TODO: Diagnose.
fatalError("not supported kind for @differentiable attribute")
}
let parameters = self.generate(differentiabilityWithRespectToArgument: args.arguments)
let whereClause: BridgedTrailingWhereClause?
if let node = args.genericWhereClause {
whereClause = self.generate(genericWhereClause: node)
} else {
whereClause = nil
}
return .createParsed(
self.ctx,
atLoc: self.generateSourceLoc(node.atSign),
range: self.generateAttrSourceRange(node),
kind: differentiability,
params: parameters,
genericWhereClause: whereClause.asNullable
)
}
/// E.g:
/// ```
/// @_dynamicReplacement(for: member)
@@ -1893,6 +2063,37 @@ extension ASTGenVisitor {
)
}
/// E.g.:
/// ```
/// @transpose(of: foo(_:), wrt: self)
/// ```
func generateTransposeAttr(attribute node: AttributeSyntax) -> BridgedTransposeAttr? {
guard let args = node.arguments?.as(DerivativeAttributeArgumentsSyntax.self) else {
fatalError("(compiler bug) invalid arguments for @derivative attribute")
}
guard let originalDecl = self.generateDerivativeOriginalDecl(expr: args.originalDeclName) else {
return nil
}
if let accessorToken = args.accessorSpecifier {
// TODO: Diagnostics.
_ = accessorToken
fatalError("(compiler bug) unexpected accessor kind for @transpose attribute")
}
let parameters = self.generate(differentiabilityWithRespectToArgument: args.arguments)
return .createParsed(
self.ctx,
atLoc: self.generateSourceLoc(node.atSign),
range: self.generateAttrSourceRange(node),
baseType: originalDecl.baseType.asNullable,
originalName: originalDecl.declName,
originalNameLoc: originalDecl.declNameLoc,
params: parameters
)
}
/// E.g.:
/// ```
/// @_unavailableFromAsync

View File

@@ -378,7 +378,7 @@ extension ASTGenVisitor {
// MARK: - AbstractStorageDecl
extension ASTGenVisitor {
private func generate(accessorSpecifier specifier: TokenSyntax) -> BridgedAccessorKind? {
func generate(accessorSpecifier specifier: TokenSyntax) -> BridgedAccessorKind? {
switch specifier.keywordKind {
case .get:
return .get

View File

@@ -72,7 +72,8 @@ extension ASTGenVisitor {
return (self.generateConventionTypeAttr(attribute: node)?.asTypeAttribute)
.map(BridgedTypeOrCustomAttr.typeAttr(_:))
case .differentiable:
fatalError("unimplemented")
return (self.generateDifferentiableTypeAttr(attribute: node)?.asTypeAttribute)
.map(BridgedTypeOrCustomAttr.typeAttr(_:))
case .execution:
return (self.generateExecutionTypeAttr(attribute: node)?.asTypeAttribute)
.map(BridgedTypeOrCustomAttr.typeAttr(_:))
@@ -173,7 +174,53 @@ extension ASTGenVisitor {
clangTypeLoc: cTypeNameLoc ?? BridgedSourceLoc()
)
}
func generateDifferentiableTypeAttr(attribute node: AttributeSyntax) -> BridgedDifferentiableTypeAttr? {
let differentiability: BridgedDifferentiabilityKind
let differentiabilityLoc: BridgedSourceLoc
if let args = node.arguments {
guard let args = args.as(DifferentiableAttributeArgumentsSyntax.self) else {
fatalError("(compiler bug) invalid arguments for @differentiable attribute")
}
if let kindSpecifier = args.kindSpecifier {
differentiability = self.generateDifferentiabilityKind(text: kindSpecifier.rawText)
differentiabilityLoc = self.generateSourceLoc(kindSpecifier)
guard differentiability != .nonDifferentiable else {
// TODO: Diagnose
fatalError("invalid kind for @differentiable type attribute")
}
guard kindSpecifier.nextToken(viewMode: .fixedUp) == node.rightParen else {
// TODO: Diagnose
fatalError("only expeceted 'reverse' in @differentiable type attribute")
}
} else {
// TODO: Diagnose
fatalError("expected @differentiable(reverse)")
}
} else {
differentiability = .normal
differentiabilityLoc = nil
}
// Only 'reverse' is supported today.
guard differentiability == .reverse else {
// TODO: Diagnose
fatalError("Only @differentiable(reverse) is supported")
}
return .createParsed(
self.ctx,
atLoc: self.generateSourceLoc(node.atSign),
nameLoc: self.generateSourceLoc(node.attributeName),
parensRange: self.generateAttrParensRange(attribute: node),
kind: differentiability,
kindLoc: differentiabilityLoc
)
}
func generateExecutionTypeAttr(attribute node: AttributeSyntax) -> BridgedExecutionTypeAttr? {
let behaviorLoc = self.generateSourceLoc(node.arguments)
let behavior: BridgedExecutionTypeAttrExecutionKind? = self.generateSingleAttrOption(

109
test/ASTGen/autodiff.swift Normal file
View File

@@ -0,0 +1,109 @@
// RUN: %empty-directory(%t)
// RUNx: %target-swift-frontend-dump-parse \
// RUNx: -enable-experimental-feature ParserASTGen \
// RUNx: | %sanitize-address > %t/astgen.ast
// RUNx: %target-swift-frontend-dump-parse \
// RUNx: | %sanitize-address > %t/cpp-parser.ast
// RUNx: %diff -u %t/astgen.ast %t/cpp-parser.ast
// RUN: %target-typecheck-verify-swift
import _Differentiation
func testDifferentiableTypeAttr(_ fn: @escaping @differentiable(reverse) (Float) -> Float)
-> @differentiable(reverse) (Float) -> Float {
return fn
}
@differentiable(reverse)
func testDifferentiableSimple(_ x: Float) -> Float { return x * x }
@differentiable(reverse, wrt: arg1)
func testDifferentiableWRT1(arg1: Float, arg2: Float) -> Float { return arg1 }
@differentiable(reverse, wrt: (arg1, arg2))
func testDifferentiableWRT2(arg1: Float, arg2: Float) -> Float { return arg1 * arg2 }
@differentiable(reverse where T : Differentiable)
func testOnlyWhereClause<T : Numeric>(x: T) -> T { return x }
protocol DiffP {}
extension DiffP {
@differentiable(reverse, wrt: self where Self : Differentiable)
func testWhereClauseMethod() -> Self {
return self
}
}
func linearFunc(_ x: Float) -> Float { return x }
@transpose(of: linearFunc, wrt: 0)
func linearFuncTranspose(x: Float) -> Float { return x }
extension Float {
func getDouble() -> Double { return Double(self) }
@transpose(of: Float.getDouble, wrt: self)
static func structTranspose(v: Double) -> Float { return Float(v) }
}
struct DerivativeTest<T: Differentiable & AdditiveArithmetic>: Differentiable, AdditiveArithmetic {
typealias TangentVector = DerivativeTest<T.TangentVector>
static var zero: Self {
fatalError()
}
static func + (lhs: Self, rhs: Self) -> Self {
fatalError()
}
static func - (lhs: Self, rhs: Self) -> Self {
fatalError()
}
mutating func move(by offset: TangentVector) {
x.move(by: offset.x)
}
var x: T
static func staticMethod(_ x: Float) -> Float { 1.2 }
@derivative(of: staticMethod)
static func jvpStaticMethod(x: Float) -> (value: Float, differential: (Float) -> Float) {
return (x, { $0 })
}
func instanceMethod(_ x: T) -> T { x }
@derivative(of: instanceMethod)
func jvpInstanceMethod(x: T) -> (value: T, differential: (TangentVector, T.TangentVector) -> T.TangentVector) {
return (x, { $1 })
}
init(_ x: Float) { fatalError() }
init(_ x: T, y: Float) { fatalError() }
@derivative(of: init(_:y:))
static func vjpInit2(_ x: T, _ y: Float) -> (value: Self, pullback: (TangentVector) -> (T.TangentVector, Float)) {
return (.init(x, y: y), { _ in (.zero, .zero) })
}
var computedProperty: T {
get { x }
set { x = newValue }
}
// FIXME: SwiftParser parsed this attribute as:
// {type: 'computedProperty', originalName: 'get', accessor: null}
// But it should be:
// {type: null, originalName: 'computedProperty', accessor: 'get'}
// @derivative(of: computedProperty.get)
// func jvpProperty() -> (value: T, differential: (TangentVector) -> T.TangentVector) {
// fatalError()
// }
subscript(float float: Float) -> Float {
get { 1 }
set {}
}
@derivative(of: subscript(float:).get, wrt: self)
func vjpSubscriptLabeledGetter(float: Float) -> (value: Float, pullback: (Float) -> TangentVector) {
return (1, { _ in .zero })
}
}