mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Merge pull request #30711 from rxwei/differential-operators
[AutoDiff upstream] Add differential operators and some utilities.
This commit is contained in:
@@ -181,33 +181,53 @@ enum class ImplFunctionRepresentation {
|
||||
Closure
|
||||
};
|
||||
|
||||
enum class ImplFunctionDifferentiabilityKind {
|
||||
NonDifferentiable,
|
||||
Normal,
|
||||
Linear
|
||||
};
|
||||
|
||||
class ImplFunctionTypeFlags {
|
||||
unsigned Rep : 3;
|
||||
unsigned Pseudogeneric : 1;
|
||||
unsigned Escaping : 1;
|
||||
unsigned DifferentiabilityKind : 2;
|
||||
|
||||
public:
|
||||
ImplFunctionTypeFlags() : Rep(0), Pseudogeneric(0), Escaping(0) {}
|
||||
ImplFunctionTypeFlags()
|
||||
: Rep(0), Pseudogeneric(0), Escaping(0), DifferentiabilityKind(0) {}
|
||||
|
||||
ImplFunctionTypeFlags(ImplFunctionRepresentation rep,
|
||||
bool pseudogeneric, bool noescape)
|
||||
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape) {}
|
||||
ImplFunctionTypeFlags(ImplFunctionRepresentation rep, bool pseudogeneric,
|
||||
bool noescape,
|
||||
ImplFunctionDifferentiabilityKind diffKind)
|
||||
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape),
|
||||
DifferentiabilityKind(unsigned(diffKind)) {}
|
||||
|
||||
ImplFunctionTypeFlags
|
||||
withRepresentation(ImplFunctionRepresentation rep) const {
|
||||
return ImplFunctionTypeFlags(rep, Pseudogeneric, Escaping);
|
||||
return ImplFunctionTypeFlags(
|
||||
rep, Pseudogeneric, Escaping,
|
||||
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
|
||||
}
|
||||
|
||||
ImplFunctionTypeFlags
|
||||
withEscaping() const {
|
||||
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
|
||||
Pseudogeneric, true);
|
||||
return ImplFunctionTypeFlags(
|
||||
ImplFunctionRepresentation(Rep), Pseudogeneric, true,
|
||||
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
|
||||
}
|
||||
|
||||
ImplFunctionTypeFlags
|
||||
withPseudogeneric() const {
|
||||
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
|
||||
true, Escaping);
|
||||
return ImplFunctionTypeFlags(
|
||||
ImplFunctionRepresentation(Rep), true, Escaping,
|
||||
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
|
||||
}
|
||||
|
||||
ImplFunctionTypeFlags
|
||||
withDifferentiabilityKind(ImplFunctionDifferentiabilityKind diffKind) const {
|
||||
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep), Pseudogeneric,
|
||||
Escaping, diffKind);
|
||||
}
|
||||
|
||||
ImplFunctionRepresentation getRepresentation() const {
|
||||
@@ -217,6 +237,10 @@ public:
|
||||
bool isEscaping() const { return Escaping; }
|
||||
|
||||
bool isPseudogeneric() const { return Pseudogeneric; }
|
||||
|
||||
ImplFunctionDifferentiabilityKind getDifferentiabilityKind() const {
|
||||
return ImplFunctionDifferentiabilityKind(DifferentiabilityKind);
|
||||
}
|
||||
};
|
||||
|
||||
#if SWIFT_OBJC_INTEROP
|
||||
@@ -582,6 +606,14 @@ class TypeDecoder {
|
||||
flags =
|
||||
flags.withRepresentation(ImplFunctionRepresentation::Block);
|
||||
}
|
||||
} else if (child->getKind() == NodeKind::ImplDifferentiable) {
|
||||
flags = flags.withDifferentiabilityKind(
|
||||
ImplFunctionDifferentiabilityKind::Normal);
|
||||
} else if (child->getKind() == NodeKind::ImplLinear) {
|
||||
flags = flags.withDifferentiabilityKind(
|
||||
ImplFunctionDifferentiabilityKind::Linear);
|
||||
} else if (child->getKind() == NodeKind::ImplEscaping) {
|
||||
flags = flags.withEscaping();
|
||||
} else if (child->getKind() == NodeKind::ImplEscaping) {
|
||||
flags = flags.withEscaping();
|
||||
} else if (child->getKind() == NodeKind::ImplParameter) {
|
||||
|
||||
Reference in New Issue
Block a user