diff --git a/docs/SIL.rst b/docs/SIL.rst index 29387e8819c..fb2f3297e4e 100644 --- a/docs/SIL.rst +++ b/docs/SIL.rst @@ -5666,6 +5666,42 @@ destination (if it returns with ``throw``). The rules on generic substitutions are identical to those of ``apply``. +Differentiable Programming +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +differentiability_witness_function +`````````````````````````````````` +:: + + sil-instruction ::= + 'differentiability_witness_function' + '[' sil-differentiability-witness-function-kind ']' + '[' 'parameters' sil-differentiability-witness-function-index-list ']' + '[' 'results' sil-differentiability-witness-function-index-list ']' + generic-parameter-clause? + sil-function-name ':' sil-type + + sil-differentiability-witness-function-kind ::= 'jvp' | 'vjp' | 'transpose' + sil-differentiability-witness-function-index-list ::= [0-9]+ (' ' [0-9]+)* + + differentiability_witness_function [jvp] [parameters 0] [results 0] \ + @foo : $(T) -> T + +Looks up a differentiability witness function (JVP, VJP, or transpose) for +a referenced function via SIL differentiability witnesses. + +The differentiability witness function kind identifies the witness function to +look up: ``[jvp]``, ``[vjp]``, or ``[transpose]``. + +The remaining components identify the SIL differentiability witness: + +- Original function name. +- Parameter indices. +- Result indices. +- Witness generic parameter clause (optional). When parsing SIL, the parsed + witness generic parameter clause is combined with the original function's + generic signature to form the full witness generic signature. + Assertion configuration ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index b07722194f8..5ff36652e64 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -73,6 +73,27 @@ struct AutoDiffDerivativeFunctionKind { } }; +/// The kind of a differentiability witness function. +struct DifferentiabilityWitnessFunctionKind { + enum innerty : uint8_t { + // The Jacobian-vector products function. + JVP = 0, + // The vector-Jacobian products function. + VJP = 1, + // The transpose function. + Transpose = 2 + } rawValue; + + DifferentiabilityWitnessFunctionKind() = default; + DifferentiabilityWitnessFunctionKind(innerty rawValue) : rawValue(rawValue) {} + explicit DifferentiabilityWitnessFunctionKind(unsigned rawValue) + : rawValue(static_cast(rawValue)) {} + explicit DifferentiabilityWitnessFunctionKind(StringRef name); + operator innerty() const { return rawValue; } + + Optional getAsDerivativeFunctionKind() const; +}; + /// Identifies an autodiff derivative function configuration: /// - Parameter indices. /// - Result indices. diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index 37d2af115da..32439b2587b 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1609,6 +1609,13 @@ ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken, "expected the index of a parameter to differentiate with respect to", ()) ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken, "expected the index of a result to differentiate from", ()) +ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken, + "expected a differentiability witness kind, which can be one of '[jvp]', " + "'[vjp]', or '[transpose]'", ()) +ERROR(sil_inst_autodiff_invalid_witness_generic_signature,PointsToFirstBadToken, + "expected witness_generic signature '%0' does not have same generic " + "parameters as original function generic signature '%1'", + (StringRef, StringRef)) //------------------------------------------------------------------------------ // MARK: Generics parsing diagnostics diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index f279acb259b..6cd843ad093 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -2157,6 +2157,20 @@ public: SILValue emitThickToObjCMetatype(SILLocation Loc, SILValue Op, SILType Ty); SILValue emitObjCToThickMetatype(SILLocation Loc, SILValue Op, SILType Ty); + //===--------------------------------------------------------------------===// + // Differentiable programming instructions + //===--------------------------------------------------------------------===// + + /// Note: explicit function type may be specified only in lowered SIL. + DifferentiabilityWitnessFunctionInst *createDifferentiabilityWitnessFunction( + SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind, + SILDifferentiabilityWitness *Witness, + Optional FunctionType = None) { + return insert(new (getModule()) DifferentiabilityWitnessFunctionInst( + getModule(), getSILDebugLocation(Loc), WitnessKind, Witness, + FunctionType)); + } + //===--------------------------------------------------------------------===// // Private Helper Methods //===--------------------------------------------------------------------===// diff --git a/include/swift/SIL/SILCloner.h b/include/swift/SIL/SILCloner.h index 8563c2dcbe3..46ead94acde 100644 --- a/include/swift/SIL/SILCloner.h +++ b/include/swift/SIL/SILCloner.h @@ -2825,6 +2825,16 @@ void SILCloner::visitKeyPathInst(KeyPathInst *Inst) { opValues, getOpType(Inst->getType()))); } +template +void SILCloner::visitDifferentiabilityWitnessFunctionInst( + DifferentiabilityWitnessFunctionInst *Inst) { + getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope())); + recordClonedInstruction(Inst, + getBuilder().createDifferentiabilityWitnessFunction( + getOpLocation(Inst->getLoc()), + Inst->getWitnessKind(), Inst->getWitness())); +} + } // end namespace swift #endif diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 33a03dd7267..8b0cb64cbc4 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -17,6 +17,7 @@ #ifndef SWIFT_SIL_INSTRUCTION_H #define SWIFT_SIL_INSTRUCTION_H +#include "swift/AST/AutoDiff.h" #include "swift/AST/Builtins.h" #include "swift/AST/Decl.h" #include "swift/AST/GenericSignature.h" @@ -61,6 +62,7 @@ class SILBasicBlock; class SILBuilder; class SILDebugLocation; class SILDebugScope; +class SILDifferentiabilityWitness; class SILFunction; class SILGlobalVariable; class SILInstructionResultArray; @@ -7931,6 +7933,40 @@ class TryApplyInst final const GenericSpecializationInformation *SpecializationInfo); }; +class DifferentiabilityWitnessFunctionInst + : public InstructionBase< + SILInstructionKind::DifferentiabilityWitnessFunctionInst, + SingleValueInstruction> { +private: + friend SILBuilder; + /// The differentiability witness function kind. + DifferentiabilityWitnessFunctionKind witnessKind; + /// The referenced SIL differentiability witness. + SILDifferentiabilityWitness *witness; + /// Whether the instruction has an explicit function type. + bool hasExplicitFunctionType; + + static SILType getDifferentiabilityWitnessType( + SILModule &module, DifferentiabilityWitnessFunctionKind witnessKind, + SILDifferentiabilityWitness *witness); + +public: + /// Note: explicit function type may be specified only in lowered SIL. + DifferentiabilityWitnessFunctionInst( + SILModule &module, SILDebugLocation loc, + DifferentiabilityWitnessFunctionKind witnessKind, + SILDifferentiabilityWitness *witness, Optional FunctionType); + + DifferentiabilityWitnessFunctionKind getWitnessKind() const { + return witnessKind; + } + SILDifferentiabilityWitness *getWitness() const { return witness; } + bool getHasExplicitFunctionType() const { return hasExplicitFunctionType; } + + ArrayRef getAllOperands() const { return {}; } + MutableArrayRef getAllOperands() { return {}; } +}; + // This is defined out of line to work around the fact that this depends on // PartialApplyInst being defined, but PartialApplyInst is a subclass of // ApplyInstBase, so we can not place ApplyInstBase after it. diff --git a/include/swift/SIL/SILNodes.def b/include/swift/SIL/SILNodes.def index b49c8abea53..d00771420d2 100644 --- a/include/swift/SIL/SILNodes.def +++ b/include/swift/SIL/SILNodes.def @@ -688,6 +688,11 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction) SINGLE_VALUE_INST(InitBlockStorageHeaderInst, init_block_storage_header, SingleValueInstruction, None, DoesNotRelease) + // Differentiable programming + SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst, + differentiability_witness_function, + SingleValueInstruction, None, DoesNotRelease) + // Key paths // TODO: The only "side effect" is potentially retaining the returned key path // object; is there a more specific effect? diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 52c1d7353c8..773faed41e9 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -2,7 +2,7 @@ // // This source file is part of the Swift.org open source project // -// Copyright (c) 2019 Apple Inc. and the Swift project authors +// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information @@ -12,11 +12,34 @@ #include "swift/AST/AutoDiff.h" #include "swift/AST/ASTContext.h" +#include "swift/AST/Module.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/AST/Types.h" using namespace swift; +DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind( + StringRef string) { + Optional result = llvm::StringSwitch>(string) + .Case("jvp", JVP) + .Case("vjp", VJP) + .Case("transpose", Transpose); + assert(result && "Invalid string"); + rawValue = *result; +} + +Optional +DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const { + switch (rawValue) { + case JVP: + return {AutoDiffDerivativeFunctionKind::JVP}; + case VJP: + return {AutoDiffDerivativeFunctionKind::VJP}; + case Transpose: + return None; + } +} + void AutoDiffConfig::print(llvm::raw_ostream &s) const { s << "(parameters="; parameterIndices->print(s); diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp index 6ee03d00f49..5b677fc2bc7 100644 --- a/lib/IRGen/IRGenSIL.cpp +++ b/lib/IRGen/IRGenSIL.cpp @@ -1042,6 +1042,9 @@ public: void visitKeyPathInst(KeyPathInst *I); + void visitDifferentiabilityWitnessFunctionInst( + DifferentiabilityWitnessFunctionInst *i); + #define LOADABLE_REF_STORAGE_HELPER(Name) \ void visitRefTo##Name##Inst(RefTo##Name##Inst *i); \ void visit##Name##ToRefInst(Name##ToRefInst *i); \ @@ -1809,6 +1812,33 @@ void IRGenSILFunction::visitSILBasicBlock(SILBasicBlock *BB) { assert(Builder.hasPostTerminatorIP() && "SIL bb did not terminate block?!"); } +void IRGenSILFunction::visitDifferentiabilityWitnessFunctionInst( + DifferentiabilityWitnessFunctionInst *i) { + llvm::Value *diffWitness = + IGM.getAddrOfDifferentiabilityWitness(i->getWitness()); + unsigned offset = 0; + switch (i->getWitnessKind()) { + case DifferentiabilityWitnessFunctionKind::JVP: + offset = 0; + break; + case DifferentiabilityWitnessFunctionKind::VJP: + offset = 1; + break; + case DifferentiabilityWitnessFunctionKind::Transpose: + llvm_unreachable("Not yet implemented"); + } + + diffWitness = Builder.CreateStructGEP(diffWitness, offset); + diffWitness = Builder.CreateLoad(diffWitness, IGM.getPointerAlignment()); + + auto fnType = cast(i->getType().getASTType()); + Signature signature = IGM.getSignature(fnType); + diffWitness = + Builder.CreateBitCast(diffWitness, signature.getType()->getPointerTo()); + + setLoweredFunctionPointer(i, FunctionPointer(diffWitness, signature)); +} + void IRGenSILFunction::visitFunctionRefBaseInst(FunctionRefBaseInst *i) { auto fn = i->getInitiallyReferencedFunction(); diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index ca3b5f3e86f..b95225dd4c5 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -5023,6 +5023,47 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B, blockType, subMap); break; } + case SILInstructionKind::DifferentiabilityWitnessFunctionInst: { + // e.g. differentiability_witness_function + // [jvp] [parameters 0 1] [results 0] + // @foo : $(T) -> T + DifferentiabilityWitnessFunctionKind witnessKind; + StringRef witnessKindNames[3] = {"jvp", "vjp", "transpose"}; + if (P.parseToken( + tok::l_square, + diag:: + sil_inst_autodiff_expected_differentiability_witness_kind) || + parseSILIdentifierSwitch( + witnessKind, witnessKindNames, + diag:: + sil_inst_autodiff_expected_differentiability_witness_kind) || + P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare, + "differentiability witness function kind")) + return true; + SourceLoc keyStartLoc = P.Tok.getLoc(); + auto configAndFn = + parseSILDifferentiabilityWitnessConfigAndFunction(P, *this, InstLoc); + if (!configAndFn) + return true; + auto config = configAndFn->first; + auto originalFn = configAndFn->second; + auto *witness = SILMod.lookUpDifferentiabilityWitness( + {originalFn->getName(), config}); + if (!witness) { + P.diagnose(keyStartLoc, diag::sil_diff_witness_undefined); + return true; + } + // Parse an optional explicit function type. + Optional functionType = None; + if (P.consumeIf(tok::kw_as)) { + functionType = SILType(); + if (parseSILType(*functionType)) + return true; + } + ResultVal = B.createDifferentiabilityWitnessFunction( + InstLoc, witnessKind, witness, functionType); + break; + } } return false; diff --git a/lib/SIL/OperandOwnership.cpp b/lib/SIL/OperandOwnership.cpp index 9be5f357ddd..86a4d2b3f9c 100644 --- a/lib/SIL/OperandOwnership.cpp +++ b/lib/SIL/OperandOwnership.cpp @@ -116,6 +116,7 @@ SHOULD_NEVER_VISIT_INST(AllocBox) SHOULD_NEVER_VISIT_INST(AllocExistentialBox) SHOULD_NEVER_VISIT_INST(AllocGlobal) SHOULD_NEVER_VISIT_INST(AllocStack) +SHOULD_NEVER_VISIT_INST(DifferentiabilityWitnessFunction) SHOULD_NEVER_VISIT_INST(FloatLiteral) SHOULD_NEVER_VISIT_INST(FunctionRef) SHOULD_NEVER_VISIT_INST(DynamicFunctionRef) diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 27439f486b2..c0f2e149df1 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -605,6 +605,50 @@ TryApplyInst *TryApplyInst::create( normalBB, errorBB, specializationInfo); } +SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType( + SILModule &module, DifferentiabilityWitnessFunctionKind witnessKind, + SILDifferentiabilityWitness *witness) { + auto fnTy = witness->getOriginalFunction()->getLoweredFunctionType(); + CanGenericSignature witnessCanGenSig; + if (auto witnessGenSig = witness->getDerivativeGenericSignature()) + witnessCanGenSig = witnessGenSig->getCanonicalSignature(); + auto *parameterIndices = witness->getParameterIndices(); + auto *resultIndices = witness->getResultIndices(); + if (auto derivativeKind = witnessKind.getAsDerivativeFunctionKind()) { + bool isReabstractionThunk = + witness->getOriginalFunction()->isThunk() == IsReabstractionThunk; + auto diffFnTy = fnTy->getAutoDiffDerivativeFunctionType( + parameterIndices, *resultIndices->begin(), *derivativeKind, + module.Types, LookUpConformanceInModule(module.getSwiftModule()), + witnessCanGenSig, isReabstractionThunk); + return SILType::getPrimitiveObjectType(diffFnTy); + } + assert(witnessKind == DifferentiabilityWitnessFunctionKind::Transpose); + auto transposeFnTy = fnTy->getAutoDiffTransposeFunctionType( + parameterIndices, module.Types, + LookUpConformanceInModule(module.getSwiftModule()), witnessCanGenSig); + return SILType::getPrimitiveObjectType(transposeFnTy); +} + +DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst( + SILModule &module, SILDebugLocation debugLoc, + DifferentiabilityWitnessFunctionKind witnessKind, + SILDifferentiabilityWitness *witness, Optional functionType) + : InstructionBase(debugLoc, functionType + ? *functionType + : getDifferentiabilityWitnessType( + module, witnessKind, witness)), + witnessKind(witnessKind), witness(witness), + hasExplicitFunctionType(functionType) { + assert(witness && "Differentiability witness must not be null"); +#ifndef NDEBUG + if (functionType.hasValue()) { + assert(module.getStage() == SILStage::Lowered && + "Explicit type is valid only in lowered SIL"); + } +#endif +} + FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind, SILDebugLocation DebugLoc, SILFunction *F, diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 52303659e5d..921af1844b0 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -2255,6 +2255,40 @@ public: } } } + + void visitDifferentiabilityWitnessFunctionInst( + DifferentiabilityWitnessFunctionInst *dwfi) { + auto *witness = dwfi->getWitness(); + *this << '['; + switch (dwfi->getWitnessKind()) { + case DifferentiabilityWitnessFunctionKind::JVP: + *this << "jvp"; + break; + case DifferentiabilityWitnessFunctionKind::VJP: + *this << "vjp"; + break; + case DifferentiabilityWitnessFunctionKind::Transpose: + *this << "transpose"; + break; + } + *this << "] [parameters"; + for (auto i : witness->getParameterIndices()->getIndices()) + *this << ' ' << i; + *this << "] [results"; + for (auto i : witness->getResultIndices()->getIndices()) + *this << ' ' << i; + *this << "] "; + if (auto witnessGenSig = witness->getDerivativeGenericSignature()) { + auto subPrinter = PrintOptions::printSIL(); + witnessGenSig->print(PrintState.OS, subPrinter); + *this << " "; + } + printSILFunctionNameAndType(PrintState.OS, witness->getOriginalFunction()); + if (dwfi->getHasExplicitFunctionType()) { + *this << " as "; + *this << dwfi->getType(); + } + } }; } // end anonymous namespace diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index 2cfc03caf56..4f6436da271 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -4534,6 +4534,28 @@ public: "unknown verfication type"); } + void checkDifferentiabilityWitnessFunctionInst( + DifferentiabilityWitnessFunctionInst *dwfi) { + auto witnessFnTy = dwfi->getType().castTo(); + auto *witness = dwfi->getWitness(); + // `DifferentiabilityWitnessFunctionInst` constructor asserts that + // `witness` is non-null. + auto witnessKind = dwfi->getWitnessKind(); + // Return if not witnessing a derivative function. + auto derivKind = witnessKind.getAsDerivativeFunctionKind(); + if (!derivKind) + return; + // Return if witness does not define the referenced derivative. + auto *derivativeFn = witness->getDerivative(*derivKind); + if (!derivativeFn) + return; + auto derivativeFnTy = derivativeFn->getLoweredFunctionType(); + requireSameType(SILType::getPrimitiveObjectType(witnessFnTy), + SILType::getPrimitiveObjectType(derivativeFnTy), + "Type of witness instruction does not match actual type of " + "witnessed function"); + } + // This verifies that the entry block of a SIL function doesn't have // any predecessors and also verifies the entry point arguments. void verifyEntryBlock(SILBasicBlock *entry) { diff --git a/lib/SIL/ValueOwnership.cpp b/lib/SIL/ValueOwnership.cpp index 0cfea82a320..49fa081e8dd 100644 --- a/lib/SIL/ValueOwnership.cpp +++ b/lib/SIL/ValueOwnership.cpp @@ -148,6 +148,7 @@ CONSTANT_OWNERSHIP_INST(None, WitnessMethod) CONSTANT_OWNERSHIP_INST(None, StoreBorrow) CONSTANT_OWNERSHIP_INST(None, ConvertEscapeToNoEscape) CONSTANT_OWNERSHIP_INST(Unowned, InitBlockStorageHeader) +CONSTANT_OWNERSHIP_INST(None, DifferentiabilityWitnessFunction) // TODO: It would be great to get rid of these. CONSTANT_OWNERSHIP_INST(Unowned, RawPointerToRef) CONSTANT_OWNERSHIP_INST(Unowned, ObjCProtocol) diff --git a/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp b/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp index e226949cc47..3b6bd7da1d0 100644 --- a/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp +++ b/lib/SILOptimizer/UtilityPasses/SerializeSILPass.cpp @@ -327,6 +327,7 @@ static bool hasOpaqueArchetype(TypeExpansionContext context, case SILInstructionKind::CondFailInst: case SILInstructionKind::DestructureStructInst: case SILInstructionKind::DestructureTupleInst: + case SILInstructionKind::DifferentiabilityWitnessFunctionInst: // Handle by operand and result check. break; diff --git a/lib/SILOptimizer/Utils/SILInliner.cpp b/lib/SILOptimizer/Utils/SILInliner.cpp index 5483e099d4d..cf760936987 100644 --- a/lib/SILOptimizer/Utils/SILInliner.cpp +++ b/lib/SILOptimizer/Utils/SILInliner.cpp @@ -875,6 +875,7 @@ InlineCost swift::instructionInlineCost(SILInstruction &I) { case SILInstructionKind::SelectValueInst: case SILInstructionKind::KeyPathInst: case SILInstructionKind::GlobalValueInst: + case SILInstructionKind::DifferentiabilityWitnessFunctionInst: #define COMMON_ALWAYS_OR_SOMETIMES_LOADABLE_CHECKED_REF_STORAGE(Name) \ case SILInstructionKind::Name##ToRefInst: \ case SILInstructionKind::RefTo##Name##Inst: \ diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index b30ff8fe71e..bdd9caab225 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -2570,6 +2570,19 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, ResultVal = Builder.createKeyPath(Loc, pattern, subMap, operands, kpTy); break; } + case SILInstructionKind::DifferentiabilityWitnessFunctionInst: { + StringRef mangledKey = MF->getIdentifierText(ValID); + auto *witness = getSILDifferentiabilityWitnessForReference(mangledKey); + assert(witness && "SILDifferentiabilityWitness not found"); + DifferentiabilityWitnessFunctionKind witnessKind(Attr); + Optional explicitFnTy = None; + auto astTy = MF->getType(TyID); + if (TyID) + explicitFnTy = getSILType(astTy, SILValueCategory::Object, Fn); + ResultVal = Builder.createDifferentiabilityWitnessFunction( + Loc, witnessKind, witness, explicitFnTy); + break; + } } for (auto result : ResultVal->getResults()) { diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 0eee80debde..8b0ef1d18e1 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 539; // swift master-rebranch +const uint16_t SWIFTMODULE_VERSION_MINOR = 540; // differentiability_witness_function instruction /// A standard hash seed used for all string hashes in a serialized module. /// diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index a45b9392d36..d89a46e5040 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -2153,6 +2153,24 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { break; } + case SILInstructionKind::DifferentiabilityWitnessFunctionInst: { + auto *dwfi = cast(&SI); + auto *witness = dwfi->getWitness(); + DifferentiabilityWitnessesToEmit.insert(witness); + Mangle::ASTMangler mangler; + auto mangledKey = + mangler.mangleSILDifferentiabilityWitnessKey(witness->getKey()); + auto rawWitnessKind = (unsigned)dwfi->getWitnessKind(); + // We only store the type when the instruction has an explicit type. + bool hasExplicitFnTy = dwfi->getHasExplicitFunctionType(); + SILOneOperandLayout::emitRecord( + Out, ScratchRecord, SILAbbrCodes[SILOneOperandLayout::Code], + (unsigned)dwfi->getKind(), rawWitnessKind, + hasExplicitFnTy ? S.addTypeRef(dwfi->getType().getASTType()) : TypeID(), + hasExplicitFnTy ? (unsigned)dwfi->getType().getCategory() : 0, + S.addUniquedStringRef(mangledKey)); + break; + } } // Non-void values get registered in the value table. for (auto result : SI.getResults()) { diff --git a/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst.sil b/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst.sil new file mode 100644 index 00000000000..fd98b270a2f --- /dev/null +++ b/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst.sil @@ -0,0 +1,116 @@ +// Round-trip parsing/printing test. + +// RUN: %target-sil-opt %s | %FileCheck %s + +// Round-trip serialization-deserialization test. + +// RUN: %empty-directory(%t) +// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name main +// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.sil -module-name main +// NOTE(SR-12090): Workaround because import declarations are not preserved in .sib files. +// RUN: sed -e 's/import Swift$/import Swift; import _Differentiation/' %t/tmp.sil > %t/tmp_fixed.sil +// RUN: %target-sil-opt %t/tmp_fixed.sil -module-name main -emit-sorted-sil | %FileCheck %s + +// IRGen test. + +// RUN: %target-swift-frontend -emit-ir %s | %FileCheck %s --check-prefix=IRGEN --check-prefix %target-cpu +// NOTE: `%target-cpu`-specific FileCheck lines exist because lowered function +// types in LLVM IR differ between architectures. + +// REQUIRES: differentiable_programming +// NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround. +// REQUIRES: shell + +sil_stage raw + +import Swift +import Builtin + +import _Differentiation + +sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float + +sil_differentiability_witness [parameters 0 1] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float + +sil_differentiability_witness [parameters 0] [results 0] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + +sil_differentiability_witness [parameters 0 1] [results 0 1] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + +sil_differentiability_witness [parameters 0] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil_differentiability_witness [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil_differentiability_witness [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil_differentiability_witness [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil @foo : $@convention(thin) (Float, Float, Float) -> Float + +sil @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + +sil @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil @genericreq : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil @test_derivative_witnesses : $@convention(thin) () -> () { +bb0: + %foo_jvp_wrt_0 = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float + %foo_vjp_wrt_0_1 = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float + + // Test multiple results. + %bar_jvp_wrt_0_results_0 = differentiability_witness_function [jvp] [parameters 0] [results 0] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %bar_vjp_wrt_0_1_results_0_1 = differentiability_witness_function [vjp] [parameters 0 1] [results 0 1] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + + // Test generic requirements. + %generic_jvp_wrt_0 = differentiability_witness_function [jvp] [parameters 0] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + %generic_vjp_wrt_0_1 = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + + // Test "dependent" generic requirements: `T == T.TangentVector` depends on `T: Differentiable`. + %generic_vjp_wrt_0_1_dependent_req = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + + return undef : $() +} + +// CHECK-LABEL: sil @test_derivative_witnesses : $@convention(thin) () -> () { +// CHECK: bb0: +// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float +// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float +// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) +// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0 1] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) +// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 +// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 +// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 +// CHECK: } + +// IRGEN: @AD__foo_PSUURS = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT:[0-9]+]] +// IRGEN: @AD__foo_PSSURS = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @AD__bar_PSUURSU = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @AD__bar_PSSURSS = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @AD__generic_PSURS16_Differentiation14DifferentiableRzl = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @AD__generic_PSSRSs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzl = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] +// IRGEN: @AD__generic_PSSRS16_Differentiation14DifferentiableRz13TangentVectorAaBPQzRszl = external global %swift.differentiability_witness, align [[PTR_ALIGNMENT]] + +// IRGEN-LABEL: define {{.*}} @test_derivative_witnesses() + +// IRGEN: [[PTR1:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__foo_PSUURS, i32 0, i32 0), align [[PTR_ALIGNMENT]] +// IRGEN: [[FNPTR1:%.*]] = bitcast i8* [[PTR1]] to { float, i8*, %swift.refcounted* } (float, float, float)* + +// IRGEN: [[PTR2:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__foo_PSSURS, i32 0, i32 1), align [[PTR_ALIGNMENT]] +// IRGEN: [[FNPTR2:%.*]] = bitcast i8* [[PTR2]] to { float, i8*, %swift.refcounted* } (float, float, float)* + +// IRGEN: [[PTR3:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__bar_PSUURSU, i32 0, i32 0), align [[PTR_ALIGNMENT]] +// x86_64: [[FNPTR3:%.*]] = bitcast i8* [[PTR3]] to { float, float, i8*, %swift.refcounted* } (float, float, float)* +// i386: [[FNPTR3:%.*]] = bitcast i8* [[PTR3]] to void (<{ %TSf, %TSf, %swift.function }>*, float, float, float)* + +// IRGEN: [[PTR4:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__bar_PSSURSS, i32 0, i32 1), align [[PTR_ALIGNMENT]] +// x86_64: [[FNPTR4:%.*]] = bitcast i8* [[PTR4]] to { float, float, i8*, %swift.refcounted* } (float, float, float)* +// i386: [[FNPTR4:%.*]] = bitcast i8* [[PTR4]] to void (<{ %TSf, %TSf, %swift.function }>*, float, float, float)* + +// IRGEN: [[PTR5:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__generic_PSURS16_Differentiation14DifferentiableRzl, i32 0, i32 0), align [[PTR_ALIGNMENT]] +// IRGEN: [[FNPTR5:%.*]] = bitcast i8* [[PTR5]] to { i8*, %swift.refcounted* } (%swift.opaque*, %swift.opaque*, float, %swift.type*, i8**)* + +// IRGEN: [[PTR6:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__generic_PSSRSs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzl, i32 0, i32 1), align [[PTR_ALIGNMENT]] +// IRGEN: [[FNPTR6:%.*]] = bitcast i8* [[PTR6]] to { i8*, %swift.refcounted* } (%swift.opaque*, %swift.opaque*, float, %swift.type*, i8**, i8**)* + +// IRGEN: [[PTR7:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__generic_PSSRS16_Differentiation14DifferentiableRz13TangentVectorAaBPQzRszl, i32 0, i32 1), align [[PTR_ALIGNMENT]] +// IRGEN: [[FNPTR7:%.*]] = bitcast i8* [[PTR7]] to { i8*, %swift.refcounted* } (%swift.opaque*, %swift.opaque*, float, %swift.type*, i8**)* diff --git a/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst_transpose.sil b/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst_transpose.sil new file mode 100644 index 00000000000..dc8289d8827 --- /dev/null +++ b/test/AutoDiff/SIL/Serialization/differentiability_witness_function_inst_transpose.sil @@ -0,0 +1,81 @@ +// Note: this test is separate from `differentiability_witness_function_inst.sil` +// because `differentiability_witness_function [transpose]` instructions do not +// have IRGen support yet. + +// Round-trip parsing/printing test. + +// RUN: %target-sil-opt %s | %FileCheck %s + +// Round-trip serialization-deserialization test. + +// RUN: %empty-directory(%t) +// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name main +// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.sil -module-name main +// NOTE(SR-12090): Workaround because import declarations are not preserved in .sib files. +// RUN: sed -e 's/import Swift$/import Swift; import _Differentiation/' %t/tmp.sil > %t/tmp_fixed.sil +// RUN: %target-sil-opt %t/tmp_fixed.sil -module-name main -emit-sorted-sil | %FileCheck %s + +// REQUIRES: differentiable_programming +// NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround. +// REQUIRES: shell + +sil_stage raw + +import Swift +import Builtin + +import _Differentiation + +sil_differentiability_witness [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float + +sil_differentiability_witness [parameters 0 1] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float + +sil_differentiability_witness [parameters 0] [results 0] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + +sil_differentiability_witness [parameters 0 1] [results 0 1] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + +sil_differentiability_witness [parameters 0] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil_differentiability_witness [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil_differentiability_witness [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil_differentiability_witness [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil @foo : $@convention(thin) (Float, Float, Float) -> Float + +sil @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + +sil @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil @genericreq : $@convention(thin) (@in_guaranteed T, Float) -> @out T + +sil @test_transpose_witnesses : $@convention(thin) () -> () { +bb0: + %foo_t_wrt_0 = differentiability_witness_function [transpose] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float + %foo_t_wrt_0_1 = differentiability_witness_function [transpose] [parameters 0 1] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float + + // Test multiple results. + %bar_t_wrt_0_results_0 = differentiability_witness_function [transpose] [parameters 0] [results 0] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + %bar_t_wrt_0_1_results_0_1 = differentiability_witness_function [transpose] [parameters 0 1] [results 0 1] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) + + // Test generic requirements. + %generic_t_wrt_0 = differentiability_witness_function [transpose] [parameters 0] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + %generic_t_wrt_0_1 = differentiability_witness_function [transpose] [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + + // Test "dependent" generic requirements: `T == T.TangentVector` depends on `T: Differentiable`. + %generic_t_wrt_0_1_dependent_req = differentiability_witness_function [transpose] [parameters 0 1] [results 0] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T + return undef : $() +} + +// CHECK-LABEL: sil @test_transpose_witnesses : $@convention(thin) () -> () { +// CHECK: bb0: +// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float +// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0 1] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float +// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0] [results 0] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) +// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0 1] [results 0 1] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float) +// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 +// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 +// CHECK: return undef : $() +// CHECK: } +