mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Differentiability witnesses are now keyed by the original function name, the differentiability kind, and the autodiff config.
Updated SIL syntax:
```
differentiability-kind ::= 'forward' | 'reverse' | 'normal' | 'linear'
sil-differentiability-witness ::=
'sil_differentiability_witness'
sil-linkage?
'[' differentiability-kind ']'
'[' 'parameters' sil-differentiability-witness-function-index-list ']'
'[' 'results' sil-differentiability-witness-function-index-list ']'
generic-parameter-clause?
sil-function-name ':' sil-type
sil-differentiability-witness-body?
sil-instruction ::=
'differentiability_witness_function'
'[' sil-differentiability-witness-function-kind ']'
'[' differentiability-kind ']'
'[' 'parameters' sil-differentiability-witness-function-index-list ']'
'[' 'results' sil-differentiability-witness-function-index-list ']'
generic-parameter-clause?
sil-function-name ':' sil-type
```
```console
sil_differentiability_witness [reverse] [parameters 0 1] [results 0] <T where T: Differentiable> @foo : <T> $(T) -> T
differentiability_witness_function [vjp] [reverse] [parameters 0] [results 0] <T where T: Differentiable> @foo : $(T) -> T
```
New mangling:
```swift
global ::= global generic-signature? 'WJ' DIFFERENTIABILITY-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' // differentiability witness
```
```console
$s13test_mangling3fooyS2f_S2ftFWJrSpSr ---> reverse differentiability witness for test_mangling.foo(Swift.Float, Swift.Float, Swift.Float) -> Swift.Float with respect to parameters {0} and results {0}
```
Resolves rdar://74380324.
176 lines
6.8 KiB
C++
176 lines
6.8 KiB
C++
//===--- SILDifferentiabilityWitness.h - Differentiability witnesses ------===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines the SILDifferentiabilityWitness class, which maps an
|
|
// original SILFunction and derivative configuration (parameter indices, result
|
|
// indices, derivative generic signature) to derivative functions (JVP and VJP).
|
|
//
|
|
// SIL differentiability witnesses are generated from the `@differentiable`
|
|
// and `@derivative` AST declaration attributes.
|
|
//
|
|
// Differentiability witnesses are canonicalized by the SIL differentiation
|
|
// transform, which fills in missing derivative functions.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
|
|
#define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
|
|
|
|
#include "swift/AST/Attr.h"
|
|
#include "swift/AST/AutoDiff.h"
|
|
#include "swift/AST/GenericSignature.h"
|
|
#include "swift/SIL/SILAllocated.h"
|
|
#include "swift/SIL/SILLinkage.h"
|
|
#include "llvm/ADT/ilist.h"
|
|
#include "llvm/ADT/ilist_node.h"
|
|
|
|
namespace swift {
|
|
|
|
class SILPrintContext;
|
|
|
|
class SILDifferentiabilityWitness
|
|
: public llvm::ilist_node<SILDifferentiabilityWitness>,
|
|
public SILAllocated<SILDifferentiabilityWitness> {
|
|
private:
|
|
/// The module which contains the differentiability witness.
|
|
SILModule &Module;
|
|
/// The linkage of the differentiability witness.
|
|
SILLinkage Linkage;
|
|
/// The original function.
|
|
SILFunction *OriginalFunction;
|
|
/// The differentiability kind.
|
|
DifferentiabilityKind Kind;
|
|
/// The derivative configuration: parameter indices, result indices, and
|
|
/// derivative generic signature (optional). The derivative generic signature
|
|
/// may contain same-type requirements such that all generic parameters are
|
|
/// bound to concrete types.
|
|
AutoDiffConfig Config;
|
|
/// The JVP (Jacobian-vector products) derivative function.
|
|
SILFunction *JVP;
|
|
/// The VJP (vector-Jacobian products) derivative function.
|
|
SILFunction *VJP;
|
|
/// Whether or not this differentiability witness is a declaration.
|
|
bool IsDeclaration;
|
|
/// Whether or not this differentiability witness is serialized, which allows
|
|
/// devirtualization from another module.
|
|
bool IsSerialized;
|
|
/// The AST `@differentiable` or `@derivative` attribute from which the
|
|
/// differentiability witness is generated. Used for diagnostics.
|
|
/// Null if the differentiability witness is parsed from SIL or if it is
|
|
/// deserialized.
|
|
const DeclAttribute *Attribute = nullptr;
|
|
|
|
SILDifferentiabilityWitness(
|
|
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
|
|
DifferentiabilityKind kind, IndexSubset *parameterIndices,
|
|
IndexSubset *resultIndices, GenericSignature derivativeGenSig,
|
|
SILFunction *jvp, SILFunction *vjp, bool isDeclaration, bool isSerialized,
|
|
const DeclAttribute *attribute)
|
|
: Module(module), Linkage(linkage), OriginalFunction(originalFunction),
|
|
Kind(kind),
|
|
Config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
|
|
JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration),
|
|
IsSerialized(isSerialized), Attribute(attribute) {
|
|
assert(kind != DifferentiabilityKind::NonDifferentiable);
|
|
}
|
|
|
|
public:
|
|
static SILDifferentiabilityWitness *
|
|
createDeclaration(SILModule &module, SILLinkage linkage,
|
|
SILFunction *originalFunction, DifferentiabilityKind kind,
|
|
IndexSubset *parameterIndices, IndexSubset *resultIndices,
|
|
GenericSignature derivativeGenSig,
|
|
const DeclAttribute *attribute = nullptr);
|
|
|
|
static SILDifferentiabilityWitness *createDefinition(
|
|
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
|
|
DifferentiabilityKind kind, IndexSubset *parameterIndices,
|
|
IndexSubset *resultIndices, GenericSignature derivativeGenSig,
|
|
SILFunction *jvp, SILFunction *vjp, bool isSerialized,
|
|
const DeclAttribute *attribute = nullptr);
|
|
|
|
void convertToDefinition(SILFunction *jvp, SILFunction *vjp,
|
|
bool isSerialized);
|
|
|
|
SILDifferentiabilityWitnessKey getKey() const;
|
|
SILModule &getModule() const { return Module; }
|
|
SILLinkage getLinkage() const { return Linkage; }
|
|
SILFunction *getOriginalFunction() const { return OriginalFunction; }
|
|
DifferentiabilityKind getKind() const { return Kind; }
|
|
const AutoDiffConfig &getConfig() const { return Config; }
|
|
IndexSubset *getParameterIndices() const { return Config.parameterIndices; }
|
|
IndexSubset *getResultIndices() const { return Config.resultIndices; }
|
|
GenericSignature getDerivativeGenericSignature() const {
|
|
return Config.derivativeGenericSignature;
|
|
}
|
|
SILFunction *getJVP() const { return JVP; }
|
|
SILFunction *getVJP() const { return VJP; }
|
|
SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
|
|
switch (kind) {
|
|
case AutoDiffDerivativeFunctionKind::JVP:
|
|
return JVP;
|
|
case AutoDiffDerivativeFunctionKind::VJP:
|
|
return VJP;
|
|
}
|
|
llvm_unreachable("invalid derivative type");
|
|
}
|
|
void setJVP(SILFunction *jvp) { JVP = jvp; }
|
|
void setVJP(SILFunction *vjp) { VJP = vjp; }
|
|
void setDerivative(AutoDiffDerivativeFunctionKind kind,
|
|
SILFunction *derivative) {
|
|
switch (kind) {
|
|
case AutoDiffDerivativeFunctionKind::JVP:
|
|
JVP = derivative;
|
|
break;
|
|
case AutoDiffDerivativeFunctionKind::VJP:
|
|
VJP = derivative;
|
|
break;
|
|
}
|
|
}
|
|
bool isDeclaration() const { return IsDeclaration; }
|
|
bool isDefinition() const { return !IsDeclaration; }
|
|
bool isSerialized() const { return IsSerialized; }
|
|
const DeclAttribute *getAttribute() const { return Attribute; }
|
|
|
|
/// Verify that the differentiability witness is well-formed.
|
|
void verify(const SILModule &module) const;
|
|
|
|
void print(llvm::raw_ostream &os, bool verbose = false) const;
|
|
void dump() const;
|
|
};
|
|
|
|
} // end namespace swift
|
|
|
|
namespace llvm {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ilist_traits for SILDifferentiabilityWitness
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <>
|
|
struct ilist_traits<::swift::SILDifferentiabilityWitness>
|
|
: public ilist_node_traits<::swift::SILDifferentiabilityWitness> {
|
|
using SILDifferentiabilityWitness = ::swift::SILDifferentiabilityWitness;
|
|
|
|
public:
|
|
static void deleteNode(SILDifferentiabilityWitness *DW) {
|
|
DW->~SILDifferentiabilityWitness();
|
|
}
|
|
|
|
private:
|
|
void createNode(const SILDifferentiabilityWitness &);
|
|
};
|
|
|
|
} // namespace llvm
|
|
|
|
#endif // SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
|