mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[ASTGen] Generate AutoDiff attributes
`@differentiable`, `@derivative` and `@transpose`
This commit is contained in:
@@ -33,6 +33,7 @@ template<typename T> class ArrayRef;
|
||||
}
|
||||
|
||||
namespace swift {
|
||||
enum class AccessorKind;
|
||||
class AvailabilityDomain;
|
||||
class Argument;
|
||||
class ASTContext;
|
||||
@@ -44,6 +45,7 @@ class DeclNameLoc;
|
||||
class DeclNameRef;
|
||||
class DiagnosticArgument;
|
||||
class DiagnosticEngine;
|
||||
enum class DifferentiabilityKind : uint8_t;
|
||||
class Fingerprint;
|
||||
class Identifier;
|
||||
class IfConfigClauseRangeInfo;
|
||||
@@ -55,6 +57,7 @@ enum class MacroRole : uint32_t;
|
||||
class MacroIntroducedDeclName;
|
||||
enum class MacroIntroducedDeclNameKind;
|
||||
enum class ParamSpecifier : uint8_t;
|
||||
class ParsedAutoDiffParameter;
|
||||
enum class PlatformKind : uint8_t;
|
||||
class ProtocolConformanceRef;
|
||||
class RegexLiteralPatternFeature;
|
||||
@@ -502,6 +505,13 @@ struct BridgedPatternBindingEntry {
|
||||
BridgedNullablePatternBindingInitializer initContext;
|
||||
};
|
||||
|
||||
enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedAccessorKind {
|
||||
#define ACCESSOR(ID) BridgedAccessorKind##ID,
|
||||
#include "swift/AST/AccessorKinds.def"
|
||||
};
|
||||
|
||||
swift::AccessorKind unbridged(BridgedAccessorKind kind);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MARK: Diagnostic Engine
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -745,6 +755,59 @@ struct BridgedAvailabilityDomain {
|
||||
bool isNull() const { return opaque == nullptr; };
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MARK: AutoDiff
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedDifferentiabilityKind {
|
||||
BridgedDifferentiabilityKindNonDifferentiable = 0,
|
||||
BridgedDifferentiabilityKindForward = 1,
|
||||
BridgedDifferentiabilityKindReverse = 2,
|
||||
BridgedDifferentiabilityKindNormal = 3,
|
||||
BridgedDifferentiabilityKindLinear = 4,
|
||||
};
|
||||
|
||||
swift::DifferentiabilityKind unbridged(BridgedDifferentiabilityKind cKind);
|
||||
|
||||
class BridgedParsedAutoDiffParameter {
|
||||
private:
|
||||
BridgedSourceLoc loc;
|
||||
enum Kind {
|
||||
Named,
|
||||
Ordered,
|
||||
Self,
|
||||
} kind;
|
||||
union Value {
|
||||
BridgedIdentifier name;
|
||||
unsigned index;
|
||||
|
||||
Value(BridgedIdentifier name) : name(name) {}
|
||||
Value(unsigned index) : index(index) {}
|
||||
Value() : name() {}
|
||||
} value;
|
||||
|
||||
BridgedParsedAutoDiffParameter(BridgedSourceLoc loc, Kind kind, Value value)
|
||||
: loc(loc), kind(kind), value(value) {}
|
||||
|
||||
public:
|
||||
SWIFT_NAME("forNamed(_:loc:)")
|
||||
static BridgedParsedAutoDiffParameter forNamed(BridgedIdentifier name,
|
||||
BridgedSourceLoc loc) {
|
||||
return BridgedParsedAutoDiffParameter(loc, Kind::Named, name);
|
||||
}
|
||||
SWIFT_NAME("forOrdered(_:loc:)")
|
||||
static BridgedParsedAutoDiffParameter forOrdered(size_t index,
|
||||
BridgedSourceLoc loc) {
|
||||
return BridgedParsedAutoDiffParameter(loc, Kind::Ordered, index);
|
||||
}
|
||||
SWIFT_NAME("forSelf(loc:)")
|
||||
static BridgedParsedAutoDiffParameter forSelf(BridgedSourceLoc loc) {
|
||||
return BridgedParsedAutoDiffParameter(loc, Kind::Self, {});
|
||||
}
|
||||
|
||||
swift::ParsedAutoDiffParameter unbridged() const;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MARK: DeclAttributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -879,6 +942,30 @@ BridgedCustomAttr BridgedCustomAttr_createParsed(
|
||||
BridgedNullableCustomAttributeInitializer cInitContext,
|
||||
BridgedNullableArgumentList cArgumentList);
|
||||
|
||||
SWIFT_NAME("BridgedDerivativeAttr.createParsed(_:atLoc:range:baseType:"
|
||||
"originalName:originalNameLoc:accessorKind:params:)")
|
||||
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
|
||||
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
|
||||
BridgedAccessorKind cAccessorKind, BridgedArrayRef cParams);
|
||||
|
||||
SWIFT_NAME("BridgedDerivativeAttr.createParsed(_:atLoc:range:baseType:"
|
||||
"originalName:originalNameLoc:params:)")
|
||||
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
|
||||
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
|
||||
BridgedArrayRef cParams);
|
||||
|
||||
SWIFT_NAME("BridgedDifferentiableAttr.createParsed(_:atLoc:range:kind:params:"
|
||||
"genericWhereClause:)")
|
||||
BridgedDifferentiableAttr BridgedDifferentiableAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedDifferentiabilityKind cKind,
|
||||
BridgedArrayRef cParams,
|
||||
BridgedNullableTrailingWhereClause cGenericWhereClause);
|
||||
|
||||
SWIFT_NAME("BridgedDocumentationAttr.createParsed(_:atLoc:range:metadata:"
|
||||
"accessLevel:)")
|
||||
BridgedDocumentationAttr BridgedDocumentationAttr_createParsed(
|
||||
@@ -1260,6 +1347,15 @@ BridgedSILGenNameAttr BridgedSILGenNameAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedStringRef cName, bool isRaw);
|
||||
|
||||
SWIFT_NAME(
|
||||
"BridgedTransposeAttr.createParsed(_:atLoc:range:baseType:originalName:"
|
||||
"originalNameLoc:params:)")
|
||||
BridgedTransposeAttr BridgedTransposeAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
|
||||
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
|
||||
BridgedArrayRef cParams);
|
||||
|
||||
SWIFT_NAME(
|
||||
"BridgedUnavailableFromAsyncAttr.createParsed(_:atLoc:range:message:)")
|
||||
BridgedUnavailableFromAsyncAttr BridgedUnavailableFromAsyncAttr_createParsed(
|
||||
@@ -1285,11 +1381,6 @@ enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedStaticSpelling {
|
||||
BridgedStaticSpellingClass
|
||||
};
|
||||
|
||||
enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedAccessorKind {
|
||||
#define ACCESSOR(ID) BridgedAccessorKind##ID,
|
||||
#include "swift/AST/AccessorKinds.def"
|
||||
};
|
||||
|
||||
struct BridgedAccessorRecord {
|
||||
BridgedSourceLoc lBraceLoc;
|
||||
BridgedArrayRef accessors;
|
||||
@@ -2438,6 +2529,13 @@ enum ENUM_EXTENSIBILITY_ATTR(closed) BridgedExecutionTypeAttrExecutionKind {
|
||||
BridgedExecutionTypeAttrExecutionKind_Caller
|
||||
};
|
||||
|
||||
SWIFT_NAME("BridgedDifferentiableTypeAttr.createParsed(_:atLoc:nameLoc:"
|
||||
"parensRange:kind:kindLoc:)")
|
||||
BridgedDifferentiableTypeAttr BridgedDifferentiableTypeAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
|
||||
BridgedDifferentiabilityKind cKind, BridgedSourceLoc cKindLoc);
|
||||
|
||||
SWIFT_NAME("BridgedExecutionTypeAttr.createParsed(_:atLoc:nameLoc:parensRange:"
|
||||
"behavior:behaviorLoc:)")
|
||||
BridgedExecutionTypeAttr BridgedExecutionTypeAttr_createParsed(
|
||||
|
||||
@@ -14,12 +14,47 @@
|
||||
|
||||
#include "swift/AST/ASTContext.h"
|
||||
#include "swift/AST/Attr.h"
|
||||
#include "swift/AST/AutoDiff.h"
|
||||
#include "swift/AST/Expr.h"
|
||||
#include "swift/AST/Identifier.h"
|
||||
#include "swift/Basic/Assertions.h"
|
||||
|
||||
using namespace swift;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MARK: AutoDiff
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DifferentiabilityKind unbridged(BridgedDifferentiabilityKind cKind) {
|
||||
switch (cKind) {
|
||||
case BridgedDifferentiabilityKindNonDifferentiable:
|
||||
return DifferentiabilityKind::NonDifferentiable;
|
||||
case BridgedDifferentiabilityKindForward:
|
||||
return DifferentiabilityKind::Forward;
|
||||
case BridgedDifferentiabilityKindReverse:
|
||||
return DifferentiabilityKind::Reverse;
|
||||
case BridgedDifferentiabilityKindNormal:
|
||||
return DifferentiabilityKind::Normal;
|
||||
case BridgedDifferentiabilityKindLinear:
|
||||
return DifferentiabilityKind::Linear;
|
||||
}
|
||||
llvm_unreachable("unhandled enum value");
|
||||
}
|
||||
|
||||
ParsedAutoDiffParameter BridgedParsedAutoDiffParameter::unbridged() const {
|
||||
switch (kind) {
|
||||
case Kind::Named:
|
||||
return ParsedAutoDiffParameter::getNamedParameter(loc.unbridged(),
|
||||
value.name.unbridged());
|
||||
case Kind::Ordered:
|
||||
return ParsedAutoDiffParameter::getOrderedParameter(loc.unbridged(),
|
||||
value.index);
|
||||
case Kind::Self:
|
||||
return ParsedAutoDiffParameter::getSelfParameter(loc.unbridged());
|
||||
}
|
||||
llvm_unreachable("unhandled enum value");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MARK: DeclAttributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -221,6 +256,62 @@ BridgedCustomAttr BridgedCustomAttr_createParsed(
|
||||
cInitContext.unbridged(), cArgumentList.unbridged());
|
||||
}
|
||||
|
||||
BridgedDerivativeAttr BridgedDerivativeAttr_createParsedImpl(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
|
||||
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
|
||||
std::optional<BridgedAccessorKind> cAccessorKind, BridgedArrayRef cParams) {
|
||||
std::optional<AccessorKind> accessorKind;
|
||||
if (cAccessorKind)
|
||||
accessorKind = unbridged(*cAccessorKind);
|
||||
SmallVector<ParsedAutoDiffParameter, 2> params;
|
||||
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
|
||||
params.push_back(elem.unbridged());
|
||||
|
||||
return DerivativeAttr::create(cContext.unbridged(),
|
||||
/*implicit=*/false, cAtLoc.unbridged(),
|
||||
cRange.unbridged(), cBaseType.unbridged(),
|
||||
DeclNameRefWithLoc{cOriginalName.unbridged(),
|
||||
cOriginalNameLoc.unbridged(),
|
||||
accessorKind},
|
||||
params);
|
||||
}
|
||||
|
||||
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
|
||||
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
|
||||
BridgedAccessorKind cAccessorKind, BridgedArrayRef cParams) {
|
||||
return BridgedDerivativeAttr_createParsedImpl(
|
||||
cContext, cAtLoc, cRange, cBaseType, cOriginalName, cOriginalNameLoc,
|
||||
cAccessorKind, cParams);
|
||||
}
|
||||
|
||||
BridgedDerivativeAttr BridgedDerivativeAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
|
||||
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
|
||||
BridgedArrayRef cParams) {
|
||||
return BridgedDerivativeAttr_createParsedImpl(
|
||||
cContext, cAtLoc, cRange, cBaseType, cOriginalName, cOriginalNameLoc,
|
||||
/*cAccessorKind=*/std::nullopt, cParams);
|
||||
}
|
||||
|
||||
BridgedDifferentiableAttr BridgedDifferentiableAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedDifferentiabilityKind cKind,
|
||||
BridgedArrayRef cParams,
|
||||
BridgedNullableTrailingWhereClause cGenericWhereClause) {
|
||||
SmallVector<ParsedAutoDiffParameter, 2> params;
|
||||
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
|
||||
params.push_back(elem.unbridged());
|
||||
|
||||
return DifferentiableAttr::create(cContext.unbridged(), /*implicit=*/false,
|
||||
cAtLoc.unbridged(), cRange.unbridged(),
|
||||
unbridged(cKind), params,
|
||||
cGenericWhereClause.unbridged());
|
||||
}
|
||||
|
||||
BridgedDynamicReplacementAttr BridgedDynamicReplacementAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceLoc cAttrNameLoc, BridgedSourceLoc cLParenLoc,
|
||||
@@ -752,6 +843,24 @@ BridgedSILGenNameAttr BridgedSILGenNameAttr_createParsed(
|
||||
cRange.unbridged(), /*Implicit=*/false);
|
||||
}
|
||||
|
||||
BridgedTransposeAttr BridgedTransposeAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedNullableTypeRepr cBaseType,
|
||||
BridgedDeclNameRef cOriginalName, BridgedDeclNameLoc cOriginalNameLoc,
|
||||
BridgedArrayRef cParams) {
|
||||
SmallVector<ParsedAutoDiffParameter, 2> params;
|
||||
for (auto &elem : cParams.unbridged<BridgedParsedAutoDiffParameter>())
|
||||
params.push_back(elem.unbridged());
|
||||
|
||||
return TransposeAttr::create(
|
||||
cContext.unbridged(),
|
||||
/*implicit=*/false, cAtLoc.unbridged(), cRange.unbridged(),
|
||||
cBaseType.unbridged(),
|
||||
DeclNameRefWithLoc{cOriginalName.unbridged(), cOriginalNameLoc.unbridged(),
|
||||
/*AccessorKind=*/std::nullopt},
|
||||
params);
|
||||
}
|
||||
|
||||
BridgedUnavailableFromAsyncAttr BridgedUnavailableFromAsyncAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceRange cRange, BridgedStringRef cMessage) {
|
||||
|
||||
@@ -111,7 +111,7 @@ static StaticSpellingKind unbridged(BridgedStaticSpelling kind) {
|
||||
return static_cast<StaticSpellingKind>(kind);
|
||||
}
|
||||
|
||||
static AccessorKind unbridged(BridgedAccessorKind kind) {
|
||||
AccessorKind unbridged(BridgedAccessorKind kind) {
|
||||
return static_cast<AccessorKind>(kind);
|
||||
}
|
||||
|
||||
|
||||
@@ -78,6 +78,15 @@ BridgedConventionTypeAttr BridgedConventionTypeAttr_createParsed(
|
||||
{cClangType.unbridged(), cClangTypeLoc.unbridged()});
|
||||
}
|
||||
|
||||
BridgedDifferentiableTypeAttr BridgedDifferentiableTypeAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
|
||||
BridgedDifferentiabilityKind cKind, BridgedSourceLoc cKindLoc) {
|
||||
return new (cContext.unbridged()) DifferentiableTypeAttr(
|
||||
cAtLoc.unbridged(), cNameLoc.unbridged(), cParensRange.unbridged(),
|
||||
{unbridged(cKind), cKindLoc.unbridged()});
|
||||
}
|
||||
|
||||
BridgedExecutionTypeAttr BridgedExecutionTypeAttr_createParsed(
|
||||
BridgedASTContext cContext, BridgedSourceLoc cAtLoc,
|
||||
BridgedSourceLoc cNameLoc, BridgedSourceRange cParensRange,
|
||||
|
||||
@@ -126,9 +126,9 @@ extension ASTGenVisitor {
|
||||
case .cDecl:
|
||||
return handle(self.generateCDeclAttr(attribute: node)?.asDeclAttribute)
|
||||
case .derivative:
|
||||
fatalError("unimplemented")
|
||||
return handle(self.generateDerivativeAttr(attribute: node)?.asDeclAttribute)
|
||||
case .differentiable:
|
||||
fatalError("unimplemented")
|
||||
return handle(self.generateDifferentiableAttr(attribute: node)?.asDeclAttribute)
|
||||
case .dynamicReplacement:
|
||||
return handle(self.generateDynamicReplacementAttr(attribute: node)?.asDeclAttribute)
|
||||
case .documentation:
|
||||
@@ -182,7 +182,7 @@ extension ASTGenVisitor {
|
||||
case .swiftNativeObjCRuntimeBase:
|
||||
return handle(self.generateSwiftNativeObjCRuntimeBaseAttr(attribute: node)?.asDeclAttribute)
|
||||
case .transpose:
|
||||
fatalError("unimplemented")
|
||||
return handle(self.generateTransposeAttr(attribute: node)?.asDeclAttribute)
|
||||
case .typeEraser:
|
||||
fatalError("unimplemented")
|
||||
case .unavailableFromAsync:
|
||||
@@ -553,6 +553,176 @@ extension ASTGenVisitor {
|
||||
)
|
||||
}
|
||||
|
||||
struct GeneratedDerivativeOriginalDecl {
|
||||
var baseType: BridgedTypeRepr?
|
||||
var declName: BridgedDeclNameRef
|
||||
var declNameLoc: BridgedDeclNameLoc
|
||||
}
|
||||
|
||||
func generateDerivativeOriginalDecl(expr: ExprSyntax) -> GeneratedDerivativeOriginalDecl? {
|
||||
var baseType: BridgedTypeRepr?
|
||||
var declName: BridgedDeclNameRef
|
||||
var declNameLoc: BridgedDeclNameLoc
|
||||
|
||||
if let declrefExpr = expr.as(DeclReferenceExprSyntax.self) {
|
||||
baseType = nil
|
||||
(declName, declNameLoc) = self.generateDeclNameRef(declReferenceExpr: declrefExpr)
|
||||
|
||||
} else if let memberExpr = expr.as(MemberAccessExprSyntax.self),
|
||||
let baseExpr = memberExpr.base {
|
||||
guard let _baseType = self.generateTypeRepr(expr: baseExpr) else {
|
||||
// TODO: Diagnose.
|
||||
fatalError("invalid type expression for @derivative qualified decl name")
|
||||
}
|
||||
baseType = _baseType
|
||||
(declName, declNameLoc) = self.generateDeclNameRef(declReferenceExpr: memberExpr.declName)
|
||||
|
||||
} else {
|
||||
// TODO: Diagnosse.
|
||||
fatalError("invalid expression for @derivative original decl name")
|
||||
}
|
||||
|
||||
return GeneratedDerivativeOriginalDecl(
|
||||
baseType: baseType,
|
||||
declName: declName,
|
||||
declNameLoc: declNameLoc
|
||||
)
|
||||
}
|
||||
|
||||
func generateDifferentiabilityKind(text: SyntaxText) -> BridgedDifferentiabilityKind {
|
||||
switch text {
|
||||
case "reverse": return .reverse
|
||||
case "wrt", "withRespectTo": return .normal
|
||||
case "_liner": return .linear
|
||||
case "_forward": return .forward
|
||||
default: return .nonDifferentiable
|
||||
}
|
||||
}
|
||||
|
||||
func generate(differentiabilityArgument node: DifferentiabilityArgumentSyntax) -> BridgedParsedAutoDiffParameter {
|
||||
let loc = self.generateSourceLoc(node)
|
||||
switch node.argument.rawTokenKind {
|
||||
case .identifier:
|
||||
return .forNamed(self.generateIdentifier(node.argument), loc: loc)
|
||||
|
||||
case .integerLiteral:
|
||||
guard let index = Int(node.argument.text) else {
|
||||
// TODO: Diagnose
|
||||
fatalError("(compiler bug) invalid integer literal token text")
|
||||
}
|
||||
return .forOrdered(index, loc: loc)
|
||||
|
||||
case .keyword where node.argument.rawText == "self":
|
||||
return .forSelf(loc: loc)
|
||||
|
||||
default:
|
||||
// TODO: Diagnose
|
||||
fatalError("(compiler bug) invalid token for 'wrt:' argument")
|
||||
}
|
||||
}
|
||||
|
||||
func generate(differentiabilityWithRespectToArgument node: DifferentiabilityWithRespectToArgumentSyntax?) -> BridgedArrayRef {
|
||||
guard let node else {
|
||||
return BridgedArrayRef()
|
||||
}
|
||||
switch node.arguments {
|
||||
case .argument(let node): // Single argument e.g. 'wrt: foo'
|
||||
return CollectionOfOne(self.generate(differentiabilityArgument: node)).bridgedArray(in: self)
|
||||
case .argumentList(let node): // Multiple arguments e.g. 'wrt: (self, 2)'
|
||||
return node.arguments.lazy.map(self.generate(differentiabilityArgument:)).bridgedArray(in: self)
|
||||
}
|
||||
}
|
||||
|
||||
/// E.g.
|
||||
/// ```
|
||||
/// @derivative(of: foo(arg:), wrt: self)
|
||||
/// ```
|
||||
func generateDerivativeAttr(attribute node: AttributeSyntax) -> BridgedDerivativeAttr? {
|
||||
guard let args = node.arguments?.as(DerivativeAttributeArgumentsSyntax.self) else {
|
||||
fatalError("(compiler bug) invalid arguments for @derivative attribute")
|
||||
}
|
||||
guard let originalDecl = self.generateDerivativeOriginalDecl(expr: args.originalDeclName) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
let accessorKind: BridgedAccessorKind?
|
||||
if let accessorToken = args.accessorSpecifier {
|
||||
accessorKind = self.generate(accessorSpecifier: accessorToken)
|
||||
} else {
|
||||
accessorKind = nil
|
||||
}
|
||||
|
||||
let parameters = self.generate(differentiabilityWithRespectToArgument: args.arguments)
|
||||
|
||||
if let accessorKind {
|
||||
return .createParsed(
|
||||
self.ctx,
|
||||
atLoc: self.generateSourceLoc(node.atSign),
|
||||
range: self.generateAttrSourceRange(node),
|
||||
baseType: originalDecl.baseType.asNullable,
|
||||
originalName: originalDecl.declName,
|
||||
originalNameLoc: originalDecl.declNameLoc,
|
||||
accessorKind: accessorKind,
|
||||
params: parameters
|
||||
)
|
||||
} else {
|
||||
return .createParsed(
|
||||
self.ctx,
|
||||
atLoc: self.generateSourceLoc(node.atSign),
|
||||
range: self.generateAttrSourceRange(node),
|
||||
baseType: originalDecl.baseType.asNullable,
|
||||
originalName: originalDecl.declName,
|
||||
originalNameLoc: originalDecl.declNameLoc,
|
||||
params: parameters
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// E.g.
|
||||
/// ```
|
||||
/// @differentiable(reverse, wrt: (self, 3) where T: U)
|
||||
/// @differentiable(reverse, wrt: foo where T: U)
|
||||
/// ```
|
||||
func generateDifferentiableAttr(attribute node: AttributeSyntax) -> BridgedDifferentiableAttr? {
|
||||
guard let args = node.arguments?.as(DifferentiableAttributeArgumentsSyntax.self) else {
|
||||
fatalError("(compiler bug) invalid arguments for @differentiable attribute")
|
||||
}
|
||||
|
||||
var differentiability: BridgedDifferentiabilityKind
|
||||
if let kindSpecifier = args.kindSpecifier {
|
||||
differentiability = self.generateDifferentiabilityKind(text: kindSpecifier.rawText)
|
||||
} else {
|
||||
differentiability = .normal
|
||||
}
|
||||
if differentiability == .normal {
|
||||
// TODO: Diagnose "'@differentiable' has been renamed to '@differentiable(reverse)"
|
||||
differentiability = .reverse
|
||||
}
|
||||
guard differentiability == .reverse || differentiability == .linear else {
|
||||
// TODO: Diagnose.
|
||||
fatalError("not supported kind for @differentiable attribute")
|
||||
}
|
||||
|
||||
let parameters = self.generate(differentiabilityWithRespectToArgument: args.arguments)
|
||||
|
||||
let whereClause: BridgedTrailingWhereClause?
|
||||
if let node = args.genericWhereClause {
|
||||
whereClause = self.generate(genericWhereClause: node)
|
||||
} else {
|
||||
whereClause = nil
|
||||
}
|
||||
|
||||
return .createParsed(
|
||||
self.ctx,
|
||||
atLoc: self.generateSourceLoc(node.atSign),
|
||||
range: self.generateAttrSourceRange(node),
|
||||
kind: differentiability,
|
||||
params: parameters,
|
||||
genericWhereClause: whereClause.asNullable
|
||||
)
|
||||
}
|
||||
|
||||
/// E.g:
|
||||
/// ```
|
||||
/// @_dynamicReplacement(for: member)
|
||||
@@ -1893,6 +2063,37 @@ extension ASTGenVisitor {
|
||||
)
|
||||
}
|
||||
|
||||
/// E.g.:
|
||||
/// ```
|
||||
/// @transpose(of: foo(_:), wrt: self)
|
||||
/// ```
|
||||
func generateTransposeAttr(attribute node: AttributeSyntax) -> BridgedTransposeAttr? {
|
||||
guard let args = node.arguments?.as(DerivativeAttributeArgumentsSyntax.self) else {
|
||||
fatalError("(compiler bug) invalid arguments for @derivative attribute")
|
||||
}
|
||||
guard let originalDecl = self.generateDerivativeOriginalDecl(expr: args.originalDeclName) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
if let accessorToken = args.accessorSpecifier {
|
||||
// TODO: Diagnostics.
|
||||
_ = accessorToken
|
||||
fatalError("(compiler bug) unexpected accessor kind for @transpose attribute")
|
||||
}
|
||||
|
||||
let parameters = self.generate(differentiabilityWithRespectToArgument: args.arguments)
|
||||
|
||||
return .createParsed(
|
||||
self.ctx,
|
||||
atLoc: self.generateSourceLoc(node.atSign),
|
||||
range: self.generateAttrSourceRange(node),
|
||||
baseType: originalDecl.baseType.asNullable,
|
||||
originalName: originalDecl.declName,
|
||||
originalNameLoc: originalDecl.declNameLoc,
|
||||
params: parameters
|
||||
)
|
||||
}
|
||||
|
||||
/// E.g.:
|
||||
/// ```
|
||||
/// @_unavailableFromAsync
|
||||
|
||||
@@ -378,7 +378,7 @@ extension ASTGenVisitor {
|
||||
// MARK: - AbstractStorageDecl
|
||||
|
||||
extension ASTGenVisitor {
|
||||
private func generate(accessorSpecifier specifier: TokenSyntax) -> BridgedAccessorKind? {
|
||||
func generate(accessorSpecifier specifier: TokenSyntax) -> BridgedAccessorKind? {
|
||||
switch specifier.keywordKind {
|
||||
case .get:
|
||||
return .get
|
||||
|
||||
@@ -72,7 +72,8 @@ extension ASTGenVisitor {
|
||||
return (self.generateConventionTypeAttr(attribute: node)?.asTypeAttribute)
|
||||
.map(BridgedTypeOrCustomAttr.typeAttr(_:))
|
||||
case .differentiable:
|
||||
fatalError("unimplemented")
|
||||
return (self.generateDifferentiableTypeAttr(attribute: node)?.asTypeAttribute)
|
||||
.map(BridgedTypeOrCustomAttr.typeAttr(_:))
|
||||
case .execution:
|
||||
return (self.generateExecutionTypeAttr(attribute: node)?.asTypeAttribute)
|
||||
.map(BridgedTypeOrCustomAttr.typeAttr(_:))
|
||||
@@ -173,7 +174,53 @@ extension ASTGenVisitor {
|
||||
clangTypeLoc: cTypeNameLoc ?? BridgedSourceLoc()
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
func generateDifferentiableTypeAttr(attribute node: AttributeSyntax) -> BridgedDifferentiableTypeAttr? {
|
||||
let differentiability: BridgedDifferentiabilityKind
|
||||
let differentiabilityLoc: BridgedSourceLoc
|
||||
if let args = node.arguments {
|
||||
guard let args = args.as(DifferentiableAttributeArgumentsSyntax.self) else {
|
||||
fatalError("(compiler bug) invalid arguments for @differentiable attribute")
|
||||
}
|
||||
|
||||
if let kindSpecifier = args.kindSpecifier {
|
||||
differentiability = self.generateDifferentiabilityKind(text: kindSpecifier.rawText)
|
||||
differentiabilityLoc = self.generateSourceLoc(kindSpecifier)
|
||||
|
||||
guard differentiability != .nonDifferentiable else {
|
||||
// TODO: Diagnose
|
||||
fatalError("invalid kind for @differentiable type attribute")
|
||||
}
|
||||
|
||||
guard kindSpecifier.nextToken(viewMode: .fixedUp) == node.rightParen else {
|
||||
// TODO: Diagnose
|
||||
fatalError("only expeceted 'reverse' in @differentiable type attribute")
|
||||
}
|
||||
} else {
|
||||
// TODO: Diagnose
|
||||
fatalError("expected @differentiable(reverse)")
|
||||
}
|
||||
} else {
|
||||
differentiability = .normal
|
||||
differentiabilityLoc = nil
|
||||
}
|
||||
|
||||
// Only 'reverse' is supported today.
|
||||
guard differentiability == .reverse else {
|
||||
// TODO: Diagnose
|
||||
fatalError("Only @differentiable(reverse) is supported")
|
||||
}
|
||||
|
||||
return .createParsed(
|
||||
self.ctx,
|
||||
atLoc: self.generateSourceLoc(node.atSign),
|
||||
nameLoc: self.generateSourceLoc(node.attributeName),
|
||||
parensRange: self.generateAttrParensRange(attribute: node),
|
||||
kind: differentiability,
|
||||
kindLoc: differentiabilityLoc
|
||||
)
|
||||
}
|
||||
|
||||
func generateExecutionTypeAttr(attribute node: AttributeSyntax) -> BridgedExecutionTypeAttr? {
|
||||
let behaviorLoc = self.generateSourceLoc(node.arguments)
|
||||
let behavior: BridgedExecutionTypeAttrExecutionKind? = self.generateSingleAttrOption(
|
||||
|
||||
109
test/ASTGen/autodiff.swift
Normal file
109
test/ASTGen/autodiff.swift
Normal file
@@ -0,0 +1,109 @@
|
||||
// RUN: %empty-directory(%t)
|
||||
|
||||
// RUNx: %target-swift-frontend-dump-parse \
|
||||
// RUNx: -enable-experimental-feature ParserASTGen \
|
||||
// RUNx: | %sanitize-address > %t/astgen.ast
|
||||
|
||||
// RUNx: %target-swift-frontend-dump-parse \
|
||||
// RUNx: | %sanitize-address > %t/cpp-parser.ast
|
||||
|
||||
// RUNx: %diff -u %t/astgen.ast %t/cpp-parser.ast
|
||||
|
||||
// RUN: %target-typecheck-verify-swift
|
||||
|
||||
import _Differentiation
|
||||
|
||||
func testDifferentiableTypeAttr(_ fn: @escaping @differentiable(reverse) (Float) -> Float)
|
||||
-> @differentiable(reverse) (Float) -> Float {
|
||||
return fn
|
||||
}
|
||||
|
||||
@differentiable(reverse)
|
||||
func testDifferentiableSimple(_ x: Float) -> Float { return x * x }
|
||||
@differentiable(reverse, wrt: arg1)
|
||||
func testDifferentiableWRT1(arg1: Float, arg2: Float) -> Float { return arg1 }
|
||||
@differentiable(reverse, wrt: (arg1, arg2))
|
||||
func testDifferentiableWRT2(arg1: Float, arg2: Float) -> Float { return arg1 * arg2 }
|
||||
@differentiable(reverse where T : Differentiable)
|
||||
func testOnlyWhereClause<T : Numeric>(x: T) -> T { return x }
|
||||
protocol DiffP {}
|
||||
extension DiffP {
|
||||
@differentiable(reverse, wrt: self where Self : Differentiable)
|
||||
func testWhereClauseMethod() -> Self {
|
||||
return self
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func linearFunc(_ x: Float) -> Float { return x }
|
||||
@transpose(of: linearFunc, wrt: 0)
|
||||
func linearFuncTranspose(x: Float) -> Float { return x }
|
||||
|
||||
extension Float {
|
||||
func getDouble() -> Double { return Double(self) }
|
||||
|
||||
@transpose(of: Float.getDouble, wrt: self)
|
||||
static func structTranspose(v: Double) -> Float { return Float(v) }
|
||||
}
|
||||
|
||||
struct DerivativeTest<T: Differentiable & AdditiveArithmetic>: Differentiable, AdditiveArithmetic {
|
||||
|
||||
typealias TangentVector = DerivativeTest<T.TangentVector>
|
||||
|
||||
static var zero: Self {
|
||||
fatalError()
|
||||
}
|
||||
static func + (lhs: Self, rhs: Self) -> Self {
|
||||
fatalError()
|
||||
}
|
||||
static func - (lhs: Self, rhs: Self) -> Self {
|
||||
fatalError()
|
||||
}
|
||||
mutating func move(by offset: TangentVector) {
|
||||
x.move(by: offset.x)
|
||||
}
|
||||
|
||||
var x: T
|
||||
|
||||
static func staticMethod(_ x: Float) -> Float { 1.2 }
|
||||
@derivative(of: staticMethod)
|
||||
static func jvpStaticMethod(x: Float) -> (value: Float, differential: (Float) -> Float) {
|
||||
return (x, { $0 })
|
||||
}
|
||||
|
||||
func instanceMethod(_ x: T) -> T { x }
|
||||
@derivative(of: instanceMethod)
|
||||
func jvpInstanceMethod(x: T) -> (value: T, differential: (TangentVector, T.TangentVector) -> T.TangentVector) {
|
||||
return (x, { $1 })
|
||||
}
|
||||
|
||||
init(_ x: Float) { fatalError() }
|
||||
init(_ x: T, y: Float) { fatalError() }
|
||||
|
||||
@derivative(of: init(_:y:))
|
||||
static func vjpInit2(_ x: T, _ y: Float) -> (value: Self, pullback: (TangentVector) -> (T.TangentVector, Float)) {
|
||||
return (.init(x, y: y), { _ in (.zero, .zero) })
|
||||
}
|
||||
|
||||
var computedProperty: T {
|
||||
get { x }
|
||||
set { x = newValue }
|
||||
}
|
||||
// FIXME: SwiftParser parsed this attribute as:
|
||||
// {type: 'computedProperty', originalName: 'get', accessor: null}
|
||||
// But it should be:
|
||||
// {type: null, originalName: 'computedProperty', accessor: 'get'}
|
||||
// @derivative(of: computedProperty.get)
|
||||
// func jvpProperty() -> (value: T, differential: (TangentVector) -> T.TangentVector) {
|
||||
// fatalError()
|
||||
// }
|
||||
|
||||
subscript(float float: Float) -> Float {
|
||||
get { 1 }
|
||||
set {}
|
||||
}
|
||||
@derivative(of: subscript(float:).get, wrt: self)
|
||||
func vjpSubscriptLabeledGetter(float: Float) -> (value: Float, pullback: (Float) -> TangentVector) {
|
||||
return (1, { _ in .zero })
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user