mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
82 lines
3.7 KiB
C++
82 lines
3.7 KiB
C++
//===--- SILDifferentiabilityWitness.cpp - Differentiability witnesses ----===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define DEBUG_TYPE "sil-differentiability-witness"
|
|
|
|
#include "swift/AST/ASTMangler.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
#include "swift/SIL/SILDifferentiabilityWitness.h"
|
|
#include "swift/SIL/SILModule.h"
|
|
|
|
using namespace swift;
|
|
|
|
SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDeclaration(
|
|
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
|
|
DifferentiabilityKind kind, IndexSubset *parameterIndices,
|
|
IndexSubset *resultIndices, GenericSignature derivativeGenSig,
|
|
const DeclAttribute *attribute) {
|
|
auto *diffWitness = new (module) SILDifferentiabilityWitness(
|
|
module, linkage, originalFunction, kind, parameterIndices, resultIndices,
|
|
derivativeGenSig, /*jvp*/ nullptr, /*vjp*/ nullptr,
|
|
/*isDeclaration*/ true, /*isSerialized*/ false, attribute);
|
|
// Register the differentiability witness in the module.
|
|
Mangle::ASTMangler mangler(module.getASTContext());
|
|
auto mangledKey = mangler.mangleSILDifferentiabilityWitness(
|
|
diffWitness->getOriginalFunction()->getName(),
|
|
diffWitness->getKind(), diffWitness->getConfig());
|
|
assert(!module.DifferentiabilityWitnessMap.count(mangledKey) &&
|
|
"Cannot create duplicate differentiability witness in a module");
|
|
module.DifferentiabilityWitnessMap[mangledKey] = diffWitness;
|
|
module.DifferentiabilityWitnessesByFunction[originalFunction->getName()]
|
|
.push_back(diffWitness);
|
|
module.getDifferentiabilityWitnessList().push_back(diffWitness);
|
|
return diffWitness;
|
|
}
|
|
|
|
SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
|
|
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
|
|
DifferentiabilityKind kind, IndexSubset *parameterIndices,
|
|
IndexSubset *resultIndices, GenericSignature derivativeGenSig,
|
|
SILFunction *jvp, SILFunction *vjp, bool isSerialized,
|
|
const DeclAttribute *attribute) {
|
|
auto *diffWitness = new (module) SILDifferentiabilityWitness(
|
|
module, linkage, originalFunction, kind, parameterIndices, resultIndices,
|
|
derivativeGenSig, jvp, vjp, /*isDeclaration*/ false, isSerialized,
|
|
attribute);
|
|
// Register the differentiability witness in the module.
|
|
Mangle::ASTMangler mangler(module.getASTContext());
|
|
auto mangledKey = mangler.mangleSILDifferentiabilityWitness(
|
|
diffWitness->getOriginalFunction()->getName(),
|
|
diffWitness->getKind(), diffWitness->getConfig());
|
|
assert(!module.DifferentiabilityWitnessMap.count(mangledKey) &&
|
|
"Cannot create duplicate differentiability witness in a module");
|
|
module.DifferentiabilityWitnessMap[mangledKey] = diffWitness;
|
|
module.DifferentiabilityWitnessesByFunction[originalFunction->getName()]
|
|
.push_back(diffWitness);
|
|
module.getDifferentiabilityWitnessList().push_back(diffWitness);
|
|
return diffWitness;
|
|
}
|
|
|
|
void SILDifferentiabilityWitness::convertToDefinition(SILFunction *jvp,
|
|
SILFunction *vjp,
|
|
bool isSerialized) {
|
|
assert(IsDeclaration);
|
|
IsDeclaration = false;
|
|
JVP = jvp;
|
|
VJP = vjp;
|
|
IsSerialized = isSerialized;
|
|
}
|
|
|
|
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
|
|
return {getOriginalFunction()->getName(), getKind(), getConfig()};
|
|
}
|