mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
This PR implements first set of changes required to support autodiff for coroutines. It mostly targeted to `_modify` accessors in standard library (and beyond), but overall implementation is quite generic. There are some specifics of implementation and known limitations: - Only `@yield_once` coroutines are naturally supported - VJP is a coroutine itself: it yields the results *and* returns a pullback closure as a normal return. This allows us to capture values produced in resume part of a coroutine (this is required for defers and other cleanups / commits) - Pullback is a coroutine, we assume that coroutine cannot abort and therefore we execute the original coroutine in reverse from return via yield and then back to the entry - It seems there is no semantically sane way to support `_read` coroutines (as we will need to "accept" adjoints via yields), therefore only coroutines with inout yields are supported (`_modify` accessors). Pullbacks of such coroutines take adjoint buffer as input argument, yield this buffer (to accumulate adjoint values in the caller) and finally return the adjoints indirectly. - Coroutines (as opposed to normal functions) are not first-class values: there is no AST type for them, one cannot e.g. store them into tuples, etc. So, everywhere where AST type is required, we have to hack around. - As there is no AST type for coroutines, there is no way one could register custom derivative for coroutines. So far only compiler-produced derivatives are supported - There are lots of common things wrt normal function apply's, but still there are subtle but important differences. I tried to organize the code to enable code reuse, still it was not always possible, so some code duplication could be seen - The order of how pullback closures are produced in VJP is a bit different: for normal apply's VJP produces both value and pullback closure via a single nested VJP apply. This is not so anymore with coroutine VJP's: yielded values are produced at `begin_apply` site and pullback closure is available only from `end_apply`, so we need to track the order in which pullbacks are produced (and arrange consumption of the values accordingly – effectively delay them) - On the way some complementary changes were required in e.g. mangler / demangler This patch covers the generation of derivatives up to SIL level, however, it is not enough as codegen of `partial_apply` of a coroutine is completely broken. The fix for this will be submitted separately as it is not directly autodiff-related. --------- Co-authored-by: Andrew Savonichev <andrew.savonichev@gmail.com> Co-authored-by: Richard Wei <rxwei@apple.com>
273 lines
8.1 KiB
C++
273 lines
8.1 KiB
C++
//===--- IndexSubset.h - Fixed-size subset of indices ---------------------===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines the `IndexSubset` class and support logic.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef SWIFT_AST_INDEXSUBSET_H
|
|
#define SWIFT_AST_INDEXSUBSET_H
|
|
|
|
#include "swift/Basic/Debug.h"
|
|
#include "swift/Basic/LLVM.h"
|
|
#include "swift/Basic/Range.h"
|
|
#include "swift/Basic/STLExtras.h"
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/FoldingSet.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
namespace swift {
|
|
|
|
class ASTContext;
|
|
|
|
/// An efficient index subset data structure, uniqued in `ASTContext`.
|
|
/// Stores a bit vector representing set indices and a total capacity.
|
|
class IndexSubset : public llvm::FoldingSetNode {
|
|
public:
|
|
typedef uint64_t BitWord;
|
|
|
|
static constexpr unsigned bitWordSize = sizeof(BitWord);
|
|
static constexpr unsigned numBitsPerBitWord = bitWordSize * 8;
|
|
|
|
static std::pair<unsigned, unsigned>
|
|
getBitWordIndexAndOffset(unsigned index) {
|
|
auto bitWordIndex = index / numBitsPerBitWord;
|
|
auto bitWordOffset = index % numBitsPerBitWord;
|
|
return {bitWordIndex, bitWordOffset};
|
|
}
|
|
|
|
static unsigned getNumBitWordsNeededForCapacity(unsigned capacity) {
|
|
if (capacity == 0) return 0;
|
|
return capacity / numBitsPerBitWord + 1;
|
|
}
|
|
|
|
private:
|
|
/// The total capacity of the index subset, which is `1` less than the largest
|
|
/// index.
|
|
unsigned capacity;
|
|
/// The number of bit words in the index subset.
|
|
unsigned numBitWords;
|
|
|
|
static unsigned getNumBytesNeededForCapacity(unsigned capacity) {
|
|
return getNumBitWordsNeededForCapacity(capacity) * bitWordSize;
|
|
}
|
|
|
|
BitWord *getBitWordsData() {
|
|
return reinterpret_cast<BitWord *>(this + 1);
|
|
}
|
|
|
|
const BitWord *getBitWordsData() const {
|
|
return reinterpret_cast<const BitWord *>(this + 1);
|
|
}
|
|
|
|
ArrayRef<BitWord> getBitWords() const {
|
|
return {getBitWordsData(), getNumBitWords()};
|
|
}
|
|
|
|
BitWord getBitWord(unsigned i) const {
|
|
return getBitWordsData()[i];
|
|
}
|
|
|
|
BitWord &getBitWord(unsigned i) {
|
|
return getBitWordsData()[i];
|
|
}
|
|
|
|
MutableArrayRef<BitWord> getMutableBitWords() {
|
|
return {const_cast<BitWord *>(getBitWordsData()), getNumBitWords()};
|
|
}
|
|
|
|
explicit IndexSubset(const SmallBitVector &indices)
|
|
: capacity((unsigned)indices.size()),
|
|
numBitWords(getNumBitWordsNeededForCapacity(capacity)) {
|
|
std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0);
|
|
for (auto i : indices.set_bits()) {
|
|
unsigned bitWordIndex, offset;
|
|
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i);
|
|
getBitWord(bitWordIndex) |= (1ull << offset);
|
|
}
|
|
}
|
|
|
|
public:
|
|
IndexSubset() = delete;
|
|
IndexSubset(const IndexSubset &) = delete;
|
|
IndexSubset &operator=(const IndexSubset &) = delete;
|
|
|
|
// Defined in ASTContext.cpp.
|
|
static IndexSubset *get(ASTContext &ctx, const SmallBitVector &indices);
|
|
|
|
static IndexSubset *get(ASTContext &ctx, unsigned capacity,
|
|
ArrayRef<unsigned> indices) {
|
|
SmallBitVector indicesBitVec(capacity, false);
|
|
for (auto index : indices) {
|
|
assert(index < capacity);
|
|
indicesBitVec.set(index);
|
|
}
|
|
return IndexSubset::get(ctx, indicesBitVec);
|
|
}
|
|
|
|
static IndexSubset *getDefault(ASTContext &ctx, unsigned capacity,
|
|
bool includeAll = false) {
|
|
return get(ctx, SmallBitVector(capacity, includeAll));
|
|
}
|
|
|
|
static IndexSubset *getFromRange(ASTContext &ctx, unsigned capacity,
|
|
unsigned start, unsigned end) {
|
|
assert(start < capacity);
|
|
assert(end <= capacity);
|
|
SmallBitVector bitVec(capacity);
|
|
bitVec.set(start, end);
|
|
return get(ctx, bitVec);
|
|
}
|
|
|
|
/// Creates an index subset corresponding to the given string generated by
|
|
/// `getString()`. If the string is invalid, returns nullptr.
|
|
static IndexSubset *getFromString(ASTContext &ctx, StringRef string);
|
|
|
|
/// Returns the number of bit words used to store the index subset.
|
|
// Note: Use `getCapacity()` to get the total index subset capacity.
|
|
// This is public only for unit testing
|
|
// (in unittests/AST/IndexSubsetTests.cpp).
|
|
unsigned getNumBitWords() const {
|
|
return numBitWords;
|
|
}
|
|
|
|
/// Returns the capacity of the index subset.
|
|
unsigned getCapacity() const {
|
|
return capacity;
|
|
}
|
|
|
|
/// Returns a textual string description of these indices.
|
|
///
|
|
/// It has the format `[SU]+`, where the total number of characters is equal
|
|
/// to the capacity, and where "S" means that the corresponding index is
|
|
/// contained and "U" means that the corresponding index is not.
|
|
std::string getString() const;
|
|
|
|
class iterator;
|
|
|
|
iterator begin() const {
|
|
return iterator(this);
|
|
}
|
|
|
|
iterator end() const {
|
|
return iterator(this, (int)capacity);
|
|
}
|
|
|
|
/// Returns an iterator range of indices in the index subset.
|
|
iterator_range<iterator> getIndices() const {
|
|
return make_range(begin(), end());
|
|
}
|
|
|
|
/// Returns the number of indices in the index subset.
|
|
unsigned getNumIndices() const {
|
|
return (unsigned)std::distance(begin(), end());
|
|
}
|
|
|
|
SmallBitVector getBitVector() const {
|
|
SmallBitVector indicesBitVec(capacity, false);
|
|
for (auto index : getIndices())
|
|
indicesBitVec.set(index);
|
|
return indicesBitVec;
|
|
}
|
|
|
|
bool contains(unsigned index) const {
|
|
unsigned bitWordIndex, offset;
|
|
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(index);
|
|
return getBitWord(bitWordIndex) & (1ull << offset);
|
|
}
|
|
|
|
bool isEmpty() const {
|
|
return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; });
|
|
}
|
|
|
|
bool equals(IndexSubset *other) const {
|
|
return capacity == other->getCapacity() &&
|
|
getBitWords().equals(other->getBitWords());
|
|
}
|
|
|
|
bool isSubsetOf(IndexSubset *other) const;
|
|
bool isSupersetOf(IndexSubset *other) const;
|
|
|
|
IndexSubset *adding(unsigned index, ASTContext &ctx) const;
|
|
IndexSubset *extendingCapacity(ASTContext &ctx,
|
|
unsigned newCapacity) const;
|
|
|
|
void Profile(llvm::FoldingSetNodeID &id) const {
|
|
id.AddInteger(capacity);
|
|
for (auto index : getIndices())
|
|
id.AddInteger(index);
|
|
}
|
|
|
|
void print(llvm::raw_ostream &s = llvm::outs()) const;
|
|
SWIFT_DEBUG_DUMPER(dump());
|
|
|
|
int findNext(int startIndex) const;
|
|
int findFirst() const { return findNext(-1); }
|
|
int findPrevious(int endIndex) const;
|
|
int findLast() const { return findPrevious(capacity); }
|
|
|
|
class iterator {
|
|
public:
|
|
typedef unsigned value_type;
|
|
typedef unsigned difference_type;
|
|
typedef unsigned * pointer;
|
|
typedef unsigned & reference;
|
|
typedef std::forward_iterator_tag iterator_category;
|
|
|
|
private:
|
|
const IndexSubset *parent;
|
|
int current = 0;
|
|
|
|
void advance() {
|
|
assert(current != -1 && "Trying to advance past end.");
|
|
current = parent->findNext(current);
|
|
}
|
|
|
|
public:
|
|
iterator(const IndexSubset *parent, int current)
|
|
: parent(parent), current(current) {}
|
|
explicit iterator(const IndexSubset *parent)
|
|
: iterator(parent, parent->findFirst()) {}
|
|
iterator(const iterator &) = default;
|
|
|
|
iterator operator++(int) {
|
|
auto prev = *this;
|
|
advance();
|
|
return prev;
|
|
}
|
|
|
|
iterator &operator++() {
|
|
advance();
|
|
return *this;
|
|
}
|
|
|
|
unsigned operator*() const { return current; }
|
|
|
|
bool operator==(const iterator &other) const {
|
|
assert(parent == other.parent &&
|
|
"Comparing iterators from different IndexSubsets");
|
|
return current == other.current;
|
|
}
|
|
|
|
bool operator!=(const iterator &other) const {
|
|
assert(parent == other.parent &&
|
|
"Comparing iterators from different IndexSubsets");
|
|
return current != other.current;
|
|
}
|
|
};
|
|
};
|
|
|
|
}
|
|
|
|
#endif // SWIFT_AST_INDEXSUBSET_H
|