mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
129 lines
5.5 KiB
C++
129 lines
5.5 KiB
C++
//===--- Thunk.h - Automatic differentiation thunks -----------*- C++ -*---===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Automatic differentiation thunk generation utilities.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_THUNK_H
|
|
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_THUNK_H
|
|
|
|
#include "swift/AST/AutoDiff.h"
|
|
#include "swift/Basic/LLVM.h"
|
|
#include "swift/SIL/SILBuilder.h"
|
|
|
|
namespace swift {
|
|
|
|
class SILOptFunctionBuilder;
|
|
class SILModule;
|
|
class SILLocation;
|
|
class SILValue;
|
|
class OpenedArchetypeType;
|
|
class GenericEnvironment;
|
|
class SubstitutionMap;
|
|
class ArchetypeType;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace autodiff {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Thunk helpers
|
|
//===----------------------------------------------------------------------===//
|
|
// These helpers are copied/adapted from SILGen. They should be refactored and
|
|
// moved to a shared location.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// The thunk kinds used in the differentiation transform.
|
|
enum class DifferentiationThunkKind {
|
|
/// A reabstraction thunk.
|
|
///
|
|
/// Reabstraction thunks transform a function-typed value to another one with
|
|
/// different parameter/result abstraction patterns. This is identical to the
|
|
/// thunks generated by SILGen.
|
|
Reabstraction,
|
|
|
|
/// An index subset thunk.
|
|
///
|
|
/// An index subset thunk is used transform JVP/VJPs into a version that is
|
|
/// "wrt" fewer differentiation parameters.
|
|
/// - Differentials of thunked JVPs use zero for non-requested differentiation
|
|
/// parameters.
|
|
/// - Pullbacks of thunked VJPs discard results for non-requested
|
|
/// differentiation parameters.
|
|
IndexSubset
|
|
};
|
|
|
|
CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
|
|
OpenedArchetypeType *openedExistential,
|
|
GenericEnvironment *&genericEnv,
|
|
SubstitutionMap &contextSubs,
|
|
SubstitutionMap &interfaceSubs,
|
|
ArchetypeType *&newArchetype);
|
|
|
|
/// Build the type of a function transformation thunk.
|
|
CanSILFunctionType buildThunkType(SILFunction *fn,
|
|
CanSILFunctionType &sourceType,
|
|
CanSILFunctionType &expectedType,
|
|
GenericEnvironment *&genericEnv,
|
|
SubstitutionMap &interfaceSubs,
|
|
bool withoutActuallyEscaping,
|
|
DifferentiationThunkKind thunkKind);
|
|
|
|
/// Get or create a reabstraction thunk from `fromType` to `toType`, to be
|
|
/// called in `caller`.
|
|
SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
|
|
SILModule &module, SILLocation loc,
|
|
SILFunction *caller,
|
|
CanSILFunctionType fromType,
|
|
CanSILFunctionType toType);
|
|
|
|
/// Reabstracts the given function-typed value `fn` to the target type `toType`.
|
|
/// Remaps substitutions using `remapSubstitutions`.
|
|
SILValue reabstractFunction(
|
|
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
|
|
SILValue fn, CanSILFunctionType toType,
|
|
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions);
|
|
|
|
/// Get or create a derivative function parameter index subset thunk from
|
|
/// `actualIndices` to `desiredIndices` for the given associated function
|
|
/// value and original function operand. Returns a pair of the parameter
|
|
/// index subset thunk and its interface substitution map (used to partially
|
|
/// apply the thunk).
|
|
/// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear
|
|
/// map returned by the derivative function.
|
|
std::pair<SILFunction *, SubstitutionMap>
|
|
getOrCreateSubsetParametersThunkForDerivativeFunction(
|
|
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
|
|
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
|
|
SILAutoDiffIndices actualIndices);
|
|
|
|
/// Get or create a derivative function parameter index subset thunk from
|
|
/// `actualIndices` to `desiredIndices` for the given associated function
|
|
/// value and original function operand. Returns a pair of the parameter
|
|
/// index subset thunk and its interface substitution map (used to partially
|
|
/// apply the thunk).
|
|
std::pair<SILFunction *, SubstitutionMap>
|
|
getOrCreateSubsetParametersThunkForLinearMap(
|
|
SILOptFunctionBuilder &fb, SILFunction *assocFn,
|
|
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
|
|
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
|
|
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
|
|
|
|
} // end namespace autodiff
|
|
|
|
} // end namespace swift
|
|
|
|
#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_THUNK_H
|