mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[AutoDiff upstream] [SIL] Add differentiable function instructions.
Add `differentiable_function` and `differentiable_function_extract` instructions. `differentiable_function` creates a `@differentiable` function-typed value from an original function operand and derivative function operands (optional). `differentiable_function_extract` extracts either the original or derivative function value from a `@differentiable` function. The differentiation transform canonicalizes `differentiable_function` instructions, filling in derivative function operands if missing. Resolves TF-1139 and TF-1140.
This commit is contained in:
@@ -2152,6 +2152,38 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) {
|
||||
|
||||
break;
|
||||
}
|
||||
case SILInstructionKind::DifferentiableFunctionInst: {
|
||||
auto *dfi = cast<DifferentiableFunctionInst>(&SI);
|
||||
SmallVector<ValueID, 4> trailingInfo;
|
||||
auto *paramIndices = dfi->getParameterIndices();
|
||||
for (unsigned idx : paramIndices->getIndices())
|
||||
trailingInfo.push_back(idx);
|
||||
for (auto &op : dfi->getAllOperands()) {
|
||||
auto val = op.get();
|
||||
trailingInfo.push_back(S.addTypeRef(val->getType().getASTType()));
|
||||
trailingInfo.push_back((unsigned)val->getType().getCategory());
|
||||
trailingInfo.push_back(addValueRef(val));
|
||||
}
|
||||
SILInstDifferentiableFunctionLayout::emitRecord(
|
||||
Out, ScratchRecord,
|
||||
SILAbbrCodes[SILInstDifferentiableFunctionLayout::Code],
|
||||
paramIndices->getCapacity(), dfi->hasDerivativeFunctions(),
|
||||
trailingInfo);
|
||||
break;
|
||||
}
|
||||
case SILInstructionKind::DifferentiableFunctionExtractInst: {
|
||||
auto *dfei = cast<DifferentiableFunctionExtractInst>(&SI);
|
||||
auto operandRef = addValueRef(dfei->getOperand());
|
||||
auto operandType = dfei->getOperand()->getType();
|
||||
auto operandTypeRef = S.addTypeRef(operandType.getASTType());
|
||||
auto rawExtractee = (unsigned)dfei->getExtractee();
|
||||
SILInstDifferentiableFunctionExtractLayout::emitRecord(
|
||||
Out, ScratchRecord,
|
||||
SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code],
|
||||
operandTypeRef, (unsigned)operandType.getCategory(), operandRef,
|
||||
rawExtractee, (unsigned)dfei->hasExplicitExtracteeType());
|
||||
break;
|
||||
}
|
||||
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
|
||||
auto *dwfi = cast<DifferentiabilityWitnessFunctionInst>(&SI);
|
||||
auto *witness = dwfi->getWitness();
|
||||
@@ -2541,6 +2573,10 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
|
||||
registerSILAbbr<SILInstNoOperandLayout>();
|
||||
registerSILAbbr<SILOneOperandLayout>();
|
||||
registerSILAbbr<SILTwoOperandsLayout>();
|
||||
registerSILAbbr<SILInstWitnessMethodLayout>();
|
||||
registerSILAbbr<SILSpecializeAttrLayout>();
|
||||
registerSILAbbr<SILInstDifferentiableFunctionLayout>();
|
||||
registerSILAbbr<SILInstDifferentiableFunctionExtractLayout>();
|
||||
|
||||
registerSILAbbr<VTableLayout>();
|
||||
registerSILAbbr<VTableEntryLayout>();
|
||||
@@ -2556,9 +2592,6 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
|
||||
registerSILAbbr<PropertyLayout>();
|
||||
registerSILAbbr<DifferentiabilityWitnessLayout>();
|
||||
|
||||
registerSILAbbr<SILInstWitnessMethodLayout>();
|
||||
registerSILAbbr<SILSpecializeAttrLayout>();
|
||||
|
||||
// Register the abbreviation codes so these layouts can exist in both
|
||||
// decl blocks and sil blocks.
|
||||
registerSILAbbr<decls_block::AbstractProtocolConformanceLayout>();
|
||||
|
||||
Reference in New Issue
Block a user