Files
swift-mirror/lib/IRGen/GenDiffFunc.cpp
Dario Rexin 3cf40ea504 [IRGen] Re-introduce TypeLayout strings (#62059)
* Introduce TypeLayout Strings

Layout strings encode the structure of a type into a byte string that can be
interpreted by a runtime function to achieve a destroy or copy. Rather than
generating ir for a destroy/assignWithCopy/etc, we instead generate a layout
string which encodes enough information for a called runtime function to
perform the operation for us. Value witness functions tend to be quite large,
so this allows us to replace them with a single call instead. This gives us the
option of making a codesize/runtime cost trade off.

* Added Attribute @_GenerateLayoutBytecode

This marks a type definition that should use generic bytecode based
value witnesses rather than generating the standard suite of
value witness functions. This should reduce the codesize of the binary
for a runtime interpretation of the bytecode cost.

* Statically link in implementation

Summary:
This creates a library to store the runtime functions in to deploy to
runtimes that do not implement bytecode layouts. Right now, that is
everything. Once these are added to the runtime itself, it can be used
to deploy to old runtimes.

* Implement Destroy at Runtime Using LayoutStrings

If GenerateLayoutBytecode is enabled, Create a layout string and use it
to call swift_generic_destroy

* Add Resilient type and Archetype Support for BytecodeLayouts

Add Resilient type and Archetype Support to Bytecode Layouts

* Implement Bytecode assign/init with copy/take

Implements swift_generic_initialize and swift_generic_assign to allow copying
types using bytecode based witnesses.

* Add EnumTag Support

* Add IRGen Bytecode Layouts Test

Added a test to ensure layouts are correct and getting generated

* Implement BytecodeLayouts ObjC retain/release

* Fix for Non static alignments in aligned groups

* Disable MultiEnums

MultiEnums currently have some correctness issues with non fixed multienum
types. Disabling them for now then going to attempt a correct implementation in
a follow up patch

* Fixes after merge

* More fixes

* Possible fix for native unowned

* Use TypeInfoeBasedTypeLayoutEntry for all scalars when ForceStructTypeLayouts is disabled

* Remove @_GenerateBytecodeLayout attribute

* Fix typelayout_based_value_witness.swift

Co-authored-by: Gwen Mittertreiner <gwenm@fb.com>
Co-authored-by: Gwen Mittertreiner <gwen.mittertreiner@gmail.com>
2022-11-29 21:05:22 -08:00

400 lines
15 KiB
C++

//===- GenDiffFunc.cpp - Swift IR Generation For @differentiable Functions ===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements IR generation for `@differentiable` function types in
// Swift.
//
//===----------------------------------------------------------------------===//
#include "swift/AST/Decl.h"
#include "swift/AST/IRGenOptions.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/Types.h"
#include "swift/SIL/SILModule.h"
#include "swift/SIL/SILType.h"
#include "llvm/IR/DerivedTypes.h"
#include "Explosion.h"
#include "GenHeap.h"
#include "GenRecord.h"
#include "GenType.h"
#include "IRGenFunction.h"
#include "IRGenModule.h"
#include "IndirectTypeInfo.h"
#include "NonFixedTypeInfo.h"
#pragma clang diagnostic ignored "-Winconsistent-missing-override"
using namespace swift;
using namespace irgen;
//----------------------------------------------------------------------------//
// `@differentiable` (non-linear) function type info
//----------------------------------------------------------------------------//
namespace {
class DifferentiableFuncFieldInfo final
: public RecordField<DifferentiableFuncFieldInfo> {
public:
DifferentiableFuncFieldInfo(
NormalDifferentiableFunctionTypeComponent component, const TypeInfo &type,
IndexSubset *parameterIndices, IndexSubset *resultIndices)
: RecordField(type), component(component),
parameterIndices(parameterIndices), resultIndices(resultIndices) {}
/// The field index.
const NormalDifferentiableFunctionTypeComponent component;
/// The parameter indices.
IndexSubset *parameterIndices;
/// The result indices.
IndexSubset *resultIndices;
std::string getFieldName() const {
switch (component) {
case NormalDifferentiableFunctionTypeComponent::Original:
return "original";
case NormalDifferentiableFunctionTypeComponent::JVP:
return "jvp";
case NormalDifferentiableFunctionTypeComponent::VJP:
return "vjp";
}
llvm_unreachable("invalid component type");
}
SILType getType(IRGenModule &IGM, SILType t) const {
auto fnTy = t.castTo<SILFunctionType>();
auto origFnTy = fnTy->getWithoutDifferentiability();
if (component == NormalDifferentiableFunctionTypeComponent::Original)
return SILType::getPrimitiveObjectType(origFnTy);
auto kind = *component.getAsDerivativeFunctionKind();
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType(
parameterIndices, resultIndices, kind, IGM.getSILTypes(),
LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(assocTy);
}
};
class DifferentiableFuncTypeInfo final
: public RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo,
DifferentiableFuncFieldInfo> {
using super = RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo,
DifferentiableFuncFieldInfo>;
public:
DifferentiableFuncTypeInfo(ArrayRef<DifferentiableFuncFieldInfo> fields,
unsigned explosionSize, llvm::Type *ty, Size size,
SpareBitVector &&spareBits, Alignment align,
IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
isPOD, alwaysFixedSize) {}
Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
const DifferentiableFuncFieldInfo &field) const {
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
}
void initializeFromParams(IRGenFunction &IGF, Explosion &params, Address src,
SILType T, bool isOutlined) const override {
llvm_unreachable("unexploded @differentiable function as argument?");
}
void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering,
Size offset) const override {
for (auto &field : getFields()) {
auto fieldOffset = offset + field.getFixedByteOffset();
cast<LoadableTypeInfo>(field.getTypeInfo())
.addToAggLowering(IGM, lowering, fieldOffset);
}
}
TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM,
SILType T) const override {
if (!IGM.getOptions().ForceStructTypeLayouts || !areFieldsABIAccessible()) {
return IGM.typeLayoutCache.getOrCreateTypeInfoBasedEntry(*this, T);
}
if (getFields().empty()) {
return IGM.typeLayoutCache.getEmptyEntry();
}
std::vector<TypeLayoutEntry *> fields;
for (auto &field : getFields()) {
auto fieldTy = field.getType(IGM, T);
fields.push_back(field.getTypeInfo().buildTypeLayoutEntry(IGM, fieldTy));
}
if (fields.size() == 1) {
return fields[0];
}
return IGM.typeLayoutCache.getOrCreateAlignedGroupEntry(fields, 1);
}
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; }
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const {
return None;
}
};
class DifferentiableFuncTypeBuilder
: public RecordTypeBuilder<DifferentiableFuncTypeBuilder,
DifferentiableFuncFieldInfo,
NormalDifferentiableFunctionTypeComponent> {
SILFunctionType *originalType;
IndexSubset *parameterIndices;
IndexSubset *resultIndices;
public:
DifferentiableFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
: RecordTypeBuilder(IGM),
originalType(fnTy->getWithoutDifferentiability()),
parameterIndices(fnTy->getDifferentiabilityParameterIndices()),
resultIndices(fnTy->getDifferentiabilityResultIndices()) {
// TODO: Ban 'Normal' and 'Forward'.
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Reverse ||
fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal ||
fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Forward);
}
TypeInfo *createFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
StructLayout &&layout) {
llvm_unreachable("@differentiable functions are always loadable");
}
DifferentiableFuncTypeInfo *
createLoadable(ArrayRef<DifferentiableFuncFieldInfo> fields,
StructLayout &&layout, unsigned explosionSize) {
return DifferentiableFuncTypeInfo::create(
fields, explosionSize, layout.getType(), layout.getSize(),
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
layout.isAlwaysFixedSize());
}
TypeInfo *createNonFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
FieldsAreABIAccessible_t fieldsAccessible,
StructLayout &&layout) {
llvm_unreachable("@differentiable functions are always loadable");
}
DifferentiableFuncFieldInfo
getFieldInfo(unsigned index,
NormalDifferentiableFunctionTypeComponent component,
const TypeInfo &fieldTI) {
return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices,
resultIndices);
}
SILType getType(NormalDifferentiableFunctionTypeComponent component) {
if (component == NormalDifferentiableFunctionTypeComponent::Original)
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
auto kind = *component.getAsDerivativeFunctionKind();
auto assocTy = originalType->getAutoDiffDerivativeFunctionType(
parameterIndices, resultIndices, kind, IGM.getSILTypes(),
LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(assocTy);
}
StructLayout performLayout(ArrayRef<const TypeInfo *> fieldTypes) {
return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject,
LayoutStrategy::Universal, fieldTypes);
}
};
} // end anonymous namespace
//----------------------------------------------------------------------------//
// `@differentiable(_linear)` function type info
//----------------------------------------------------------------------------//
namespace {
class LinearFuncFieldInfo final : public RecordField<LinearFuncFieldInfo> {
public:
LinearFuncFieldInfo(LinearDifferentiableFunctionTypeComponent component,
const TypeInfo &type, IndexSubset *parameterIndices)
: RecordField(type), component(component),
parameterIndices(parameterIndices) {}
/// The field index.
const LinearDifferentiableFunctionTypeComponent component;
/// The parameter indices.
IndexSubset *parameterIndices;
std::string getFieldName() const {
switch (component) {
case LinearDifferentiableFunctionTypeComponent::Original:
return "original";
case LinearDifferentiableFunctionTypeComponent::Transpose:
return "transpose";
}
llvm_unreachable("invalid component type");
}
SILType getType(IRGenModule &IGM, SILType t) const {
auto fnTy = t.castTo<SILFunctionType>();
auto origFnTy = fnTy->getWithoutDifferentiability();
switch (component) {
case LinearDifferentiableFunctionTypeComponent::Original:
return SILType::getPrimitiveObjectType(origFnTy);
case LinearDifferentiableFunctionTypeComponent::Transpose:
auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType(
parameterIndices, IGM.getSILTypes(),
LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(transposeTy);
}
llvm_unreachable("invalid component type");
}
};
class LinearFuncTypeInfo final
: public RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo,
LinearFuncFieldInfo> {
using super =
RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo, LinearFuncFieldInfo>;
public:
LinearFuncTypeInfo(ArrayRef<LinearFuncFieldInfo> fields,
unsigned explosionSize, llvm::Type *ty, Size size,
SpareBitVector &&spareBits, Alignment align, IsPOD_t isPOD,
IsFixedSize_t alwaysFixedSize)
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
isPOD, alwaysFixedSize) {}
Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
const LinearFuncFieldInfo &field) const {
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
}
void initializeFromParams(IRGenFunction &IGF, Explosion &params, Address src,
SILType T, bool isOutlined) const override {
llvm_unreachable("unexploded @differentiable function as argument?");
}
void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering,
Size offset) const override {
for (auto &field : getFields()) {
auto fieldOffset = offset + field.getFixedByteOffset();
cast<LoadableTypeInfo>(field.getTypeInfo())
.addToAggLowering(IGM, lowering, fieldOffset);
}
}
TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM,
SILType T) const override {
if (!IGM.getOptions().ForceStructTypeLayouts || !areFieldsABIAccessible()) {
return IGM.typeLayoutCache.getOrCreateTypeInfoBasedEntry(*this, T);
}
if (getFields().empty()) {
return IGM.typeLayoutCache.getEmptyEntry();
}
std::vector<TypeLayoutEntry *> fields;
for (auto &field : getFields()) {
auto fieldTy = field.getType(IGM, T);
fields.push_back(field.getTypeInfo().buildTypeLayoutEntry(IGM, fieldTy));
}
if (fields.size() == 1) {
return fields[0];
}
return IGM.typeLayoutCache.getOrCreateAlignedGroupEntry(fields, 1);
}
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; }
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const {
return None;
}
};
class LinearFuncTypeBuilder
: public RecordTypeBuilder<LinearFuncTypeBuilder, LinearFuncFieldInfo,
LinearDifferentiableFunctionTypeComponent> {
SILFunctionType *originalType;
IndexSubset *parameterIndices;
public:
LinearFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
: RecordTypeBuilder(IGM),
originalType(fnTy->getWithoutDifferentiability()),
parameterIndices(fnTy->getDifferentiabilityParameterIndices()) {
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear);
}
TypeInfo *createFixed(ArrayRef<LinearFuncFieldInfo> fields,
StructLayout &&layout) {
llvm_unreachable("@differentiable functions are always loadable");
}
LinearFuncTypeInfo *createLoadable(ArrayRef<LinearFuncFieldInfo> fields,
StructLayout &&layout,
unsigned explosionSize) {
return LinearFuncTypeInfo::create(
fields, explosionSize, layout.getType(), layout.getSize(),
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
layout.isAlwaysFixedSize());
}
TypeInfo *createNonFixed(ArrayRef<LinearFuncFieldInfo> fields,
FieldsAreABIAccessible_t fieldsAccessible,
StructLayout &&layout) {
llvm_unreachable("@differentiable functions are always loadable");
}
LinearFuncFieldInfo
getFieldInfo(unsigned index, LinearDifferentiableFunctionTypeComponent field,
const TypeInfo &fieldTI) {
return LinearFuncFieldInfo(field, fieldTI, parameterIndices);
}
SILType getType(LinearDifferentiableFunctionTypeComponent component) {
switch (component) {
case LinearDifferentiableFunctionTypeComponent::Original:
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
case LinearDifferentiableFunctionTypeComponent::Transpose:
auto transposeTy = originalType->getAutoDiffTransposeFunctionType(
parameterIndices, IGM.getSILTypes(),
LookUpConformanceInModule(IGM.getSwiftModule()));
return SILType::getPrimitiveObjectType(transposeTy);
}
llvm_unreachable("invalid component type");
}
StructLayout performLayout(ArrayRef<const TypeInfo *> fieldTypes) {
return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject,
LayoutStrategy::Universal, fieldTypes);
}
};
} // end anonymous namespace
//----------------------------------------------------------------------------//
// Type converter entry points
//----------------------------------------------------------------------------//
const TypeInfo *
TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) {
DifferentiableFuncTypeBuilder builder(IGM, type);
return builder.layout({NormalDifferentiableFunctionTypeComponent::Original,
NormalDifferentiableFunctionTypeComponent::JVP,
NormalDifferentiableFunctionTypeComponent::VJP});
}
const TypeInfo *
TypeConverter::convertLinearDifferentiableFunctionType(SILFunctionType *type) {
LinearFuncTypeBuilder builder(IGM, type);
return builder.layout({LinearDifferentiableFunctionTypeComponent::Original,
LinearDifferentiableFunctionTypeComponent::Transpose});
}