mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Merge pull request #35811 from rxwei/69980056-differentiable-reverse
[AutoDiff] Add '@differentiable(reverse)'.
This commit is contained in:
@@ -2118,6 +2118,9 @@ void Serializer::writeASTBlockEntity(const DeclContext *DC) {
|
||||
}
|
||||
}
|
||||
|
||||
#define SIMPLE_CASE(TYPENAME, VALUE) \
|
||||
case swift::TYPENAME::VALUE: return uint8_t(serialization::TYPENAME::VALUE);
|
||||
|
||||
static ForeignErrorConventionKind getRawStableForeignErrorConventionKind(
|
||||
ForeignErrorConvention::Kind kind) {
|
||||
switch (kind) {
|
||||
@@ -2167,14 +2170,28 @@ static uint8_t getRawStableVarDeclIntroducer(swift::VarDecl::Introducer intr) {
|
||||
static uint8_t getRawStableAutoDiffDerivativeFunctionKind(
|
||||
swift::AutoDiffDerivativeFunctionKind kind) {
|
||||
switch (kind) {
|
||||
case swift::AutoDiffDerivativeFunctionKind::JVP:
|
||||
return uint8_t(serialization::AutoDiffDerivativeFunctionKind::JVP);
|
||||
case swift::AutoDiffDerivativeFunctionKind::VJP:
|
||||
return uint8_t(serialization::AutoDiffDerivativeFunctionKind::VJP);
|
||||
SIMPLE_CASE(AutoDiffDerivativeFunctionKind, JVP)
|
||||
SIMPLE_CASE(AutoDiffDerivativeFunctionKind, VJP)
|
||||
}
|
||||
llvm_unreachable("bad derivative function kind");
|
||||
}
|
||||
|
||||
/// Translate from the AST differentiability kind enum to the Serialization enum
|
||||
/// values, which are guaranteed to be stable.
|
||||
static uint8_t getRawStableDifferentiabilityKind(
|
||||
swift::DifferentiabilityKind diffKind) {
|
||||
switch (diffKind) {
|
||||
SIMPLE_CASE(DifferentiabilityKind, NonDifferentiable)
|
||||
SIMPLE_CASE(DifferentiabilityKind, Forward)
|
||||
SIMPLE_CASE(DifferentiabilityKind, Reverse)
|
||||
SIMPLE_CASE(DifferentiabilityKind, Normal)
|
||||
SIMPLE_CASE(DifferentiabilityKind, Linear)
|
||||
}
|
||||
llvm_unreachable("bad differentiability kind");
|
||||
}
|
||||
|
||||
#undef SIMPLE_CASE
|
||||
|
||||
/// Returns true if the declaration of \p decl depends on \p problemContext
|
||||
/// based on lexical nesting.
|
||||
///
|
||||
@@ -2541,7 +2558,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
|
||||
|
||||
DifferentiableDeclAttrLayout::emitRecord(
|
||||
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
|
||||
attr->isLinear(),
|
||||
getRawStableDifferentiabilityKind(attr->getDifferentiabilityKind()),
|
||||
S.addGenericSignatureRef(attr->getDerivativeGenericSignature()),
|
||||
paramIndicesVector);
|
||||
return;
|
||||
@@ -3938,18 +3955,6 @@ static uint8_t getRawStableFunctionTypeRepresentation(
|
||||
llvm_unreachable("bad calling convention");
|
||||
}
|
||||
|
||||
/// Translate from the AST differentiability kind enum to the Serialization enum
|
||||
/// values, which are guaranteed to be stable.
|
||||
static uint8_t getRawStableDifferentiabilityKind(
|
||||
swift::DifferentiabilityKind diffKind) {
|
||||
switch (diffKind) {
|
||||
SIMPLE_CASE(DifferentiabilityKind, NonDifferentiable)
|
||||
SIMPLE_CASE(DifferentiabilityKind, Normal)
|
||||
SIMPLE_CASE(DifferentiabilityKind, Linear)
|
||||
}
|
||||
llvm_unreachable("bad differentiability kind");
|
||||
}
|
||||
|
||||
/// Translate from the AST function representation enum to the Serialization enum
|
||||
/// values, which are guaranteed to be stable.
|
||||
static uint8_t getRawStableSILFunctionTypeRepresentation(
|
||||
|
||||
Reference in New Issue
Block a user