Merge pull request #35811 from rxwei/69980056-differentiable-reverse

[AutoDiff] Add '@differentiable(reverse)'.
This commit is contained in:
Richard Wei
2021-02-08 04:32:27 -08:00
committed by GitHub
149 changed files with 1674 additions and 1536 deletions

View File

@@ -2118,6 +2118,9 @@ void Serializer::writeASTBlockEntity(const DeclContext *DC) {
}
}
#define SIMPLE_CASE(TYPENAME, VALUE) \
case swift::TYPENAME::VALUE: return uint8_t(serialization::TYPENAME::VALUE);
static ForeignErrorConventionKind getRawStableForeignErrorConventionKind(
ForeignErrorConvention::Kind kind) {
switch (kind) {
@@ -2167,14 +2170,28 @@ static uint8_t getRawStableVarDeclIntroducer(swift::VarDecl::Introducer intr) {
static uint8_t getRawStableAutoDiffDerivativeFunctionKind(
swift::AutoDiffDerivativeFunctionKind kind) {
switch (kind) {
case swift::AutoDiffDerivativeFunctionKind::JVP:
return uint8_t(serialization::AutoDiffDerivativeFunctionKind::JVP);
case swift::AutoDiffDerivativeFunctionKind::VJP:
return uint8_t(serialization::AutoDiffDerivativeFunctionKind::VJP);
SIMPLE_CASE(AutoDiffDerivativeFunctionKind, JVP)
SIMPLE_CASE(AutoDiffDerivativeFunctionKind, VJP)
}
llvm_unreachable("bad derivative function kind");
}
/// Translate from the AST differentiability kind enum to the Serialization enum
/// values, which are guaranteed to be stable.
static uint8_t getRawStableDifferentiabilityKind(
swift::DifferentiabilityKind diffKind) {
switch (diffKind) {
SIMPLE_CASE(DifferentiabilityKind, NonDifferentiable)
SIMPLE_CASE(DifferentiabilityKind, Forward)
SIMPLE_CASE(DifferentiabilityKind, Reverse)
SIMPLE_CASE(DifferentiabilityKind, Normal)
SIMPLE_CASE(DifferentiabilityKind, Linear)
}
llvm_unreachable("bad differentiability kind");
}
#undef SIMPLE_CASE
/// Returns true if the declaration of \p decl depends on \p problemContext
/// based on lexical nesting.
///
@@ -2541,7 +2558,7 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
DifferentiableDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
attr->isLinear(),
getRawStableDifferentiabilityKind(attr->getDifferentiabilityKind()),
S.addGenericSignatureRef(attr->getDerivativeGenericSignature()),
paramIndicesVector);
return;
@@ -3938,18 +3955,6 @@ static uint8_t getRawStableFunctionTypeRepresentation(
llvm_unreachable("bad calling convention");
}
/// Translate from the AST differentiability kind enum to the Serialization enum
/// values, which are guaranteed to be stable.
static uint8_t getRawStableDifferentiabilityKind(
swift::DifferentiabilityKind diffKind) {
switch (diffKind) {
SIMPLE_CASE(DifferentiabilityKind, NonDifferentiable)
SIMPLE_CASE(DifferentiabilityKind, Normal)
SIMPLE_CASE(DifferentiabilityKind, Linear)
}
llvm_unreachable("bad differentiability kind");
}
/// Translate from the AST function representation enum to the Serialization enum
/// values, which are guaranteed to be stable.
static uint8_t getRawStableSILFunctionTypeRepresentation(