mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[AutoDiff upstream] Add differentiability witness SILGen. (#30545)
Generate SIL differentiability witnesses from `@differentiable` and `@derivative` declaration attributes. Add SILGen utilities for: - Emiting differentiability witnesses. - Creating derivative function thunks, which are used as entries in differentiability witnesses. When users register a custom derivative function, it is necessary to create a thunk with the expected derivative type computed from the original function's type. This is important for consistent typing and consistent differentiability witness entry mangling. See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details. Resolves TF-1138.
This commit is contained in:
@@ -153,6 +153,25 @@ public:
|
||||
Type SelfType,
|
||||
ModuleDecl *Module);
|
||||
|
||||
/// Mangle the derivative function (JVP/VJP) for the given:
|
||||
/// - Mangled original function name.
|
||||
/// - Derivative function kind.
|
||||
/// - Derivative function configuration: parameter/result indices and
|
||||
/// derivative generic signature.
|
||||
std::string
|
||||
mangleAutoDiffDerivativeFunctionHelper(StringRef name,
|
||||
AutoDiffDerivativeFunctionKind kind,
|
||||
AutoDiffConfig config);
|
||||
|
||||
/// Mangle the linear map (differential/pullback) for the given:
|
||||
/// - Mangled original function name.
|
||||
/// - Linear map kind.
|
||||
/// - Derivative function configuration: parameter/result indices and
|
||||
/// derivative generic signature.
|
||||
std::string mangleAutoDiffLinearMapHelper(StringRef name,
|
||||
AutoDiffLinearMapKind kind,
|
||||
AutoDiffConfig config);
|
||||
|
||||
/// Mangle a SIL differentiability witness key:
|
||||
/// - Mangled original function name.
|
||||
/// - Parameter indices.
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "swift/AST/TypeAlignments.h"
|
||||
#include "swift/Basic/Range.h"
|
||||
#include "swift/Basic/SourceLoc.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
||||
namespace swift {
|
||||
|
||||
@@ -95,6 +96,45 @@ struct DifferentiabilityWitnessFunctionKind {
|
||||
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
|
||||
};
|
||||
|
||||
/// SIL-level automatic differentiation indices. Consists of:
|
||||
/// - Parameter indices: indices of parameters to differentiate with respect to.
|
||||
/// - Result index: index of the result to differentiate from.
|
||||
// TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
|
||||
// `AutoDiffConfig` supports multiple result indices.
|
||||
struct SILAutoDiffIndices {
|
||||
/// The index of the dependent result to differentiate from.
|
||||
unsigned source;
|
||||
/// The indices for independent parameters to differentiate with respect to.
|
||||
IndexSubset *parameters;
|
||||
|
||||
/*implicit*/ SILAutoDiffIndices(unsigned source, IndexSubset *parameters)
|
||||
: source(source), parameters(parameters) {}
|
||||
|
||||
bool operator==(const SILAutoDiffIndices &other) const;
|
||||
|
||||
bool operator!=(const SILAutoDiffIndices &other) const {
|
||||
return !(*this == other);
|
||||
};
|
||||
|
||||
/// Returns true if `parameterIndex` is a differentiability parameter index.
|
||||
bool isWrtParameter(unsigned parameterIndex) const {
|
||||
return parameterIndex < parameters->getCapacity() &&
|
||||
parameters->contains(parameterIndex);
|
||||
}
|
||||
|
||||
void print(llvm::raw_ostream &s = llvm::outs()) const;
|
||||
SWIFT_DEBUG_DUMP;
|
||||
|
||||
std::string mangle() const {
|
||||
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
|
||||
interleave(
|
||||
parameters->getIndices(),
|
||||
[&](unsigned idx) { result += llvm::utostr(idx); },
|
||||
[&] { result += '_'; });
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/// Identifies an autodiff derivative function configuration:
|
||||
/// - Parameter indices.
|
||||
/// - Result indices.
|
||||
@@ -110,6 +150,11 @@ struct AutoDiffConfig {
|
||||
: parameterIndices(parameterIndices), resultIndices(resultIndices),
|
||||
derivativeGenericSignature(derivativeGenericSignature) {}
|
||||
|
||||
/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
|
||||
// TODO(TF-913): This is a temporary shim for incremental removal of
|
||||
// `SILAutoDiffIndices`. Eventually remove this.
|
||||
SILAutoDiffIndices getSILAutoDiffIndices() const;
|
||||
|
||||
void print(llvm::raw_ostream &s = llvm::outs()) const;
|
||||
SWIFT_DEBUG_DUMP;
|
||||
};
|
||||
@@ -282,6 +327,37 @@ void getFunctionSemanticResultTypes(
|
||||
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
|
||||
GenericEnvironment *genericEnv = nullptr);
|
||||
|
||||
/// Returns the lowered SIL parameter indices for the given AST parameter
|
||||
/// indices and `AnyfunctionType`.
|
||||
///
|
||||
/// Notable lowering-related changes:
|
||||
/// - AST tuple parameter types are exploded when lowered to SIL.
|
||||
/// - AST curried `Self` parameter types become the last parameter when lowered
|
||||
/// to SIL.
|
||||
///
|
||||
/// Examples:
|
||||
///
|
||||
/// AST function type: (A, B, C) -> R
|
||||
/// AST parameter indices: 101, {A, C}
|
||||
/// Lowered SIL function type: $(A, B, C) -> R
|
||||
/// Lowered SIL parameter indices: 101
|
||||
///
|
||||
/// AST function type: (Self) -> (A, B, C) -> R
|
||||
/// AST parameter indices: 1010, {Self, B}
|
||||
/// Lowered SIL function type: $(A, B, C, Self) -> R
|
||||
/// Lowered SIL parameter indices: 0101
|
||||
///
|
||||
/// AST function type: (A, (B, C), D) -> R
|
||||
/// AST parameter indices: 110, {A, (B, C)}
|
||||
/// Lowered SIL function type: $(A, B, C, D) -> R
|
||||
/// Lowered SIL parameter indices: 1110
|
||||
///
|
||||
/// Note:
|
||||
/// - The AST function type must not be curried unless it is a method.
|
||||
/// Otherwise, the behavior is undefined.
|
||||
IndexSubset *getLoweredParameterIndices(IndexSubset *astParameterIndices,
|
||||
AnyFunctionType *functionType);
|
||||
|
||||
/// "Constrained" derivative generic signatures require all differentiability
|
||||
/// parameters to conform to the `Differentiable` protocol.
|
||||
///
|
||||
|
||||
Reference in New Issue
Block a user