mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
* Introduce TypeLayout Strings Layout strings encode the structure of a type into a byte string that can be interpreted by a runtime function to achieve a destroy or copy. Rather than generating ir for a destroy/assignWithCopy/etc, we instead generate a layout string which encodes enough information for a called runtime function to perform the operation for us. Value witness functions tend to be quite large, so this allows us to replace them with a single call instead. This gives us the option of making a codesize/runtime cost trade off. * Added Attribute @_GenerateLayoutBytecode This marks a type definition that should use generic bytecode based value witnesses rather than generating the standard suite of value witness functions. This should reduce the codesize of the binary for a runtime interpretation of the bytecode cost. * Statically link in implementation Summary: This creates a library to store the runtime functions in to deploy to runtimes that do not implement bytecode layouts. Right now, that is everything. Once these are added to the runtime itself, it can be used to deploy to old runtimes. * Implement Destroy at Runtime Using LayoutStrings If GenerateLayoutBytecode is enabled, Create a layout string and use it to call swift_generic_destroy * Add Resilient type and Archetype Support for BytecodeLayouts Add Resilient type and Archetype Support to Bytecode Layouts * Implement Bytecode assign/init with copy/take Implements swift_generic_initialize and swift_generic_assign to allow copying types using bytecode based witnesses. * Add EnumTag Support * Add IRGen Bytecode Layouts Test Added a test to ensure layouts are correct and getting generated * Implement BytecodeLayouts ObjC retain/release * Fix for Non static alignments in aligned groups * Disable MultiEnums MultiEnums currently have some correctness issues with non fixed multienum types. Disabling them for now then going to attempt a correct implementation in a follow up patch * Fixes after merge * More fixes * Possible fix for native unowned * Use TypeInfoeBasedTypeLayoutEntry for all scalars when ForceStructTypeLayouts is disabled * Remove @_GenerateBytecodeLayout attribute * Fix typelayout_based_value_witness.swift Co-authored-by: Gwen Mittertreiner <gwenm@fb.com> Co-authored-by: Gwen Mittertreiner <gwen.mittertreiner@gmail.com>
400 lines
15 KiB
C++
400 lines
15 KiB
C++
//===- GenDiffFunc.cpp - Swift IR Generation For @differentiable Functions ===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements IR generation for `@differentiable` function types in
|
|
// Swift.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "swift/AST/Decl.h"
|
|
#include "swift/AST/IRGenOptions.h"
|
|
#include "swift/AST/Pattern.h"
|
|
#include "swift/AST/Types.h"
|
|
#include "swift/SIL/SILModule.h"
|
|
#include "swift/SIL/SILType.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
|
|
#include "Explosion.h"
|
|
#include "GenHeap.h"
|
|
#include "GenRecord.h"
|
|
#include "GenType.h"
|
|
#include "IRGenFunction.h"
|
|
#include "IRGenModule.h"
|
|
#include "IndirectTypeInfo.h"
|
|
#include "NonFixedTypeInfo.h"
|
|
|
|
#pragma clang diagnostic ignored "-Winconsistent-missing-override"
|
|
|
|
using namespace swift;
|
|
using namespace irgen;
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// `@differentiable` (non-linear) function type info
|
|
//----------------------------------------------------------------------------//
|
|
|
|
namespace {
|
|
class DifferentiableFuncFieldInfo final
|
|
: public RecordField<DifferentiableFuncFieldInfo> {
|
|
public:
|
|
DifferentiableFuncFieldInfo(
|
|
NormalDifferentiableFunctionTypeComponent component, const TypeInfo &type,
|
|
IndexSubset *parameterIndices, IndexSubset *resultIndices)
|
|
: RecordField(type), component(component),
|
|
parameterIndices(parameterIndices), resultIndices(resultIndices) {}
|
|
|
|
/// The field index.
|
|
const NormalDifferentiableFunctionTypeComponent component;
|
|
|
|
/// The parameter indices.
|
|
IndexSubset *parameterIndices;
|
|
/// The result indices.
|
|
IndexSubset *resultIndices;
|
|
|
|
std::string getFieldName() const {
|
|
switch (component) {
|
|
case NormalDifferentiableFunctionTypeComponent::Original:
|
|
return "original";
|
|
case NormalDifferentiableFunctionTypeComponent::JVP:
|
|
return "jvp";
|
|
case NormalDifferentiableFunctionTypeComponent::VJP:
|
|
return "vjp";
|
|
}
|
|
llvm_unreachable("invalid component type");
|
|
}
|
|
|
|
SILType getType(IRGenModule &IGM, SILType t) const {
|
|
auto fnTy = t.castTo<SILFunctionType>();
|
|
auto origFnTy = fnTy->getWithoutDifferentiability();
|
|
if (component == NormalDifferentiableFunctionTypeComponent::Original)
|
|
return SILType::getPrimitiveObjectType(origFnTy);
|
|
auto kind = *component.getAsDerivativeFunctionKind();
|
|
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType(
|
|
parameterIndices, resultIndices, kind, IGM.getSILTypes(),
|
|
LookUpConformanceInModule(IGM.getSwiftModule()));
|
|
return SILType::getPrimitiveObjectType(assocTy);
|
|
}
|
|
};
|
|
|
|
class DifferentiableFuncTypeInfo final
|
|
: public RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo,
|
|
DifferentiableFuncFieldInfo> {
|
|
using super = RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo,
|
|
DifferentiableFuncFieldInfo>;
|
|
|
|
public:
|
|
DifferentiableFuncTypeInfo(ArrayRef<DifferentiableFuncFieldInfo> fields,
|
|
unsigned explosionSize, llvm::Type *ty, Size size,
|
|
SpareBitVector &&spareBits, Alignment align,
|
|
IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
|
|
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
|
|
isPOD, alwaysFixedSize) {}
|
|
|
|
Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
|
|
const DifferentiableFuncFieldInfo &field) const {
|
|
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
|
|
}
|
|
|
|
void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src,
|
|
SILType T, bool isOutlined) const override {
|
|
llvm_unreachable("unexploded @differentiable function as argument?");
|
|
}
|
|
|
|
void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering,
|
|
Size offset) const override {
|
|
for (auto &field : getFields()) {
|
|
auto fieldOffset = offset + field.getFixedByteOffset();
|
|
cast<LoadableTypeInfo>(field.getTypeInfo())
|
|
.addToAggLowering(IGM, lowering, fieldOffset);
|
|
}
|
|
}
|
|
|
|
TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM,
|
|
SILType T) const override {
|
|
if (!IGM.getOptions().ForceStructTypeLayouts || !areFieldsABIAccessible()) {
|
|
return IGM.typeLayoutCache.getOrCreateTypeInfoBasedEntry(*this, T);
|
|
}
|
|
|
|
if (getFields().empty()) {
|
|
return IGM.typeLayoutCache.getEmptyEntry();
|
|
}
|
|
|
|
std::vector<TypeLayoutEntry *> fields;
|
|
for (auto &field : getFields()) {
|
|
auto fieldTy = field.getType(IGM, T);
|
|
fields.push_back(field.getTypeInfo().buildTypeLayoutEntry(IGM, fieldTy));
|
|
}
|
|
|
|
if (fields.size() == 1) {
|
|
return fields[0];
|
|
}
|
|
|
|
return IGM.typeLayoutCache.getOrCreateAlignedGroupEntry(fields, 1);
|
|
}
|
|
|
|
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; }
|
|
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const {
|
|
return None;
|
|
}
|
|
};
|
|
|
|
class DifferentiableFuncTypeBuilder
|
|
: public RecordTypeBuilder<DifferentiableFuncTypeBuilder,
|
|
DifferentiableFuncFieldInfo,
|
|
NormalDifferentiableFunctionTypeComponent> {
|
|
|
|
SILFunctionType *originalType;
|
|
IndexSubset *parameterIndices;
|
|
IndexSubset *resultIndices;
|
|
|
|
public:
|
|
DifferentiableFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
|
|
: RecordTypeBuilder(IGM),
|
|
originalType(fnTy->getWithoutDifferentiability()),
|
|
parameterIndices(fnTy->getDifferentiabilityParameterIndices()),
|
|
resultIndices(fnTy->getDifferentiabilityResultIndices()) {
|
|
// TODO: Ban 'Normal' and 'Forward'.
|
|
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Reverse ||
|
|
fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal ||
|
|
fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Forward);
|
|
}
|
|
|
|
TypeInfo *createFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
|
|
StructLayout &&layout) {
|
|
llvm_unreachable("@differentiable functions are always loadable");
|
|
}
|
|
|
|
DifferentiableFuncTypeInfo *
|
|
createLoadable(ArrayRef<DifferentiableFuncFieldInfo> fields,
|
|
StructLayout &&layout, unsigned explosionSize) {
|
|
return DifferentiableFuncTypeInfo::create(
|
|
fields, explosionSize, layout.getType(), layout.getSize(),
|
|
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
|
|
layout.isAlwaysFixedSize());
|
|
}
|
|
|
|
TypeInfo *createNonFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
|
|
FieldsAreABIAccessible_t fieldsAccessible,
|
|
StructLayout &&layout) {
|
|
llvm_unreachable("@differentiable functions are always loadable");
|
|
}
|
|
|
|
DifferentiableFuncFieldInfo
|
|
getFieldInfo(unsigned index,
|
|
NormalDifferentiableFunctionTypeComponent component,
|
|
const TypeInfo &fieldTI) {
|
|
return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices,
|
|
resultIndices);
|
|
}
|
|
|
|
SILType getType(NormalDifferentiableFunctionTypeComponent component) {
|
|
if (component == NormalDifferentiableFunctionTypeComponent::Original)
|
|
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
|
|
auto kind = *component.getAsDerivativeFunctionKind();
|
|
auto assocTy = originalType->getAutoDiffDerivativeFunctionType(
|
|
parameterIndices, resultIndices, kind, IGM.getSILTypes(),
|
|
LookUpConformanceInModule(IGM.getSwiftModule()));
|
|
return SILType::getPrimitiveObjectType(assocTy);
|
|
}
|
|
|
|
StructLayout performLayout(ArrayRef<const TypeInfo *> fieldTypes) {
|
|
return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject,
|
|
LayoutStrategy::Universal, fieldTypes);
|
|
}
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// `@differentiable(_linear)` function type info
|
|
//----------------------------------------------------------------------------//
|
|
|
|
namespace {
|
|
class LinearFuncFieldInfo final : public RecordField<LinearFuncFieldInfo> {
|
|
public:
|
|
LinearFuncFieldInfo(LinearDifferentiableFunctionTypeComponent component,
|
|
const TypeInfo &type, IndexSubset *parameterIndices)
|
|
: RecordField(type), component(component),
|
|
parameterIndices(parameterIndices) {}
|
|
|
|
/// The field index.
|
|
const LinearDifferentiableFunctionTypeComponent component;
|
|
|
|
/// The parameter indices.
|
|
IndexSubset *parameterIndices;
|
|
|
|
std::string getFieldName() const {
|
|
switch (component) {
|
|
case LinearDifferentiableFunctionTypeComponent::Original:
|
|
return "original";
|
|
case LinearDifferentiableFunctionTypeComponent::Transpose:
|
|
return "transpose";
|
|
}
|
|
llvm_unreachable("invalid component type");
|
|
}
|
|
|
|
SILType getType(IRGenModule &IGM, SILType t) const {
|
|
auto fnTy = t.castTo<SILFunctionType>();
|
|
auto origFnTy = fnTy->getWithoutDifferentiability();
|
|
switch (component) {
|
|
case LinearDifferentiableFunctionTypeComponent::Original:
|
|
return SILType::getPrimitiveObjectType(origFnTy);
|
|
case LinearDifferentiableFunctionTypeComponent::Transpose:
|
|
auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType(
|
|
parameterIndices, IGM.getSILTypes(),
|
|
LookUpConformanceInModule(IGM.getSwiftModule()));
|
|
return SILType::getPrimitiveObjectType(transposeTy);
|
|
}
|
|
llvm_unreachable("invalid component type");
|
|
}
|
|
};
|
|
|
|
class LinearFuncTypeInfo final
|
|
: public RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo,
|
|
LinearFuncFieldInfo> {
|
|
using super =
|
|
RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo, LinearFuncFieldInfo>;
|
|
|
|
public:
|
|
LinearFuncTypeInfo(ArrayRef<LinearFuncFieldInfo> fields,
|
|
unsigned explosionSize, llvm::Type *ty, Size size,
|
|
SpareBitVector &&spareBits, Alignment align, IsPOD_t isPOD,
|
|
IsFixedSize_t alwaysFixedSize)
|
|
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
|
|
isPOD, alwaysFixedSize) {}
|
|
|
|
Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
|
|
const LinearFuncFieldInfo &field) const {
|
|
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
|
|
}
|
|
|
|
void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src,
|
|
SILType T, bool isOutlined) const override {
|
|
llvm_unreachable("unexploded @differentiable function as argument?");
|
|
}
|
|
|
|
void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering,
|
|
Size offset) const override {
|
|
for (auto &field : getFields()) {
|
|
auto fieldOffset = offset + field.getFixedByteOffset();
|
|
cast<LoadableTypeInfo>(field.getTypeInfo())
|
|
.addToAggLowering(IGM, lowering, fieldOffset);
|
|
}
|
|
}
|
|
|
|
TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM,
|
|
SILType T) const override {
|
|
if (!IGM.getOptions().ForceStructTypeLayouts || !areFieldsABIAccessible()) {
|
|
return IGM.typeLayoutCache.getOrCreateTypeInfoBasedEntry(*this, T);
|
|
}
|
|
|
|
if (getFields().empty()) {
|
|
return IGM.typeLayoutCache.getEmptyEntry();
|
|
}
|
|
|
|
std::vector<TypeLayoutEntry *> fields;
|
|
for (auto &field : getFields()) {
|
|
auto fieldTy = field.getType(IGM, T);
|
|
fields.push_back(field.getTypeInfo().buildTypeLayoutEntry(IGM, fieldTy));
|
|
}
|
|
|
|
if (fields.size() == 1) {
|
|
return fields[0];
|
|
}
|
|
|
|
return IGM.typeLayoutCache.getOrCreateAlignedGroupEntry(fields, 1);
|
|
}
|
|
|
|
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; }
|
|
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const {
|
|
return None;
|
|
}
|
|
};
|
|
|
|
class LinearFuncTypeBuilder
|
|
: public RecordTypeBuilder<LinearFuncTypeBuilder, LinearFuncFieldInfo,
|
|
LinearDifferentiableFunctionTypeComponent> {
|
|
|
|
SILFunctionType *originalType;
|
|
IndexSubset *parameterIndices;
|
|
|
|
public:
|
|
LinearFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
|
|
: RecordTypeBuilder(IGM),
|
|
originalType(fnTy->getWithoutDifferentiability()),
|
|
parameterIndices(fnTy->getDifferentiabilityParameterIndices()) {
|
|
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear);
|
|
}
|
|
|
|
TypeInfo *createFixed(ArrayRef<LinearFuncFieldInfo> fields,
|
|
StructLayout &&layout) {
|
|
llvm_unreachable("@differentiable functions are always loadable");
|
|
}
|
|
|
|
LinearFuncTypeInfo *createLoadable(ArrayRef<LinearFuncFieldInfo> fields,
|
|
StructLayout &&layout,
|
|
unsigned explosionSize) {
|
|
return LinearFuncTypeInfo::create(
|
|
fields, explosionSize, layout.getType(), layout.getSize(),
|
|
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
|
|
layout.isAlwaysFixedSize());
|
|
}
|
|
|
|
TypeInfo *createNonFixed(ArrayRef<LinearFuncFieldInfo> fields,
|
|
FieldsAreABIAccessible_t fieldsAccessible,
|
|
StructLayout &&layout) {
|
|
llvm_unreachable("@differentiable functions are always loadable");
|
|
}
|
|
|
|
LinearFuncFieldInfo
|
|
getFieldInfo(unsigned index, LinearDifferentiableFunctionTypeComponent field,
|
|
const TypeInfo &fieldTI) {
|
|
return LinearFuncFieldInfo(field, fieldTI, parameterIndices);
|
|
}
|
|
|
|
SILType getType(LinearDifferentiableFunctionTypeComponent component) {
|
|
switch (component) {
|
|
case LinearDifferentiableFunctionTypeComponent::Original:
|
|
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
|
|
case LinearDifferentiableFunctionTypeComponent::Transpose:
|
|
auto transposeTy = originalType->getAutoDiffTransposeFunctionType(
|
|
parameterIndices, IGM.getSILTypes(),
|
|
LookUpConformanceInModule(IGM.getSwiftModule()));
|
|
return SILType::getPrimitiveObjectType(transposeTy);
|
|
}
|
|
llvm_unreachable("invalid component type");
|
|
}
|
|
|
|
StructLayout performLayout(ArrayRef<const TypeInfo *> fieldTypes) {
|
|
return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject,
|
|
LayoutStrategy::Universal, fieldTypes);
|
|
}
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Type converter entry points
|
|
//----------------------------------------------------------------------------//
|
|
|
|
const TypeInfo *
|
|
TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) {
|
|
DifferentiableFuncTypeBuilder builder(IGM, type);
|
|
return builder.layout({NormalDifferentiableFunctionTypeComponent::Original,
|
|
NormalDifferentiableFunctionTypeComponent::JVP,
|
|
NormalDifferentiableFunctionTypeComponent::VJP});
|
|
}
|
|
|
|
const TypeInfo *
|
|
TypeConverter::convertLinearDifferentiableFunctionType(SILFunctionType *type) {
|
|
LinearFuncTypeBuilder builder(IGM, type);
|
|
return builder.layout({LinearDifferentiableFunctionTypeComponent::Original,
|
|
LinearDifferentiableFunctionTypeComponent::Transpose});
|
|
}
|