mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
This PR implements first set of changes required to support autodiff for coroutines. It mostly targeted to `_modify` accessors in standard library (and beyond), but overall implementation is quite generic. There are some specifics of implementation and known limitations: - Only `@yield_once` coroutines are naturally supported - VJP is a coroutine itself: it yields the results *and* returns a pullback closure as a normal return. This allows us to capture values produced in resume part of a coroutine (this is required for defers and other cleanups / commits) - Pullback is a coroutine, we assume that coroutine cannot abort and therefore we execute the original coroutine in reverse from return via yield and then back to the entry - It seems there is no semantically sane way to support `_read` coroutines (as we will need to "accept" adjoints via yields), therefore only coroutines with inout yields are supported (`_modify` accessors). Pullbacks of such coroutines take adjoint buffer as input argument, yield this buffer (to accumulate adjoint values in the caller) and finally return the adjoints indirectly. - Coroutines (as opposed to normal functions) are not first-class values: there is no AST type for them, one cannot e.g. store them into tuples, etc. So, everywhere where AST type is required, we have to hack around. - As there is no AST type for coroutines, there is no way one could register custom derivative for coroutines. So far only compiler-produced derivatives are supported - There are lots of common things wrt normal function apply's, but still there are subtle but important differences. I tried to organize the code to enable code reuse, still it was not always possible, so some code duplication could be seen - The order of how pullback closures are produced in VJP is a bit different: for normal apply's VJP produces both value and pullback closure via a single nested VJP apply. This is not so anymore with coroutine VJP's: yielded values are produced at `begin_apply` site and pullback closure is available only from `end_apply`, so we need to track the order in which pullbacks are produced (and arrange consumption of the values accordingly – effectively delay them) - On the way some complementary changes were required in e.g. mangler / demangler This patch covers the generation of derivatives up to SIL level, however, it is not enough as codegen of `partial_apply` of a coroutine is completely broken. The fix for this will be submitted separately as it is not directly autodiff-related. --------- Co-authored-by: Andrew Savonichev <andrew.savonichev@gmail.com> Co-authored-by: Richard Wei <rxwei@apple.com>
286 lines
10 KiB
C++
286 lines
10 KiB
C++
//===--- ASTDemangler.h - Swift AST symbol demangling -----------*- C++ -*-===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Defines a builder concept for the TypeDecoder and MetadataReader which builds
|
|
// AST Types, and a utility function wrapper which takes a mangled string and
|
|
// feeds it through the TypeDecoder instance.
|
|
//
|
|
// The RemoteAST library defines a MetadataReader instance that uses this
|
|
// concept, together with some additional utilities.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef SWIFT_AST_ASTDEMANGLER_H
|
|
#define SWIFT_AST_ASTDEMANGLER_H
|
|
|
|
#include "swift/AST/Types.h"
|
|
#include "swift/Demangling/Demangler.h"
|
|
#include "swift/Demangling/NamespaceMacros.h"
|
|
#include "swift/Demangling/TypeDecoder.h"
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include <optional>
|
|
|
|
namespace swift {
|
|
|
|
class TypeDecl;
|
|
|
|
namespace Demangle {
|
|
SWIFT_BEGIN_INLINE_NAMESPACE
|
|
|
|
Type getTypeForMangling(ASTContext &ctx,
|
|
llvm::StringRef mangling,
|
|
GenericSignature genericSig=GenericSignature());
|
|
|
|
TypeDecl *getTypeDeclForMangling(ASTContext &ctx,
|
|
llvm::StringRef mangling,
|
|
GenericSignature genericSig=GenericSignature());
|
|
|
|
TypeDecl *getTypeDeclForUSR(ASTContext &ctx,
|
|
llvm::StringRef usr,
|
|
GenericSignature genericSig=GenericSignature());
|
|
|
|
/// An implementation of MetadataReader's BuilderType concept that
|
|
/// just finds and builds things in the AST.
|
|
class ASTBuilder {
|
|
ASTContext &Ctx;
|
|
Demangle::NodeFactory Factory;
|
|
|
|
/// The notional context in which we're writing and type-checking code.
|
|
/// Created lazily.
|
|
DeclContext *NotionalDC = nullptr;
|
|
|
|
/// The depth and index of each parameter pack in the current generic
|
|
/// signature. We need this because the mangling for a type parameter
|
|
/// doesn't record whether it is a pack or not; we find the correct
|
|
/// depth and index in this array, and use its pack-ness.
|
|
llvm::SmallVector<std::pair<unsigned, unsigned>, 2> ParameterPacks;
|
|
|
|
/// For saving and restoring generic parameters.
|
|
llvm::SmallVector<decltype(ParameterPacks), 2> ParameterPackStack;
|
|
|
|
/// This builder doesn't perform "on the fly" substitutions, so we preserve
|
|
/// all pack expansions. We still need an active expansion stack though,
|
|
/// for the dummy implementation of these methods:
|
|
/// - beginPackExpansion()
|
|
/// - advancePackExpansion()
|
|
/// - createExpandedPackElement()
|
|
/// - endPackExpansion()
|
|
llvm::SmallVector<Type, 2> ActivePackExpansions;
|
|
|
|
public:
|
|
using BuiltType = swift::Type;
|
|
using BuiltTypeDecl = swift::GenericTypeDecl *; // nominal or type alias
|
|
using BuiltProtocolDecl = swift::ProtocolDecl *;
|
|
using BuiltGenericSignature = swift::GenericSignature;
|
|
using BuiltRequirement = swift::Requirement;
|
|
using BuiltInverseRequirement = swift::InverseRequirement;
|
|
using BuiltSubstitutionMap = swift::SubstitutionMap;
|
|
|
|
static constexpr bool needsToPrecomputeParentGenericContextShapes = false;
|
|
|
|
explicit ASTBuilder(ASTContext &ctx, GenericSignature genericSig)
|
|
: Ctx(ctx) {
|
|
for (auto *paramTy : genericSig.getGenericParams()) {
|
|
if (paramTy->isParameterPack())
|
|
ParameterPacks.emplace_back(paramTy->getDepth(), paramTy->getIndex());
|
|
}
|
|
}
|
|
|
|
ASTContext &getASTContext() { return Ctx; }
|
|
DeclContext *getNotionalDC();
|
|
|
|
Demangle::NodeFactory &getNodeFactory() { return Factory; }
|
|
|
|
Type decodeMangledType(NodePointer node, bool forRequirement = true);
|
|
Type createBuiltinType(StringRef builtinName, StringRef mangledName);
|
|
|
|
TypeDecl *createTypeDecl(NodePointer node);
|
|
|
|
GenericTypeDecl *createTypeDecl(StringRef mangledName, bool &typeAlias);
|
|
|
|
GenericTypeDecl *createTypeDecl(NodePointer node,
|
|
bool &typeAlias);
|
|
|
|
ProtocolDecl *createProtocolDecl(NodePointer node);
|
|
|
|
Type createNominalType(GenericTypeDecl *decl);
|
|
|
|
Type createNominalType(GenericTypeDecl *decl, Type parent);
|
|
|
|
Type createTypeAliasType(GenericTypeDecl *decl, Type parent);
|
|
|
|
Type createBoundGenericType(GenericTypeDecl *decl, ArrayRef<Type> args);
|
|
|
|
Type resolveOpaqueType(NodePointer opaqueDescriptor,
|
|
ArrayRef<ArrayRef<Type>> args,
|
|
unsigned ordinal);
|
|
|
|
Type createBoundGenericType(GenericTypeDecl *decl, ArrayRef<Type> args,
|
|
Type parent);
|
|
|
|
Type createTupleType(ArrayRef<Type> eltTypes, ArrayRef<StringRef> labels);
|
|
|
|
Type createPackType(ArrayRef<Type> eltTypes);
|
|
|
|
Type createSILPackType(ArrayRef<Type> eltTypes, bool isElementAddress);
|
|
|
|
size_t beginPackExpansion(Type countType);
|
|
|
|
void advancePackExpansion(size_t index);
|
|
|
|
Type createExpandedPackElement(Type patternType);
|
|
|
|
void endPackExpansion();
|
|
|
|
Type createFunctionType(
|
|
ArrayRef<Demangle::FunctionParam<Type>> params,
|
|
Type output, FunctionTypeFlags flags, ExtendedFunctionTypeFlags extFlags,
|
|
FunctionMetadataDifferentiabilityKind diffKind, Type globalActor,
|
|
Type thrownError);
|
|
|
|
Type createImplFunctionType(
|
|
Demangle::ImplParameterConvention calleeConvention,
|
|
Demangle::ImplCoroutineKind coroutineKind,
|
|
ArrayRef<Demangle::ImplFunctionParam<Type>> params,
|
|
ArrayRef<Demangle::ImplFunctionYield<Type>> yields,
|
|
ArrayRef<Demangle::ImplFunctionResult<Type>> results,
|
|
std::optional<Demangle::ImplFunctionResult<Type>> errorResult,
|
|
ImplFunctionTypeFlags flags);
|
|
|
|
Type createProtocolCompositionType(ArrayRef<ProtocolDecl *> protocols,
|
|
Type superclass,
|
|
bool isClassBound,
|
|
bool forRequirement = true);
|
|
|
|
Type createProtocolTypeFromDecl(ProtocolDecl *protocol);
|
|
|
|
Type createConstrainedExistentialType(
|
|
Type base,
|
|
ArrayRef<BuiltRequirement> constraints,
|
|
ArrayRef<BuiltInverseRequirement> inverseRequirements);
|
|
|
|
Type createSymbolicExtendedExistentialType(NodePointer shapeNode,
|
|
ArrayRef<Type> genArgs);
|
|
|
|
Type createExistentialMetatypeType(
|
|
Type instance,
|
|
std::optional<Demangle::ImplMetatypeRepresentation> repr = std::nullopt);
|
|
|
|
Type createMetatypeType(
|
|
Type instance,
|
|
std::optional<Demangle::ImplMetatypeRepresentation> repr = std::nullopt);
|
|
|
|
void pushGenericParams(ArrayRef<std::pair<unsigned, unsigned>> parameterPacks);
|
|
void popGenericParams();
|
|
|
|
Type createGenericTypeParameterType(unsigned depth, unsigned index);
|
|
|
|
Type createDependentMemberType(StringRef member, Type base);
|
|
|
|
Type createDependentMemberType(StringRef member, Type base,
|
|
ProtocolDecl *protocol);
|
|
|
|
#define REF_STORAGE(Name, ...) \
|
|
Type create##Name##StorageType(Type base);
|
|
#include "swift/AST/ReferenceStorage.def"
|
|
|
|
Type createSILBoxType(Type base);
|
|
using BuiltSILBoxField = llvm::PointerIntPair<Type, 1>;
|
|
using BuiltSubstitution = std::pair<Type, Type>;
|
|
using BuiltLayoutConstraint = swift::LayoutConstraint;
|
|
Type createSILBoxTypeWithLayout(
|
|
ArrayRef<BuiltSILBoxField> Fields,
|
|
ArrayRef<BuiltSubstitution> Substitutions,
|
|
ArrayRef<BuiltRequirement> Requirements,
|
|
ArrayRef<BuiltInverseRequirement> inverseRequirements);
|
|
|
|
bool isExistential(Type type) {
|
|
return type->isExistentialType();
|
|
}
|
|
|
|
|
|
Type createObjCClassType(StringRef name);
|
|
|
|
Type createBoundGenericObjCClassType(StringRef name, ArrayRef<Type> args);
|
|
|
|
ProtocolDecl *createObjCProtocolDecl(StringRef name);
|
|
|
|
Type createDynamicSelfType(Type selfType);
|
|
|
|
Type createForeignClassType(StringRef mangledName);
|
|
|
|
Type getUnnamedForeignClassType();
|
|
|
|
Type getOpaqueType();
|
|
|
|
Type createOptionalType(Type base);
|
|
|
|
Type createArrayType(Type base);
|
|
|
|
Type createDictionaryType(Type key, Type value);
|
|
|
|
Type createParenType(Type base);
|
|
|
|
BuiltGenericSignature
|
|
createGenericSignature(ArrayRef<BuiltType> params,
|
|
ArrayRef<BuiltRequirement> requirements);
|
|
|
|
BuiltSubstitutionMap createSubstitutionMap(BuiltGenericSignature sig,
|
|
ArrayRef<BuiltType> replacements);
|
|
|
|
Type subst(Type subject, const BuiltSubstitutionMap &Subs) const;
|
|
|
|
LayoutConstraint getLayoutConstraint(LayoutConstraintKind kind);
|
|
LayoutConstraint getLayoutConstraintWithSizeAlign(LayoutConstraintKind kind,
|
|
unsigned size,
|
|
unsigned alignment);
|
|
|
|
InverseRequirement createInverseRequirement(
|
|
Type subject, InvertibleProtocolKind kind);
|
|
|
|
private:
|
|
bool validateParentType(TypeDecl *decl, Type parent);
|
|
CanGenericSignature demangleGenericSignature(
|
|
NominalTypeDecl *nominalDecl,
|
|
NodePointer node);
|
|
DeclContext *findDeclContext(NodePointer node);
|
|
ModuleDecl *findModule(NodePointer node);
|
|
Demangle::NodePointer findModuleNode(NodePointer node);
|
|
|
|
enum class ForeignModuleKind {
|
|
Imported,
|
|
SynthesizedByImporter
|
|
};
|
|
|
|
std::optional<ForeignModuleKind> getForeignModuleKind(NodePointer node);
|
|
|
|
GenericTypeDecl *findTypeDecl(DeclContext *dc,
|
|
Identifier name,
|
|
Identifier privateDiscriminator,
|
|
Demangle::Node::Kind kind);
|
|
GenericTypeDecl *findForeignTypeDecl(StringRef name,
|
|
StringRef relatedEntityKind,
|
|
ForeignModuleKind lookupKind,
|
|
Demangle::Node::Kind kind);
|
|
|
|
static GenericTypeDecl *getAcceptableTypeDeclCandidate(ValueDecl *decl,
|
|
Demangle::Node::Kind kind);
|
|
};
|
|
|
|
SWIFT_END_INLINE_NAMESPACE
|
|
} // namespace Demangle
|
|
|
|
} // namespace swift
|
|
|
|
#endif // SWIFT_AST_ASTDEMANGLER_H
|