mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
204 lines
8.0 KiB
C++
204 lines
8.0 KiB
C++
//===----- DifferentiationMangler.cpp --------- differentiation -*- C++ -*-===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
|
|
#include "swift/AST/AutoDiff.h"
|
|
#include "swift/AST/GenericEnvironment.h"
|
|
#include "swift/AST/GenericSignature.h"
|
|
#include "swift/AST/SubstitutionMap.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
#include "swift/Demangling/ManglingMacros.h"
|
|
#include "swift/SIL/SILGlobalVariable.h"
|
|
|
|
using namespace swift;
|
|
using namespace Mangle;
|
|
|
|
/// Mangles the generic signature and get its mangling tree. This is necessary
|
|
/// because the derivative generic signature's requirements may contain names
|
|
/// which repeat the contents of the original function name. To follow Swift's
|
|
/// mangling scheme, these repetitions must be mangled as substitutions.
|
|
/// Therefore, we build mangling trees in `DifferentiationMangler` and let the
|
|
/// remangler take care of substitutions.
|
|
static NodePointer mangleGenericSignatureAsNode(GenericSignature sig,
|
|
Demangler &demangler) {
|
|
if (!sig)
|
|
return nullptr;
|
|
ASTMangler sigMangler(sig->getASTContext());
|
|
auto mangledGenSig = sigMangler.mangleGenericSignature(sig);
|
|
auto demangledGenSig = demangler.demangleSymbol(mangledGenSig);
|
|
assert(demangledGenSig->getKind() == Node::Kind::Global);
|
|
assert(demangledGenSig->getNumChildren() == 1);
|
|
auto result = demangledGenSig->getFirstChild();
|
|
assert(result->getKind() == Node::Kind::DependentGenericSignature);
|
|
return result;
|
|
}
|
|
|
|
static NodePointer mangleAutoDiffFunctionAsNode(
|
|
StringRef originalName, Demangle::AutoDiffFunctionKind kind,
|
|
const AutoDiffConfig &config, Demangler &demangler) {
|
|
assert(isMangledName(originalName));
|
|
auto demangledOrig = demangler.demangleSymbol(originalName);
|
|
assert(demangledOrig && "Should only be called when the original "
|
|
"function has a mangled name");
|
|
assert(demangledOrig->getKind() == Node::Kind::Global);
|
|
auto derivativeGenericSignatureNode = mangleGenericSignatureAsNode(
|
|
config.derivativeGenericSignature, demangler);
|
|
auto *adFunc = demangler.createNode(Node::Kind::AutoDiffFunction);
|
|
for (auto *child : *demangledOrig)
|
|
adFunc->addChild(child, demangler);
|
|
if (derivativeGenericSignatureNode)
|
|
adFunc->addChild(derivativeGenericSignatureNode, demangler);
|
|
adFunc->addChild(
|
|
demangler.createNode(
|
|
Node::Kind::AutoDiffFunctionKind, (Node::IndexType)kind),
|
|
demangler);
|
|
adFunc->addChild(
|
|
demangler.createNode(
|
|
Node::Kind::IndexSubset, config.parameterIndices->getString()),
|
|
demangler);
|
|
adFunc->addChild(
|
|
demangler.createNode(
|
|
Node::Kind::IndexSubset, config.resultIndices->getString()),
|
|
demangler);
|
|
auto root = demangler.createNode(Node::Kind::Global);
|
|
root->addChild(adFunc, demangler);
|
|
return root;
|
|
}
|
|
|
|
std::string DifferentiationMangler::mangleAutoDiffFunction(
|
|
StringRef originalName, Demangle::AutoDiffFunctionKind kind,
|
|
const AutoDiffConfig &config) {
|
|
// If the original function is mangled, mangle the tree.
|
|
if (isMangledName(originalName)) {
|
|
Demangler demangler;
|
|
auto node = mangleAutoDiffFunctionAsNode(
|
|
originalName, kind, config, demangler);
|
|
auto mangling = Demangle::mangleNode(node, Flavor);
|
|
assert(mangling.isSuccess());
|
|
return mangling.result();
|
|
}
|
|
// Otherwise, treat the original function symbol as a black box and just
|
|
// mangle the other parts.
|
|
beginManglingWithoutPrefix();
|
|
appendOperator(originalName);
|
|
appendAutoDiffFunctionParts("TJ", kind, config);
|
|
return finalize();
|
|
}
|
|
|
|
// Returns the mangled name for a derivative function of the given kind.
|
|
std::string DifferentiationMangler::mangleDerivativeFunction(
|
|
StringRef originalName, AutoDiffDerivativeFunctionKind kind,
|
|
const AutoDiffConfig &config) {
|
|
return mangleAutoDiffFunction(
|
|
originalName, getAutoDiffFunctionKind(kind), config);
|
|
}
|
|
|
|
// Returns the mangled name for a derivative function of the given kind.
|
|
std::string DifferentiationMangler::mangleLinearMap(
|
|
StringRef originalName, AutoDiffLinearMapKind kind,
|
|
const AutoDiffConfig &config) {
|
|
return mangleAutoDiffFunction(
|
|
originalName, getAutoDiffFunctionKind(kind), config);
|
|
}
|
|
|
|
static NodePointer mangleDerivativeFunctionSubsetParametersThunkAsNode(
|
|
StringRef originalName, Type toType, Demangle::AutoDiffFunctionKind kind,
|
|
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
|
|
IndexSubset *toParamIndices, Demangler &demangler) {
|
|
assert(isMangledName(originalName));
|
|
auto demangledOrig = demangler.demangleSymbol(originalName);
|
|
assert(demangledOrig && "Should only be called when the original "
|
|
"function has a mangled name");
|
|
assert(demangledOrig->getKind() == Node::Kind::Global);
|
|
auto *thunk = demangler.createNode(Node::Kind::AutoDiffSubsetParametersThunk);
|
|
for (auto *child : *demangledOrig)
|
|
thunk->addChild(child, demangler);
|
|
NodePointer toTypeNode = nullptr;
|
|
{
|
|
ASTMangler typeMangler(toType->getASTContext());
|
|
toTypeNode = demangler.demangleType(
|
|
typeMangler.mangleTypeWithoutPrefix(toType));
|
|
assert(toTypeNode && "Cannot demangle the to-type as node");
|
|
}
|
|
thunk->addChild(toTypeNode, demangler);
|
|
thunk->addChild(
|
|
demangler.createNode(
|
|
Node::Kind::AutoDiffFunctionKind, (Node::IndexType)kind),
|
|
demangler);
|
|
thunk->addChild(
|
|
demangler.createNode(
|
|
Node::Kind::IndexSubset, fromParamIndices->getString()),
|
|
demangler);
|
|
thunk->addChild(
|
|
demangler.createNode(
|
|
Node::Kind::IndexSubset, fromResultIndices->getString()),
|
|
demangler);
|
|
thunk->addChild(
|
|
demangler.createNode(
|
|
Node::Kind::IndexSubset, toParamIndices->getString()),
|
|
demangler);
|
|
auto root = demangler.createNode(Node::Kind::Global);
|
|
root->addChild(thunk, demangler);
|
|
return root;
|
|
}
|
|
|
|
std::string
|
|
DifferentiationMangler::mangleDerivativeFunctionSubsetParametersThunk(
|
|
StringRef originalName, CanType toType,
|
|
AutoDiffDerivativeFunctionKind linearMapKind,
|
|
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
|
|
IndexSubset *toParamIndices) {
|
|
beginMangling();
|
|
auto kind = getAutoDiffFunctionKind(linearMapKind);
|
|
// If the original function is mangled, mangle the tree.
|
|
if (isMangledName(originalName)) {
|
|
Demangler demangler;
|
|
auto node = mangleDerivativeFunctionSubsetParametersThunkAsNode(
|
|
originalName, toType, kind, fromParamIndices, fromResultIndices,
|
|
toParamIndices, demangler);
|
|
auto mangling = Demangle::mangleNode(node, Flavor);
|
|
assert(mangling.isSuccess());
|
|
return mangling.result();
|
|
}
|
|
// Otherwise, treat the original function symbol as a black box and just
|
|
// mangle the other parts.
|
|
beginManglingWithoutPrefix();
|
|
appendOperator(originalName);
|
|
appendType(toType, nullptr);
|
|
auto kindCode = (char)kind;
|
|
appendOperator("TJS", StringRef(&kindCode, 1));
|
|
appendIndexSubset(fromParamIndices);
|
|
appendOperator("p");
|
|
appendIndexSubset(fromResultIndices);
|
|
appendOperator("r");
|
|
appendIndexSubset(toParamIndices);
|
|
appendOperator("P");
|
|
return finalize();
|
|
}
|
|
|
|
std::string DifferentiationMangler::mangleLinearMapSubsetParametersThunk(
|
|
CanType fromType, AutoDiffLinearMapKind linearMapKind,
|
|
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
|
|
IndexSubset *toParamIndices) {
|
|
beginMangling();
|
|
appendType(fromType, nullptr);
|
|
auto functionKindCode = (char)getAutoDiffFunctionKind(linearMapKind);
|
|
appendOperator("TJS", StringRef(&functionKindCode, 1));
|
|
appendIndexSubset(fromParamIndices);
|
|
appendOperator("p");
|
|
appendIndexSubset(fromResultIndices);
|
|
appendOperator("r");
|
|
appendIndexSubset(toParamIndices);
|
|
appendOperator("P");
|
|
return finalize();
|
|
}
|