[AutoDiff upstream] Add differentiability_witness_function instruction. (#29765)

The `differentiability_witness_function` instruction looks up a
differentiability witness function (JVP, VJP, or transpose) for a referenced
function via SIL differentiability witnesses.

Add round-trip parsing/serialization and IRGen tests.

Notes:
- Differentiability witnesses for linear functions require more support.
  `differentiability_witness_function [transpose]` instructions do not yet
  have IRGen.
- Nothing currently generates `differentiability_witness_function` instructions.
  The differentiation transform does, but it hasn't been upstreamed yet.

Resolves TF-1141.
This commit is contained in:
Dan Zheng
2020-02-13 16:55:46 -08:00
committed by GitHub
parent 94983ce43c
commit a49428ca7c
22 changed files with 557 additions and 2 deletions

View File

@@ -2,7 +2,7 @@
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 Apple Inc. and the Swift project authors
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
@@ -12,11 +12,34 @@
#include "swift/AST/AutoDiff.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/Module.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"
using namespace swift;
DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind(
StringRef string) {
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
.Case("jvp", JVP)
.Case("vjp", VJP)
.Case("transpose", Transpose);
assert(result && "Invalid string");
rawValue = *result;
}
Optional<AutoDiffDerivativeFunctionKind>
DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
switch (rawValue) {
case JVP:
return {AutoDiffDerivativeFunctionKind::JVP};
case VJP:
return {AutoDiffDerivativeFunctionKind::VJP};
case Transpose:
return None;
}
}
void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << "(parameters=";
parameterIndices->print(s);