mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[AutoDiff upstream] Add differentiability_witness_function instruction. (#29765)
The `differentiability_witness_function` instruction looks up a differentiability witness function (JVP, VJP, or transpose) for a referenced function via SIL differentiability witnesses. Add round-trip parsing/serialization and IRGen tests. Notes: - Differentiability witnesses for linear functions require more support. `differentiability_witness_function [transpose]` instructions do not yet have IRGen. - Nothing currently generates `differentiability_witness_function` instructions. The differentiation transform does, but it hasn't been upstreamed yet. Resolves TF-1141.
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
//
|
||||
// This source file is part of the Swift.org open source project
|
||||
//
|
||||
// Copyright (c) 2019 Apple Inc. and the Swift project authors
|
||||
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
||||
// Licensed under Apache License v2.0 with Runtime Library Exception
|
||||
//
|
||||
// See https://swift.org/LICENSE.txt for license information
|
||||
@@ -12,11 +12,34 @@
|
||||
|
||||
#include "swift/AST/AutoDiff.h"
|
||||
#include "swift/AST/ASTContext.h"
|
||||
#include "swift/AST/Module.h"
|
||||
#include "swift/AST/TypeCheckRequests.h"
|
||||
#include "swift/AST/Types.h"
|
||||
|
||||
using namespace swift;
|
||||
|
||||
DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind(
|
||||
StringRef string) {
|
||||
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
|
||||
.Case("jvp", JVP)
|
||||
.Case("vjp", VJP)
|
||||
.Case("transpose", Transpose);
|
||||
assert(result && "Invalid string");
|
||||
rawValue = *result;
|
||||
}
|
||||
|
||||
Optional<AutoDiffDerivativeFunctionKind>
|
||||
DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
|
||||
switch (rawValue) {
|
||||
case JVP:
|
||||
return {AutoDiffDerivativeFunctionKind::JVP};
|
||||
case VJP:
|
||||
return {AutoDiffDerivativeFunctionKind::VJP};
|
||||
case Transpose:
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
void AutoDiffConfig::print(llvm::raw_ostream &s) const {
|
||||
s << "(parameters=";
|
||||
parameterIndices->print(s);
|
||||
|
||||
Reference in New Issue
Block a user