//===--- SILDifferentiabilityWitness.h - Differentiability witnesses ------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file defines the SILDifferentiabilityWitness class, which maps an // original SILFunction and derivative configuration (parameter indices, result // indices, derivative generic signature) to derivative functions (JVP and VJP). // // SIL differentiability witnesses are generated from the `@differentiable` // and `@derivative` AST declaration attributes. // // Differentiability witnesses are canonicalized by the SIL differentiation // transform, which fills in missing derivative functions. // //===----------------------------------------------------------------------===// #ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H #define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H #include "swift/AST/Attr.h" #include "swift/AST/AutoDiff.h" #include "swift/AST/GenericSignature.h" #include "swift/SIL/SILAllocated.h" #include "swift/SIL/SILLinkage.h" #include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist_node.h" namespace swift { class SILPrintContext; class SILDifferentiabilityWitness : public llvm::ilist_node, public SILAllocated { private: /// The module which contains the differentiability witness. SILModule &Module; /// The linkage of the differentiability witness. SILLinkage Linkage; /// The original function. SILFunction *OriginalFunction; /// The differentiability kind. DifferentiabilityKind Kind; /// The derivative configuration: parameter indices, result indices, and /// derivative generic signature (optional). The derivative generic signature /// may contain same-type requirements such that all generic parameters are /// bound to concrete types. AutoDiffConfig Config; /// The JVP (Jacobian-vector products) derivative function. SILFunction *JVP; /// The VJP (vector-Jacobian products) derivative function. SILFunction *VJP; /// Whether or not this differentiability witness is a declaration. bool IsDeclaration; /// Whether or not this differentiability witness is serialized, which allows /// devirtualization from another module. bool IsSerialized; /// The AST `@differentiable` or `@derivative` attribute from which the /// differentiability witness is generated. Used for diagnostics. /// Null if the differentiability witness is parsed from SIL or if it is /// deserialized. const DeclAttribute *Attribute = nullptr; SILDifferentiabilityWitness( SILModule &module, SILLinkage linkage, SILFunction *originalFunction, DifferentiabilityKind kind, IndexSubset *parameterIndices, IndexSubset *resultIndices, GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isDeclaration, bool isSerialized, const DeclAttribute *attribute) : Module(module), Linkage(linkage), OriginalFunction(originalFunction), Kind(kind), Config(parameterIndices, resultIndices, derivativeGenSig.getPointer()), JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration), IsSerialized(isSerialized), Attribute(attribute) { assert(kind != DifferentiabilityKind::NonDifferentiable); } public: static SILDifferentiabilityWitness * createDeclaration(SILModule &module, SILLinkage linkage, SILFunction *originalFunction, DifferentiabilityKind kind, IndexSubset *parameterIndices, IndexSubset *resultIndices, GenericSignature derivativeGenSig, const DeclAttribute *attribute = nullptr); static SILDifferentiabilityWitness *createDefinition( SILModule &module, SILLinkage linkage, SILFunction *originalFunction, DifferentiabilityKind kind, IndexSubset *parameterIndices, IndexSubset *resultIndices, GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp, bool isSerialized, const DeclAttribute *attribute = nullptr); void convertToDefinition(SILFunction *jvp, SILFunction *vjp, bool isSerialized); SILDifferentiabilityWitnessKey getKey() const; SILModule &getModule() const { return Module; } SILLinkage getLinkage() const { return Linkage; } SILFunction *getOriginalFunction() const { return OriginalFunction; } DifferentiabilityKind getKind() const { return Kind; } const AutoDiffConfig &getConfig() const { return Config; } IndexSubset *getParameterIndices() const { return Config.parameterIndices; } IndexSubset *getResultIndices() const { return Config.resultIndices; } GenericSignature getDerivativeGenericSignature() const { return Config.derivativeGenericSignature; } SILFunction *getJVP() const { return JVP; } SILFunction *getVJP() const { return VJP; } SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const { switch (kind) { case AutoDiffDerivativeFunctionKind::JVP: return JVP; case AutoDiffDerivativeFunctionKind::VJP: return VJP; } llvm_unreachable("invalid derivative type"); } void setJVP(SILFunction *jvp) { JVP = jvp; } void setVJP(SILFunction *vjp) { VJP = vjp; } void setDerivative(AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) { switch (kind) { case AutoDiffDerivativeFunctionKind::JVP: JVP = derivative; break; case AutoDiffDerivativeFunctionKind::VJP: VJP = derivative; break; } } bool isDeclaration() const { return IsDeclaration; } bool isDefinition() const { return !IsDeclaration; } bool isSerialized() const { return IsSerialized; } const DeclAttribute *getAttribute() const { return Attribute; } /// Verify that the differentiability witness is well-formed. void verify(const SILModule &module) const; void print(llvm::raw_ostream &os, bool verbose = false) const; void dump() const; }; } // end namespace swift namespace llvm { //===----------------------------------------------------------------------===// // ilist_traits for SILDifferentiabilityWitness //===----------------------------------------------------------------------===// template <> struct ilist_traits<::swift::SILDifferentiabilityWitness> : public ilist_node_traits<::swift::SILDifferentiabilityWitness> { using SILDifferentiabilityWitness = ::swift::SILDifferentiabilityWitness; public: static void deleteNode(SILDifferentiabilityWitness *DW) { DW->~SILDifferentiabilityWitness(); } private: void createNode(const SILDifferentiabilityWitness &); }; } // namespace llvm #endif // SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H