[AutoDiff upstream] [SIL] Add differentiable function instructions.

Add `differentiable_function` and `differentiable_function_extract`
instructions.

`differentiable_function` creates a `@differentiable` function-typed
value from an original function operand and derivative function operands
(optional).

`differentiable_function_extract` extracts either the original or
derivative function value from a `@differentiable` function.

The differentiation transform canonicalizes `differentiable_function`
instructions, filling in derivative function operands if missing.

Resolves TF-1139 and TF-1140.
This commit is contained in:
Dan Zheng
2020-03-22 11:50:36 -07:00
parent 603db8c954
commit cc7e9fc39e
21 changed files with 762 additions and 5 deletions

View File

@@ -2152,6 +2152,38 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) {
break;
}
case SILInstructionKind::DifferentiableFunctionInst: {
auto *dfi = cast<DifferentiableFunctionInst>(&SI);
SmallVector<ValueID, 4> trailingInfo;
auto *paramIndices = dfi->getParameterIndices();
for (unsigned idx : paramIndices->getIndices())
trailingInfo.push_back(idx);
for (auto &op : dfi->getAllOperands()) {
auto val = op.get();
trailingInfo.push_back(S.addTypeRef(val->getType().getASTType()));
trailingInfo.push_back((unsigned)val->getType().getCategory());
trailingInfo.push_back(addValueRef(val));
}
SILInstDifferentiableFunctionLayout::emitRecord(
Out, ScratchRecord,
SILAbbrCodes[SILInstDifferentiableFunctionLayout::Code],
paramIndices->getCapacity(), dfi->hasDerivativeFunctions(),
trailingInfo);
break;
}
case SILInstructionKind::DifferentiableFunctionExtractInst: {
auto *dfei = cast<DifferentiableFunctionExtractInst>(&SI);
auto operandRef = addValueRef(dfei->getOperand());
auto operandType = dfei->getOperand()->getType();
auto operandTypeRef = S.addTypeRef(operandType.getASTType());
auto rawExtractee = (unsigned)dfei->getExtractee();
SILInstDifferentiableFunctionExtractLayout::emitRecord(
Out, ScratchRecord,
SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code],
operandTypeRef, (unsigned)operandType.getCategory(), operandRef,
rawExtractee, (unsigned)dfei->hasExplicitExtracteeType());
break;
}
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
auto *dwfi = cast<DifferentiabilityWitnessFunctionInst>(&SI);
auto *witness = dwfi->getWitness();
@@ -2541,6 +2573,10 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
registerSILAbbr<SILInstNoOperandLayout>();
registerSILAbbr<SILOneOperandLayout>();
registerSILAbbr<SILTwoOperandsLayout>();
registerSILAbbr<SILInstWitnessMethodLayout>();
registerSILAbbr<SILSpecializeAttrLayout>();
registerSILAbbr<SILInstDifferentiableFunctionLayout>();
registerSILAbbr<SILInstDifferentiableFunctionExtractLayout>();
registerSILAbbr<VTableLayout>();
registerSILAbbr<VTableEntryLayout>();
@@ -2556,9 +2592,6 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
registerSILAbbr<PropertyLayout>();
registerSILAbbr<DifferentiabilityWitnessLayout>();
registerSILAbbr<SILInstWitnessMethodLayout>();
registerSILAbbr<SILSpecializeAttrLayout>();
// Register the abbreviation codes so these layouts can exist in both
// decl blocks and sil blocks.
registerSILAbbr<decls_block::AbstractProtocolConformanceLayout>();