Files
swift-mirror/include/swift/SILOptimizer/Differentiation/LinearMapInfo.h
Anton Korobeynikov c7a216058f [AutoDiff] First cut of coroutines differentiation (#71461)
This PR implements first set of changes required to support autodiff for coroutines. It mostly targeted to `_modify` accessors in standard library (and beyond), but overall implementation is quite generic.

There are some specifics of implementation and known limitations:
 - Only `@yield_once` coroutines are naturally supported
 - VJP is a coroutine itself: it yields the results *and* returns a pullback closure as a normal return. This allows us to capture values produced in resume part of a coroutine (this is required for defers and other cleanups / commits)
 - Pullback is a coroutine, we assume that coroutine cannot abort and therefore we execute the original coroutine in reverse from return via yield and then back to the entry
 - It seems there is no semantically sane way to support `_read` coroutines (as we will need to "accept" adjoints via yields), therefore only coroutines with inout yields are supported (`_modify` accessors). Pullbacks of such coroutines take adjoint buffer as input argument, yield this buffer (to accumulate adjoint values in the caller) and finally return the adjoints indirectly.
 - Coroutines (as opposed to normal functions) are not first-class values: there is no AST type for them, one cannot e.g. store them into tuples, etc. So, everywhere where AST type is required, we have to hack around.
 - As there is no AST type for coroutines, there is no way one could register custom derivative for coroutines. So far only compiler-produced derivatives are supported
 - There are lots of common things wrt normal function apply's, but still there are subtle but important differences. I tried to organize the code to enable code reuse, still it was not always possible, so some code duplication could be seen
 - The order of how pullback closures are produced in VJP is a bit different: for normal apply's VJP produces both value and pullback closure via a single nested VJP apply. This is not so anymore with coroutine VJP's: yielded values are produced at `begin_apply` site and pullback closure is available only from `end_apply`, so we need to track the order in which pullbacks are produced (and arrange consumption of the values accordingly – effectively delay them)
 - On the way some complementary changes were required in e.g. mangler / demangler

This patch covers the generation of derivatives up to SIL level, however, it is not enough as codegen of `partial_apply` of a coroutine is completely broken. The fix for this will be submitted separately as it is not directly autodiff-related.

---------

Co-authored-by: Andrew Savonichev <andrew.savonichev@gmail.com>
Co-authored-by: Richard Wei <rxwei@apple.com>
2024-04-04 17:24:55 -07:00

206 lines
8.2 KiB
C++

//===--- LinearMapInfo.h --------------------------------------*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Linear map struct and branching trace enum information for differentiation.
//
//===----------------------------------------------------------------------===//
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H
#include "swift/AST/AutoDiff.h"
#include "swift/AST/SynthesizedFileUnit.h"
#include "swift/SIL/ApplySite.h"
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
#include "llvm/ADT/DenseMap.h"
namespace swift {
class SILFunction;
class SILLoopInfo;
namespace autodiff {
class ADContext;
/// Linear map struct and branching trace enum information for an original
/// function and derivative function (JVP or VJP).
///
/// Linear map structs contain all callee linear maps produced in a JVP/VJP
/// basic block. A linear map struct is created for each basic block in the
/// original function, and a linear map struct field is created for every active
/// `apply` in the original basic block.
///
/// Branching trace enums model the control flow graph of the original function.
/// A branching trace enum is created for each basic block in the original
/// function, and a branching trace enum case is created for every basic block
/// predecessor/successor. This supports control flow differentiation: JVP/VJP
/// functions build branching trace enums to record an execution trace. Indirect
/// branching trace enums are created for basic blocks that are in loops.
///
/// Linear map struct values and branching trace enum values are constructed in
/// JVP/VJP functions and consumed in pullback/differential functions.
class LinearMapInfo {
private:
/// The linear map kind.
AutoDiffLinearMapKind kind;
/// The original function.
SILFunction *const original;
/// The derivative function.
SILFunction *const derivative;
/// Activity info of the original function.
const DifferentiableActivityInfo &activityInfo;
/// The original function's loop info.
SILLoopInfo *loopInfo;
/// Differentiation indices of the function.
const AutoDiffConfig config;
/// Mapping from original basic blocks to linear map tuple types.
llvm::DenseMap<SILBasicBlock *, TupleType *> linearMapTuples;
/// Mapping from original basic blocks to branching trace enums.
/// For pullbacks: these are predecessor enums.
/// For differentials: these are successor enums.
llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;
/// Mapping from `apply` / `begin_apply` instructions in the original function to the
/// corresponding linear map tuple type index.
llvm::DenseMap<FullApplySite, unsigned> linearMapIndexMap;
/// Mapping from predecessor-successor basic block pairs in the original
/// function to the corresponding branching trace enum case.
llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
branchingTraceEnumCases;
/// A synthesized file unit.
SynthesizedFileUnit &synthesizedFile;
/// A type converter, used to compute struct/enum SIL types.
Lowering::TypeConverter &typeConverter;
/// True, if a heap-allocated context is required. For example, when there are
/// any loops
bool heapAllocatedContext = false;
private:
/// Remaps the given type into the derivative function's context.
SILType remapTypeInDerivative(SILType ty);
/// Retrieves the file unit that contains implicit declarations in the
/// current Swift module.
SynthesizedFileUnit &getSynthesizedFile() { return synthesizedFile; }
/// Creates an enum declaration with the given JVP/VJP generic signature,
/// whose cases represent the predecessors/successors of the given original
/// block.
EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB,
CanGenericSignature genericSig);
void populateBranchingTraceDecl(SILBasicBlock *originalBB,
SILLoopInfo *loopInfo);
/// Given an `apply` / `begin_apply` instruction, conditionally gets a linear
/// map tuple field AST type for its linear map function if it is active.
Type getLinearMapType(ADContext &context, FullApplySite fai);
/// Generates linear map struct and branching enum declarations for the given
/// function. Linear map structs are populated with linear map fields and a
/// branching enum field.
void generateDifferentiationDataStructures(ADContext &context,
SILFunction *derivative);
public:
bool shouldDifferentiateApplySite(FullApplySite applySite);
bool shouldDifferentiateInstruction(SILInstruction *inst);
LinearMapInfo(const LinearMapInfo &) = delete;
LinearMapInfo &operator=(const LinearMapInfo &) = delete;
explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
SILFunction *original, SILFunction *derivative,
const AutoDiffConfig &config,
const DifferentiableActivityInfo &activityInfo,
SILLoopInfo *loopInfo);
/// Returns the linear map tuple associated with the given original block.
TupleType *getLinearMapTupleType(SILBasicBlock *origBB) const {
return linearMapTuples.lookup(origBB);
}
/// Returns the lowered SIL type of the linear map tuple associated with the
/// given original block.
SILType getLinearMapTupleLoweredType(SILBasicBlock *origBB) const {
auto derivativeGenSig =
derivative->getLoweredFunctionType()->getSubstGenericSignature();
auto linMapTupleType =
getLinearMapTupleType(origBB)->getReducedType(derivativeGenSig);
Lowering::AbstractionPattern pattern(derivativeGenSig, linMapTupleType);
return typeConverter.getLoweredType(pattern, linMapTupleType,
TypeExpansionContext::minimal());
}
/// Returns the branching trace enum associated with the given original block.
EnumDecl *getBranchingTraceDecl(SILBasicBlock *origBB) const {
return branchingTraceDecls.lookup(origBB);
}
/// Returns the lowered SIL type of the branching trace enum associated with
/// the given original block.
SILType getBranchingTraceEnumLoweredType(SILBasicBlock *origBB) const {
auto *traceDecl = getBranchingTraceDecl(origBB);
auto traceDeclType =
traceDecl->getDeclaredInterfaceType()->getCanonicalType();
Lowering::AbstractionPattern pattern(
derivative->getLoweredFunctionType()->getSubstGenericSignature(),
traceDeclType);
return typeConverter.getLoweredType(pattern, traceDeclType,
TypeExpansionContext::minimal());
}
/// Returns the enum element in the given successor block's branching trace
/// enum corresponding to the given predecessor block.
EnumElementDecl *
lookUpBranchingTraceEnumElement(SILBasicBlock *origPredBB,
SILBasicBlock *origSuccBB) const {
assert(origPredBB->getParent() == original);
return branchingTraceEnumCases.lookup({origPredBB, origSuccBB});
}
/// Finds the linear map index in the pullback tuple for the given
/// `apply` / `begin_apply` instruction in the original function.
unsigned lookUpLinearMapIndex(FullApplySite fas) const {
assert(fas->getFunction() == original);
auto lookup = linearMapIndexMap.find(fas);
assert(lookup != linearMapIndexMap.end() &&
"No linear map field corresponding to the given `apply`");
return lookup->getSecond();
}
Type lookUpLinearMapType(FullApplySite fas) const {
unsigned idx = lookUpLinearMapIndex(fas);
return getLinearMapTupleType(fas->getParent())->getElement(idx).getType();
}
bool hasHeapAllocatedContext() const {
return heapAllocatedContext;
}
};
} // end namespace autodiff
} // end namespace swift
#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H