[AutoDiff upstream] Add differentiability witness SILGen. (#30545)

Generate SIL differentiability witnesses from `@differentiable` and
`@derivative` declaration attributes.

Add SILGen utilities for:
- Emiting differentiability witnesses.
- Creating derivative function thunks, which are used as entries in
  differentiability witnesses.

When users register a custom derivative function, it is necessary to create a
thunk with the expected derivative type computed from the original function's
type. This is important for consistent typing and consistent differentiability
witness entry mangling.

See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details.

Resolves TF-1138.
This commit is contained in:
Dan Zheng
2020-03-21 02:05:04 -07:00
committed by GitHub
parent 7c5b4d1fdf
commit 24445dd2e2
9 changed files with 1166 additions and 4 deletions

View File

@@ -153,6 +153,25 @@ public:
Type SelfType,
ModuleDecl *Module);
/// Mangle the derivative function (JVP/VJP) for the given:
/// - Mangled original function name.
/// - Derivative function kind.
/// - Derivative function configuration: parameter/result indices and
/// derivative generic signature.
std::string
mangleAutoDiffDerivativeFunctionHelper(StringRef name,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config);
/// Mangle the linear map (differential/pullback) for the given:
/// - Mangled original function name.
/// - Linear map kind.
/// - Derivative function configuration: parameter/result indices and
/// derivative generic signature.
std::string mangleAutoDiffLinearMapHelper(StringRef name,
AutoDiffLinearMapKind kind,
AutoDiffConfig config);
/// Mangle a SIL differentiability witness key:
/// - Mangled original function name.
/// - Parameter indices.

View File

@@ -26,6 +26,7 @@
#include "swift/AST/TypeAlignments.h"
#include "swift/Basic/Range.h"
#include "swift/Basic/SourceLoc.h"
#include "llvm/ADT/StringExtras.h"
namespace swift {
@@ -95,6 +96,45 @@ struct DifferentiabilityWitnessFunctionKind {
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
};
/// SIL-level automatic differentiation indices. Consists of:
/// - Parameter indices: indices of parameters to differentiate with respect to.
/// - Result index: index of the result to differentiate from.
// TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
// `AutoDiffConfig` supports multiple result indices.
struct SILAutoDiffIndices {
/// The index of the dependent result to differentiate from.
unsigned source;
/// The indices for independent parameters to differentiate with respect to.
IndexSubset *parameters;
/*implicit*/ SILAutoDiffIndices(unsigned source, IndexSubset *parameters)
: source(source), parameters(parameters) {}
bool operator==(const SILAutoDiffIndices &other) const;
bool operator!=(const SILAutoDiffIndices &other) const {
return !(*this == other);
};
/// Returns true if `parameterIndex` is a differentiability parameter index.
bool isWrtParameter(unsigned parameterIndex) const {
return parameterIndex < parameters->getCapacity() &&
parameters->contains(parameterIndex);
}
void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;
std::string mangle() const {
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
interleave(
parameters->getIndices(),
[&](unsigned idx) { result += llvm::utostr(idx); },
[&] { result += '_'; });
return result;
}
};
/// Identifies an autodiff derivative function configuration:
/// - Parameter indices.
/// - Result indices.
@@ -110,6 +150,11 @@ struct AutoDiffConfig {
: parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenericSignature) {}
/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
// TODO(TF-913): This is a temporary shim for incremental removal of
// `SILAutoDiffIndices`. Eventually remove this.
SILAutoDiffIndices getSILAutoDiffIndices() const;
void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;
};
@@ -282,6 +327,37 @@ void getFunctionSemanticResultTypes(
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
GenericEnvironment *genericEnv = nullptr);
/// Returns the lowered SIL parameter indices for the given AST parameter
/// indices and `AnyfunctionType`.
///
/// Notable lowering-related changes:
/// - AST tuple parameter types are exploded when lowered to SIL.
/// - AST curried `Self` parameter types become the last parameter when lowered
/// to SIL.
///
/// Examples:
///
/// AST function type: (A, B, C) -> R
/// AST parameter indices: 101, {A, C}
/// Lowered SIL function type: $(A, B, C) -> R
/// Lowered SIL parameter indices: 101
///
/// AST function type: (Self) -> (A, B, C) -> R
/// AST parameter indices: 1010, {Self, B}
/// Lowered SIL function type: $(A, B, C, Self) -> R
/// Lowered SIL parameter indices: 0101
///
/// AST function type: (A, (B, C), D) -> R
/// AST parameter indices: 110, {A, (B, C)}
/// Lowered SIL function type: $(A, B, C, D) -> R
/// Lowered SIL parameter indices: 1110
///
/// Note:
/// - The AST function type must not be curried unless it is a method.
/// Otherwise, the behavior is undefined.
IndexSubset *getLoweredParameterIndices(IndexSubset *astParameterIndices,
AnyFunctionType *functionType);
/// "Constrained" derivative generic signatures require all differentiability
/// parameters to conform to the `Differentiable` protocol.
///