//===- GenDiffFunc.cpp - Swift IR Generation For @differentiable Functions ===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements IR generation for `@differentiable` function types in // Swift. // //===----------------------------------------------------------------------===// #include "swift/AST/Decl.h" #include "swift/AST/IRGenOptions.h" #include "swift/AST/Pattern.h" #include "swift/AST/Types.h" #include "swift/Basic/Assertions.h" #include "swift/SIL/SILModule.h" #include "swift/SIL/SILType.h" #include "llvm/IR/DerivedTypes.h" #include "Explosion.h" #include "GenHeap.h" #include "GenRecord.h" #include "GenType.h" #include "IRGenFunction.h" #include "IRGenModule.h" #include "IndirectTypeInfo.h" #include "NonFixedTypeInfo.h" #pragma clang diagnostic ignored "-Winconsistent-missing-override" using namespace swift; using namespace irgen; //----------------------------------------------------------------------------// // `@differentiable` (non-linear) function type info //----------------------------------------------------------------------------// namespace { class DifferentiableFuncFieldInfo final : public RecordField { public: DifferentiableFuncFieldInfo( NormalDifferentiableFunctionTypeComponent component, const TypeInfo &type, IndexSubset *parameterIndices, IndexSubset *resultIndices) : RecordField(type), component(component), parameterIndices(parameterIndices), resultIndices(resultIndices) {} /// The field index. const NormalDifferentiableFunctionTypeComponent component; /// The parameter indices. IndexSubset *parameterIndices; /// The result indices. IndexSubset *resultIndices; std::string getFieldName() const { switch (component) { case NormalDifferentiableFunctionTypeComponent::Original: return "original"; case NormalDifferentiableFunctionTypeComponent::JVP: return "jvp"; case NormalDifferentiableFunctionTypeComponent::VJP: return "vjp"; } llvm_unreachable("invalid component type"); } SILType getType(IRGenModule &IGM, SILType t) const { auto fnTy = t.castTo(); auto origFnTy = fnTy->getWithoutDifferentiability(); if (component == NormalDifferentiableFunctionTypeComponent::Original) return SILType::getPrimitiveObjectType(origFnTy); auto kind = *component.getAsDerivativeFunctionKind(); auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType( parameterIndices, resultIndices, kind, IGM.getSILTypes(), LookUpConformanceInModule()); return SILType::getPrimitiveObjectType(assocTy); } }; class DifferentiableFuncTypeInfo final : public RecordTypeInfo { using super = RecordTypeInfo; public: DifferentiableFuncTypeInfo(ArrayRef fields, unsigned explosionSize, llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align, IsTriviallyDestroyable_t isTriviallyDestroyable, IsFixedSize_t alwaysFixedSize) : super(fields, explosionSize, FieldsAreABIAccessible, ty, size, std::move(spareBits), align, isTriviallyDestroyable, IsCopyable, alwaysFixedSize, IsABIAccessible) {} Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, const DifferentiableFuncFieldInfo &field) const { return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); } void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src, SILType T, bool isOutlined) const override { llvm_unreachable("unexploded @differentiable function as argument?"); } void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering, Size offset) const override { for (auto &field : getFields()) { auto fieldOffset = offset + field.getFixedByteOffset(); cast(field.getTypeInfo()) .addToAggLowering(IGM, lowering, fieldOffset); } } TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM, SILType T, bool useStructLayouts) const override { if (!useStructLayouts || !areFieldsABIAccessible()) { return IGM.typeLayoutCache.getOrCreateTypeInfoBasedEntry(*this, T); } if (getFields().empty()) { return IGM.typeLayoutCache.getEmptyEntry(); } std::vector fields; for (auto &field : getFields()) { auto fieldTy = field.getType(IGM, T); fields.push_back(field.getTypeInfo().buildTypeLayoutEntry(IGM, fieldTy, useStructLayouts)); } // if (fields.size() == 1) { // return fields[0]; // } return IGM.typeLayoutCache.getOrCreateAlignedGroupEntry( fields, T, getBestKnownAlignment().getValue(), *this); } std::nullopt_t getNonFixedOffsets(IRGenFunction &IGF) const { return std::nullopt; } std::nullopt_t getNonFixedOffsets(IRGenFunction &IGF, SILType T) const { return std::nullopt; } }; class DifferentiableFuncTypeBuilder : public RecordTypeBuilder { SILFunctionType *originalType; IndexSubset *parameterIndices; IndexSubset *resultIndices; public: DifferentiableFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) : RecordTypeBuilder(IGM), originalType(fnTy->getWithoutDifferentiability()), parameterIndices(fnTy->getDifferentiabilityParameterIndices()), resultIndices(fnTy->getDifferentiabilityResultIndices()) { // TODO: Ban 'Normal' and 'Forward'. assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Reverse || fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal || fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Forward); } TypeInfo *createFixed(ArrayRef fields, FieldsAreABIAccessible_t unused, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } DifferentiableFuncTypeInfo * createLoadable(ArrayRef fields, FieldsAreABIAccessible_t unused, StructLayout &&layout, unsigned explosionSize) { return DifferentiableFuncTypeInfo::create( fields, explosionSize, layout.getType(), layout.getSize(), std::move(layout.getSpareBits()), layout.getAlignment(), layout.isTriviallyDestroyable(), layout.isAlwaysFixedSize()); } TypeInfo *createNonFixed(ArrayRef fields, FieldsAreABIAccessible_t fieldsAccessible, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } DifferentiableFuncFieldInfo getFieldInfo(unsigned index, NormalDifferentiableFunctionTypeComponent component, const TypeInfo &fieldTI) { return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices, resultIndices); } SILType getType(NormalDifferentiableFunctionTypeComponent component) { if (component == NormalDifferentiableFunctionTypeComponent::Original) return SILType::getPrimitiveObjectType(originalType->getCanonicalType()); auto kind = *component.getAsDerivativeFunctionKind(); auto assocTy = originalType->getAutoDiffDerivativeFunctionType( parameterIndices, resultIndices, kind, IGM.getSILTypes(), LookUpConformanceInModule()); return SILType::getPrimitiveObjectType(assocTy); } StructLayout performLayout(ArrayRef fieldTypes) { return StructLayout(IGM, /*type=*/std::nullopt, LayoutKind::NonHeapObject, LayoutStrategy::Universal, fieldTypes); } }; } // end anonymous namespace //----------------------------------------------------------------------------// // `@differentiable(_linear)` function type info //----------------------------------------------------------------------------// namespace { class LinearFuncFieldInfo final : public RecordField { public: LinearFuncFieldInfo(LinearDifferentiableFunctionTypeComponent component, const TypeInfo &type, IndexSubset *parameterIndices) : RecordField(type), component(component), parameterIndices(parameterIndices) {} /// The field index. const LinearDifferentiableFunctionTypeComponent component; /// The parameter indices. IndexSubset *parameterIndices; std::string getFieldName() const { switch (component) { case LinearDifferentiableFunctionTypeComponent::Original: return "original"; case LinearDifferentiableFunctionTypeComponent::Transpose: return "transpose"; } llvm_unreachable("invalid component type"); } SILType getType(IRGenModule &IGM, SILType t) const { auto fnTy = t.castTo(); auto origFnTy = fnTy->getWithoutDifferentiability(); switch (component) { case LinearDifferentiableFunctionTypeComponent::Original: return SILType::getPrimitiveObjectType(origFnTy); case LinearDifferentiableFunctionTypeComponent::Transpose: auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType( parameterIndices, IGM.getSILTypes(), LookUpConformanceInModule()); return SILType::getPrimitiveObjectType(transposeTy); } llvm_unreachable("invalid component type"); } }; class LinearFuncTypeInfo final : public RecordTypeInfo { using super = RecordTypeInfo; public: LinearFuncTypeInfo(ArrayRef fields, unsigned explosionSize, llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align, IsTriviallyDestroyable_t isTriviallyDestroyable, IsFixedSize_t alwaysFixedSize) : super(fields, explosionSize, FieldsAreABIAccessible, ty, size, std::move(spareBits), align, isTriviallyDestroyable, IsCopyable, alwaysFixedSize, IsABIAccessible) {} Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, const LinearFuncFieldInfo &field) const { return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); } void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src, SILType T, bool isOutlined) const override { llvm_unreachable("unexploded @differentiable function as argument?"); } void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering, Size offset) const override { for (auto &field : getFields()) { auto fieldOffset = offset + field.getFixedByteOffset(); cast(field.getTypeInfo()) .addToAggLowering(IGM, lowering, fieldOffset); } } TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM, SILType T, bool useStructLayouts) const override { if (!useStructLayouts || !areFieldsABIAccessible()) { return IGM.typeLayoutCache.getOrCreateTypeInfoBasedEntry(*this, T); } if (getFields().empty()) { return IGM.typeLayoutCache.getEmptyEntry(); } std::vector fields; for (auto &field : getFields()) { auto fieldTy = field.getType(IGM, T); fields.push_back(field.getTypeInfo().buildTypeLayoutEntry(IGM, fieldTy, useStructLayouts)); } // if (fields.size() == 1) { // return fields[0]; // } return IGM.typeLayoutCache.getOrCreateAlignedGroupEntry( fields, T, getBestKnownAlignment().getValue(), *this); } std::nullopt_t getNonFixedOffsets(IRGenFunction &IGF) const { return std::nullopt; } std::nullopt_t getNonFixedOffsets(IRGenFunction &IGF, SILType T) const { return std::nullopt; } }; class LinearFuncTypeBuilder : public RecordTypeBuilder { SILFunctionType *originalType; IndexSubset *parameterIndices; public: LinearFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) : RecordTypeBuilder(IGM), originalType(fnTy->getWithoutDifferentiability()), parameterIndices(fnTy->getDifferentiabilityParameterIndices()) { assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear); } TypeInfo *createFixed(ArrayRef fields, FieldsAreABIAccessible_t areFieldsABIAccessible, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } LinearFuncTypeInfo *createLoadable(ArrayRef fields, FieldsAreABIAccessible_t unused, StructLayout &&layout, unsigned explosionSize) { return LinearFuncTypeInfo::create( fields, explosionSize, layout.getType(), layout.getSize(), std::move(layout.getSpareBits()), layout.getAlignment(), layout.isTriviallyDestroyable(), layout.isAlwaysFixedSize()); } TypeInfo *createNonFixed(ArrayRef fields, FieldsAreABIAccessible_t fieldsAccessible, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } LinearFuncFieldInfo getFieldInfo(unsigned index, LinearDifferentiableFunctionTypeComponent field, const TypeInfo &fieldTI) { return LinearFuncFieldInfo(field, fieldTI, parameterIndices); } SILType getType(LinearDifferentiableFunctionTypeComponent component) { switch (component) { case LinearDifferentiableFunctionTypeComponent::Original: return SILType::getPrimitiveObjectType(originalType->getCanonicalType()); case LinearDifferentiableFunctionTypeComponent::Transpose: auto transposeTy = originalType->getAutoDiffTransposeFunctionType( parameterIndices, IGM.getSILTypes(), LookUpConformanceInModule()); return SILType::getPrimitiveObjectType(transposeTy); } llvm_unreachable("invalid component type"); } StructLayout performLayout(ArrayRef fieldTypes) { return StructLayout(IGM, /*type=*/std::nullopt, LayoutKind::NonHeapObject, LayoutStrategy::Universal, fieldTypes); } }; } // end anonymous namespace //----------------------------------------------------------------------------// // Type converter entry points //----------------------------------------------------------------------------// const TypeInfo * TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) { DifferentiableFuncTypeBuilder builder(IGM, type); return builder.layout({NormalDifferentiableFunctionTypeComponent::Original, NormalDifferentiableFunctionTypeComponent::JVP, NormalDifferentiableFunctionTypeComponent::VJP}); } const TypeInfo * TypeConverter::convertLinearDifferentiableFunctionType(SILFunctionType *type) { LinearFuncTypeBuilder builder(IGM, type); return builder.layout({LinearDifferentiableFunctionTypeComponent::Original, LinearDifferentiableFunctionTypeComponent::Transpose}); }