//===- GenDiffFunc.cpp - Swift IR Generation For @differentiable Functions ===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// // // This file implements IR generation for `@differentiable` function types in // Swift. // //===----------------------------------------------------------------------===// #include "swift/AST/Decl.h" #include "swift/AST/Pattern.h" #include "swift/AST/Types.h" #include "swift/SIL/SILModule.h" #include "swift/SIL/SILType.h" #include "llvm/IR/DerivedTypes.h" #include "Explosion.h" #include "GenHeap.h" #include "GenRecord.h" #include "GenType.h" #include "IRGenFunction.h" #include "IRGenModule.h" #include "IndirectTypeInfo.h" #include "NonFixedTypeInfo.h" #pragma clang diagnostic ignored "-Winconsistent-missing-override" using namespace swift; using namespace irgen; //----------------------------------------------------------------------------// // `@differentiable` (non-linear) function type info //----------------------------------------------------------------------------// namespace { class DifferentiableFuncFieldInfo final : public RecordField { public: DifferentiableFuncFieldInfo( NormalDifferentiableFunctionTypeComponent component, const TypeInfo &type, IndexSubset *parameterIndices) : RecordField(type), component(component), parameterIndices(parameterIndices) {} /// The field index. const NormalDifferentiableFunctionTypeComponent component; /// The parameter indices. IndexSubset *parameterIndices; std::string getFieldName() const { switch (component) { case NormalDifferentiableFunctionTypeComponent::Original: return "original"; case NormalDifferentiableFunctionTypeComponent::JVP: return "jvp"; case NormalDifferentiableFunctionTypeComponent::VJP: return "vjp"; } } SILType getType(IRGenModule &IGM, SILType t) const { auto fnTy = t.castTo(); auto origFnTy = fnTy->getWithoutDifferentiability(); if (component == NormalDifferentiableFunctionTypeComponent::Original) return SILType::getPrimitiveObjectType(origFnTy); auto kind = *component.getAsDerivativeFunctionKind(); auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType( parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(assocTy); } }; class DifferentiableFuncTypeInfo final : public RecordTypeInfo { using super = RecordTypeInfo; public: DifferentiableFuncTypeInfo(ArrayRef fields, unsigned explosionSize, llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align, IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize) : super(fields, explosionSize, ty, size, std::move(spareBits), align, isPOD, alwaysFixedSize) {} Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, const DifferentiableFuncFieldInfo &field) const { return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); } void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src, SILType T, bool isOutlined) const override { llvm_unreachable("unexploded @differentiable function as argument?"); } void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering, Size offset) const override { for (auto &field : getFields()) { auto fieldOffset = offset + field.getFixedByteOffset(); cast(field.getTypeInfo()) .addToAggLowering(IGM, lowering, fieldOffset); } } TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM, SILType T) const override { return IGM.typeLayoutCache.getOrCreateScalarEntry(*this, T); } llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; } llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const { return None; } }; class DifferentiableFuncTypeBuilder : public RecordTypeBuilder { SILFunctionType *originalType; IndexSubset *parameterIndices; public: DifferentiableFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) : RecordTypeBuilder(IGM), originalType(fnTy->getWithoutDifferentiability()), parameterIndices(fnTy->getDifferentiabilityParameterIndices()) { assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal); } TypeInfo *createFixed(ArrayRef fields, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } DifferentiableFuncTypeInfo * createLoadable(ArrayRef fields, StructLayout &&layout, unsigned explosionSize) { return DifferentiableFuncTypeInfo::create( fields, explosionSize, layout.getType(), layout.getSize(), std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(), layout.isAlwaysFixedSize()); } TypeInfo *createNonFixed(ArrayRef fields, FieldsAreABIAccessible_t fieldsAccessible, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } DifferentiableFuncFieldInfo getFieldInfo(unsigned index, NormalDifferentiableFunctionTypeComponent component, const TypeInfo &fieldTI) { return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices); } SILType getType(NormalDifferentiableFunctionTypeComponent component) { if (component == NormalDifferentiableFunctionTypeComponent::Original) return SILType::getPrimitiveObjectType(originalType->getCanonicalType()); auto kind = *component.getAsDerivativeFunctionKind(); auto assocTy = originalType->getAutoDiffDerivativeFunctionType( parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(assocTy); } StructLayout performLayout(ArrayRef fieldTypes) { return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject, LayoutStrategy::Universal, fieldTypes); } }; } // end anonymous namespace //----------------------------------------------------------------------------// // `@differentiable(linear)` function type info //----------------------------------------------------------------------------// namespace { class LinearFuncFieldInfo final : public RecordField { public: LinearFuncFieldInfo(LinearDifferentiableFunctionTypeComponent component, const TypeInfo &type, IndexSubset *parameterIndices) : RecordField(type), component(component), parameterIndices(parameterIndices) {} /// The field index. const LinearDifferentiableFunctionTypeComponent component; /// The parameter indices. IndexSubset *parameterIndices; std::string getFieldName() const { switch (component) { case LinearDifferentiableFunctionTypeComponent::Original: return "original"; case LinearDifferentiableFunctionTypeComponent::Transpose: return "transpose"; } } SILType getType(IRGenModule &IGM, SILType t) const { auto fnTy = t.castTo(); auto origFnTy = fnTy->getWithoutDifferentiability(); switch (component) { case LinearDifferentiableFunctionTypeComponent::Original: return SILType::getPrimitiveObjectType(origFnTy); case LinearDifferentiableFunctionTypeComponent::Transpose: auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType( parameterIndices, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(transposeTy); } } }; class LinearFuncTypeInfo final : public RecordTypeInfo { using super = RecordTypeInfo; public: LinearFuncTypeInfo(ArrayRef fields, unsigned explosionSize, llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align, IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize) : super(fields, explosionSize, ty, size, std::move(spareBits), align, isPOD, alwaysFixedSize) {} Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, const LinearFuncFieldInfo &field) const { return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); } void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src, SILType T, bool isOutlined) const override { llvm_unreachable("unexploded @differentiable function as argument?"); } void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering, Size offset) const override { for (auto &field : getFields()) { auto fieldOffset = offset + field.getFixedByteOffset(); cast(field.getTypeInfo()) .addToAggLowering(IGM, lowering, fieldOffset); } } TypeLayoutEntry *buildTypeLayoutEntry(IRGenModule &IGM, SILType T) const override { return IGM.typeLayoutCache.getOrCreateScalarEntry(*this, T); } llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; } llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const { return None; } }; class LinearFuncTypeBuilder : public RecordTypeBuilder { SILFunctionType *originalType; IndexSubset *parameterIndices; public: LinearFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) : RecordTypeBuilder(IGM), originalType(fnTy->getWithoutDifferentiability()), parameterIndices(fnTy->getDifferentiabilityParameterIndices()) { assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear); } TypeInfo *createFixed(ArrayRef fields, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } LinearFuncTypeInfo *createLoadable(ArrayRef fields, StructLayout &&layout, unsigned explosionSize) { return LinearFuncTypeInfo::create( fields, explosionSize, layout.getType(), layout.getSize(), std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(), layout.isAlwaysFixedSize()); } TypeInfo *createNonFixed(ArrayRef fields, FieldsAreABIAccessible_t fieldsAccessible, StructLayout &&layout) { llvm_unreachable("@differentiable functions are always loadable"); } LinearFuncFieldInfo getFieldInfo(unsigned index, LinearDifferentiableFunctionTypeComponent field, const TypeInfo &fieldTI) { return LinearFuncFieldInfo(field, fieldTI, parameterIndices); } SILType getType(LinearDifferentiableFunctionTypeComponent component) { switch (component) { case LinearDifferentiableFunctionTypeComponent::Original: return SILType::getPrimitiveObjectType(originalType->getCanonicalType()); case LinearDifferentiableFunctionTypeComponent::Transpose: auto transposeTy = originalType->getAutoDiffTransposeFunctionType( parameterIndices, IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule())); return SILType::getPrimitiveObjectType(transposeTy); } } StructLayout performLayout(ArrayRef fieldTypes) { return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject, LayoutStrategy::Universal, fieldTypes); } }; } // end anonymous namespace //----------------------------------------------------------------------------// // Type converter entry points //----------------------------------------------------------------------------// const TypeInfo * TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) { DifferentiableFuncTypeBuilder builder(IGM, type); return builder.layout({NormalDifferentiableFunctionTypeComponent::Original, NormalDifferentiableFunctionTypeComponent::JVP, NormalDifferentiableFunctionTypeComponent::VJP}); } const TypeInfo * TypeConverter::convertLinearDifferentiableFunctionType(SILFunctionType *type) { LinearFuncTypeBuilder builder(IGM, type); return builder.layout({LinearDifferentiableFunctionTypeComponent::Original, LinearDifferentiableFunctionTypeComponent::Transpose}); }