Files
swift-mirror/include/swift/AST/AutoDiff.h
Dan Zheng 24445dd2e2 [AutoDiff upstream] Add differentiability witness SILGen. (#30545)
Generate SIL differentiability witnesses from `@differentiable` and
`@derivative` declaration attributes.

Add SILGen utilities for:
- Emiting differentiability witnesses.
- Creating derivative function thunks, which are used as entries in
  differentiability witnesses.

When users register a custom derivative function, it is necessary to create a
thunk with the expected derivative type computed from the original function's
type. This is important for consistent typing and consistent differentiability
witness entry mangling.

See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details.

Resolves TF-1138.
2020-03-21 02:05:04 -07:00

507 lines
17 KiB
C++

//===--- AutoDiff.h - Swift automatic differentiation utilities -----------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file defines utilities for automatic differentiation.
//
//===----------------------------------------------------------------------===//
#ifndef SWIFT_AST_AUTODIFF_H
#define SWIFT_AST_AUTODIFF_H
#include <cstdint>
#include "swift/AST/GenericSignature.h"
#include "swift/AST/Identifier.h"
#include "swift/AST/IndexSubset.h"
#include "swift/AST/Type.h"
#include "swift/AST/TypeAlignments.h"
#include "swift/Basic/Range.h"
#include "swift/Basic/SourceLoc.h"
#include "llvm/ADT/StringExtras.h"
namespace swift {
class AnyFunctionType;
class SILFunctionType;
class TupleType;
/// A function type differentiability kind.
enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0,
Normal = 1,
Linear = 2
};
/// The kind of an linear map.
struct AutoDiffLinearMapKind {
enum innerty : uint8_t {
// The differential function.
Differential = 0,
// The pullback function.
Pullback = 1
} rawValue;
AutoDiffLinearMapKind() = default;
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
operator innerty() const { return rawValue; }
};
/// The kind of a derivative function.
struct AutoDiffDerivativeFunctionKind {
enum innerty : uint8_t {
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1
} rawValue;
AutoDiffDerivativeFunctionKind() = default;
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
explicit AutoDiffDerivativeFunctionKind(StringRef string);
operator innerty() const { return rawValue; }
AutoDiffLinearMapKind getLinearMapKind() {
return (AutoDiffLinearMapKind::innerty)rawValue;
}
};
/// The kind of a differentiability witness function.
struct DifferentiabilityWitnessFunctionKind {
enum innerty : uint8_t {
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1,
// The transpose function.
Transpose = 2
} rawValue;
DifferentiabilityWitnessFunctionKind() = default;
DifferentiabilityWitnessFunctionKind(innerty rawValue) : rawValue(rawValue) {}
explicit DifferentiabilityWitnessFunctionKind(unsigned rawValue)
: rawValue(static_cast<innerty>(rawValue)) {}
explicit DifferentiabilityWitnessFunctionKind(StringRef name);
operator innerty() const { return rawValue; }
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
};
/// SIL-level automatic differentiation indices. Consists of:
/// - Parameter indices: indices of parameters to differentiate with respect to.
/// - Result index: index of the result to differentiate from.
// TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
// `AutoDiffConfig` supports multiple result indices.
struct SILAutoDiffIndices {
/// The index of the dependent result to differentiate from.
unsigned source;
/// The indices for independent parameters to differentiate with respect to.
IndexSubset *parameters;
/*implicit*/ SILAutoDiffIndices(unsigned source, IndexSubset *parameters)
: source(source), parameters(parameters) {}
bool operator==(const SILAutoDiffIndices &other) const;
bool operator!=(const SILAutoDiffIndices &other) const {
return !(*this == other);
};
/// Returns true if `parameterIndex` is a differentiability parameter index.
bool isWrtParameter(unsigned parameterIndex) const {
return parameterIndex < parameters->getCapacity() &&
parameters->contains(parameterIndex);
}
void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;
std::string mangle() const {
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
interleave(
parameters->getIndices(),
[&](unsigned idx) { result += llvm::utostr(idx); },
[&] { result += '_'; });
return result;
}
};
/// Identifies an autodiff derivative function configuration:
/// - Parameter indices.
/// - Result indices.
/// - Derivative generic signature (optional).
struct AutoDiffConfig {
IndexSubset *parameterIndices;
IndexSubset *resultIndices;
GenericSignature derivativeGenericSignature;
/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
IndexSubset *resultIndices,
GenericSignature derivativeGenericSignature)
: parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenericSignature) {}
/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
// TODO(TF-913): This is a temporary shim for incremental removal of
// `SILAutoDiffIndices`. Eventually remove this.
SILAutoDiffIndices getSILAutoDiffIndices() const;
void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;
};
/// A semantic function result type: either a formal function result type or
/// an `inout` parameter type. Used in derivative function type calculation.
struct AutoDiffSemanticFunctionResultType {
Type type;
bool isInout;
};
/// Key for caching SIL derivative function types.
struct SILAutoDiffDerivativeFunctionKey {
SILFunctionType *originalType;
IndexSubset *parameterIndices;
IndexSubset *resultIndices;
AutoDiffDerivativeFunctionKind kind;
CanGenericSignature derivativeFnGenSig;
bool isReabstractionThunk;
};
class ParsedAutoDiffParameter {
public:
enum class Kind { Named, Ordered, Self };
private:
SourceLoc loc;
Kind kind;
union Value {
struct { Identifier name; } Named;
struct { unsigned index; } Ordered;
struct {} self;
Value(Identifier name) : Named({name}) {}
Value(unsigned index) : Ordered({index}) {}
Value() {}
} value;
public:
ParsedAutoDiffParameter(SourceLoc loc, Kind kind, Value value)
: loc(loc), kind(kind), value(value) {}
ParsedAutoDiffParameter(SourceLoc loc, Kind kind, unsigned index)
: loc(loc), kind(kind), value(index) {}
static ParsedAutoDiffParameter getNamedParameter(SourceLoc loc,
Identifier name) {
return { loc, Kind::Named, name };
}
static ParsedAutoDiffParameter getOrderedParameter(SourceLoc loc,
unsigned index) {
return { loc, Kind::Ordered, index };
}
static ParsedAutoDiffParameter getSelfParameter(SourceLoc loc) {
return { loc, Kind::Self, {} };
}
Identifier getName() const {
assert(kind == Kind::Named);
return value.Named.name;
}
unsigned getIndex() const {
return value.Ordered.index;
}
Kind getKind() const {
return kind;
}
SourceLoc getLoc() const {
return loc;
}
bool isEqual(const ParsedAutoDiffParameter &other) const {
if (getKind() != other.getKind())
return false;
if (getKind() == Kind::Named)
return getName() == other.getName();
return getKind() == Kind::Self;
}
};
/// The tangent space of a type.
///
/// For `Differentiable`-conforming types:
/// - The tangent space is the `TangentVector` associated type.
///
/// For tuple types:
/// - The tangent space is a tuple of the elements' tangent space types, for the
/// elements that have a tangent space.
///
/// Other types have no tangent space.
class TangentSpace {
public:
/// A tangent space kind.
enum class Kind {
/// The `TangentVector` associated type of a `Differentiable`-conforming
/// type.
TangentVector,
/// A product of tangent spaces as a tuple.
Tuple
};
private:
Kind kind;
union Value {
// TangentVector
Type tangentVectorType;
// Tuple
TupleType *tupleType;
Value(Type tangentVectorType) : tangentVectorType(tangentVectorType) {}
Value(TupleType *tupleType) : tupleType(tupleType) {}
} value;
TangentSpace(Kind kind, Value value) : kind(kind), value(value) {}
public:
TangentSpace() = delete;
static TangentSpace getTangentVector(Type tangentVectorType) {
return {Kind::TangentVector, tangentVectorType};
}
static TangentSpace getTuple(TupleType *tupleTy) {
return {Kind::Tuple, tupleTy};
}
bool isTangentVector() const { return kind == Kind::TangentVector; }
bool isTuple() const { return kind == Kind::Tuple; }
Kind getKind() const { return kind; }
Type getTangentVector() const {
assert(kind == Kind::TangentVector);
return value.tangentVectorType;
}
TupleType *getTuple() const {
assert(kind == Kind::Tuple);
return value.tupleType;
}
/// Get the tangent space type.
Type getType() const;
/// Get the tangent space canonical type.
CanType getCanonicalType() const;
/// Get the underlying nominal type declaration of the tangent space type.
NominalTypeDecl *getNominal() const;
};
/// The key type used for uniquing `SILDifferentiabilityWitness` in
/// `SILModule`: original function name, parameter indices, result indices, and
/// derivative generic signature.
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
/// Automatic differentiation utility namespace.
namespace autodiff {
/// Given a function type, collects its semantic result types in type order
/// into `result`: first, the formal result type (if non-`Void`), followed by
/// `inout` parameter types.
///
/// The function type may have at most two parameter lists.
///
/// Remaps the original semantic result using `genericEnv`, if specified.
void getFunctionSemanticResultTypes(
AnyFunctionType *functionType,
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
GenericEnvironment *genericEnv = nullptr);
/// Returns the lowered SIL parameter indices for the given AST parameter
/// indices and `AnyfunctionType`.
///
/// Notable lowering-related changes:
/// - AST tuple parameter types are exploded when lowered to SIL.
/// - AST curried `Self` parameter types become the last parameter when lowered
/// to SIL.
///
/// Examples:
///
/// AST function type: (A, B, C) -> R
/// AST parameter indices: 101, {A, C}
/// Lowered SIL function type: $(A, B, C) -> R
/// Lowered SIL parameter indices: 101
///
/// AST function type: (Self) -> (A, B, C) -> R
/// AST parameter indices: 1010, {Self, B}
/// Lowered SIL function type: $(A, B, C, Self) -> R
/// Lowered SIL parameter indices: 0101
///
/// AST function type: (A, (B, C), D) -> R
/// AST parameter indices: 110, {A, (B, C)}
/// Lowered SIL function type: $(A, B, C, D) -> R
/// Lowered SIL parameter indices: 1110
///
/// Note:
/// - The AST function type must not be curried unless it is a method.
/// Otherwise, the behavior is undefined.
IndexSubset *getLoweredParameterIndices(IndexSubset *astParameterIndices,
AnyFunctionType *functionType);
/// "Constrained" derivative generic signatures require all differentiability
/// parameters to conform to the `Differentiable` protocol.
///
/// "Constrained" transpose generic signatures additionally require all
/// linearity parameters to satisfy `Self == Self.TangentVector`.
///
/// Returns the "constrained" derivative/transpose generic signature given:
/// - An original SIL function type.
/// - Differentiability parameter indices.
/// - A possibly "unconstrained" derivative generic signature.
GenericSignature getConstrainedDerivativeGenericSignature(
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
bool isTranspose = false);
} // end namespace autodiff
} // end namespace swift
namespace llvm {
using swift::AutoDiffConfig;
using swift::AutoDiffDerivativeFunctionKind;
using swift::CanGenericSignature;
using swift::GenericSignature;
using swift::IndexSubset;
using swift::SILAutoDiffDerivativeFunctionKey;
using swift::SILFunctionType;
template <typename T> struct DenseMapInfo;
template <> struct DenseMapInfo<AutoDiffConfig> {
static AutoDiffConfig getEmptyKey() {
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
// The `derivativeGenericSignature` component must be `nullptr` so that
// `getHashValue` and `isEqual` do not try to call
// `GenericSignatureImpl::getCanonicalSignature()` on an invalid pointer.
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
nullptr};
}
static AutoDiffConfig getTombstoneKey() {
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
// The `derivativeGenericSignature` component must be `nullptr` so that
// `getHashValue` and `isEqual` do not try to call
// `GenericSignatureImpl::getCanonicalSignature()` on an invalid pointer.
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
nullptr};
}
static unsigned getHashValue(const AutoDiffConfig &Val) {
auto canGenSig =
Val.derivativeGenericSignature
? Val.derivativeGenericSignature->getCanonicalSignature()
: nullptr;
unsigned combinedHash = hash_combine(
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
DenseMapInfo<GenericSignature>::getHashValue(canGenSig));
return combinedHash;
}
static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
auto lhsCanGenSig =
LHS.derivativeGenericSignature
? LHS.derivativeGenericSignature->getCanonicalSignature()
: nullptr;
auto rhsCanGenSig =
RHS.derivativeGenericSignature
? RHS.derivativeGenericSignature->getCanonicalSignature()
: nullptr;
return LHS.parameterIndices == RHS.parameterIndices &&
LHS.resultIndices == RHS.resultIndices &&
DenseMapInfo<GenericSignature>::isEqual(lhsCanGenSig, rhsCanGenSig);
}
};
template <> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
static AutoDiffDerivativeFunctionKind getEmptyKey() {
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
DenseMapInfo<unsigned>::getEmptyKey());
}
static AutoDiffDerivativeFunctionKind getTombstoneKey() {
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
DenseMapInfo<unsigned>::getTombstoneKey());
}
static unsigned getHashValue(const AutoDiffDerivativeFunctionKind &Val) {
return DenseMapInfo<unsigned>::getHashValue(Val);
}
static bool isEqual(const AutoDiffDerivativeFunctionKind &LHS,
const AutoDiffDerivativeFunctionKind &RHS) {
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(LHS) ==
static_cast<AutoDiffDerivativeFunctionKind::innerty>(RHS);
}
};
template <> struct DenseMapInfo<SILAutoDiffDerivativeFunctionKey> {
static bool isEqual(const SILAutoDiffDerivativeFunctionKey lhs,
const SILAutoDiffDerivativeFunctionKey rhs) {
return lhs.originalType == rhs.originalType &&
lhs.parameterIndices == rhs.parameterIndices &&
lhs.resultIndices == rhs.resultIndices &&
lhs.kind.rawValue == rhs.kind.rawValue &&
lhs.derivativeFnGenSig == rhs.derivativeFnGenSig &&
lhs.isReabstractionThunk == rhs.isReabstractionThunk;
}
static inline SILAutoDiffDerivativeFunctionKey getEmptyKey() {
return {DenseMapInfo<SILFunctionType *>::getEmptyKey(),
DenseMapInfo<IndexSubset *>::getEmptyKey(),
DenseMapInfo<IndexSubset *>::getEmptyKey(),
AutoDiffDerivativeFunctionKind::innerty(
DenseMapInfo<unsigned>::getEmptyKey()),
CanGenericSignature(DenseMapInfo<GenericSignature>::getEmptyKey()),
(bool)DenseMapInfo<unsigned>::getEmptyKey()};
}
static inline SILAutoDiffDerivativeFunctionKey getTombstoneKey() {
return {
DenseMapInfo<SILFunctionType *>::getTombstoneKey(),
DenseMapInfo<IndexSubset *>::getTombstoneKey(),
DenseMapInfo<IndexSubset *>::getTombstoneKey(),
AutoDiffDerivativeFunctionKind::innerty(
DenseMapInfo<unsigned>::getTombstoneKey()),
CanGenericSignature(DenseMapInfo<GenericSignature>::getTombstoneKey()),
(bool)DenseMapInfo<unsigned>::getTombstoneKey()};
}
static unsigned getHashValue(const SILAutoDiffDerivativeFunctionKey &Val) {
return hash_combine(
DenseMapInfo<SILFunctionType *>::getHashValue(Val.originalType),
DenseMapInfo<IndexSubset *>::getHashValue(Val.parameterIndices),
DenseMapInfo<IndexSubset *>::getHashValue(Val.resultIndices),
DenseMapInfo<unsigned>::getHashValue((unsigned)Val.kind.rawValue),
DenseMapInfo<GenericSignature>::getHashValue(Val.derivativeFnGenSig),
DenseMapInfo<unsigned>::getHashValue(
(unsigned)Val.isReabstractionThunk));
}
};
} // end namespace llvm
#endif // SWIFT_AST_AUTODIFF_H