Merge pull request #30711 from rxwei/differential-operators

[AutoDiff upstream] Add differential operators and some utilities.
This commit is contained in:
marcrasi
2020-04-01 10:11:35 -07:00
committed by GitHub
12 changed files with 652 additions and 40 deletions

View File

@@ -181,33 +181,53 @@ enum class ImplFunctionRepresentation {
Closure
};
enum class ImplFunctionDifferentiabilityKind {
NonDifferentiable,
Normal,
Linear
};
class ImplFunctionTypeFlags {
unsigned Rep : 3;
unsigned Pseudogeneric : 1;
unsigned Escaping : 1;
unsigned DifferentiabilityKind : 2;
public:
ImplFunctionTypeFlags() : Rep(0), Pseudogeneric(0), Escaping(0) {}
ImplFunctionTypeFlags()
: Rep(0), Pseudogeneric(0), Escaping(0), DifferentiabilityKind(0) {}
ImplFunctionTypeFlags(ImplFunctionRepresentation rep,
bool pseudogeneric, bool noescape)
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape) {}
ImplFunctionTypeFlags(ImplFunctionRepresentation rep, bool pseudogeneric,
bool noescape,
ImplFunctionDifferentiabilityKind diffKind)
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape),
DifferentiabilityKind(unsigned(diffKind)) {}
ImplFunctionTypeFlags
withRepresentation(ImplFunctionRepresentation rep) const {
return ImplFunctionTypeFlags(rep, Pseudogeneric, Escaping);
return ImplFunctionTypeFlags(
rep, Pseudogeneric, Escaping,
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
}
ImplFunctionTypeFlags
withEscaping() const {
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
Pseudogeneric, true);
return ImplFunctionTypeFlags(
ImplFunctionRepresentation(Rep), Pseudogeneric, true,
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
}
ImplFunctionTypeFlags
withPseudogeneric() const {
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
true, Escaping);
return ImplFunctionTypeFlags(
ImplFunctionRepresentation(Rep), true, Escaping,
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
}
ImplFunctionTypeFlags
withDifferentiabilityKind(ImplFunctionDifferentiabilityKind diffKind) const {
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep), Pseudogeneric,
Escaping, diffKind);
}
ImplFunctionRepresentation getRepresentation() const {
@@ -217,6 +237,10 @@ public:
bool isEscaping() const { return Escaping; }
bool isPseudogeneric() const { return Pseudogeneric; }
ImplFunctionDifferentiabilityKind getDifferentiabilityKind() const {
return ImplFunctionDifferentiabilityKind(DifferentiabilityKind);
}
};
#if SWIFT_OBJC_INTEROP
@@ -582,6 +606,14 @@ class TypeDecoder {
flags =
flags.withRepresentation(ImplFunctionRepresentation::Block);
}
} else if (child->getKind() == NodeKind::ImplDifferentiable) {
flags = flags.withDifferentiabilityKind(
ImplFunctionDifferentiabilityKind::Normal);
} else if (child->getKind() == NodeKind::ImplLinear) {
flags = flags.withDifferentiabilityKind(
ImplFunctionDifferentiabilityKind::Linear);
} else if (child->getKind() == NodeKind::ImplEscaping) {
flags = flags.withEscaping();
} else if (child->getKind() == NodeKind::ImplEscaping) {
flags = flags.withEscaping();
} else if (child->getKind() == NodeKind::ImplParameter) {