Files
swift-mirror/lib/SILOptimizer/Utils/DifferentiationMangler.cpp

204 lines
8.0 KiB
C++

//===----- DifferentiationMangler.cpp --------- differentiation -*- C++ -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
#include "swift/SILOptimizer/Utils/DifferentiationMangler.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/Basic/Assertions.h"
#include "swift/Demangling/ManglingMacros.h"
#include "swift/SIL/SILGlobalVariable.h"
using namespace swift;
using namespace Mangle;
/// Mangles the generic signature and get its mangling tree. This is necessary
/// because the derivative generic signature's requirements may contain names
/// which repeat the contents of the original function name. To follow Swift's
/// mangling scheme, these repetitions must be mangled as substitutions.
/// Therefore, we build mangling trees in `DifferentiationMangler` and let the
/// remangler take care of substitutions.
static NodePointer mangleGenericSignatureAsNode(GenericSignature sig,
Demangler &demangler) {
if (!sig)
return nullptr;
ASTMangler sigMangler(sig->getASTContext());
auto mangledGenSig = sigMangler.mangleGenericSignature(sig);
auto demangledGenSig = demangler.demangleSymbol(mangledGenSig);
assert(demangledGenSig->getKind() == Node::Kind::Global);
assert(demangledGenSig->getNumChildren() == 1);
auto result = demangledGenSig->getFirstChild();
assert(result->getKind() == Node::Kind::DependentGenericSignature);
return result;
}
static NodePointer mangleAutoDiffFunctionAsNode(
StringRef originalName, Demangle::AutoDiffFunctionKind kind,
const AutoDiffConfig &config, Demangler &demangler) {
assert(isMangledName(originalName));
auto demangledOrig = demangler.demangleSymbol(originalName);
assert(demangledOrig && "Should only be called when the original "
"function has a mangled name");
assert(demangledOrig->getKind() == Node::Kind::Global);
auto derivativeGenericSignatureNode = mangleGenericSignatureAsNode(
config.derivativeGenericSignature, demangler);
auto *adFunc = demangler.createNode(Node::Kind::AutoDiffFunction);
for (auto *child : *demangledOrig)
adFunc->addChild(child, demangler);
if (derivativeGenericSignatureNode)
adFunc->addChild(derivativeGenericSignatureNode, demangler);
adFunc->addChild(
demangler.createNode(
Node::Kind::AutoDiffFunctionKind, (Node::IndexType)kind),
demangler);
adFunc->addChild(
demangler.createNode(
Node::Kind::IndexSubset, config.parameterIndices->getString()),
demangler);
adFunc->addChild(
demangler.createNode(
Node::Kind::IndexSubset, config.resultIndices->getString()),
demangler);
auto root = demangler.createNode(Node::Kind::Global);
root->addChild(adFunc, demangler);
return root;
}
std::string DifferentiationMangler::mangleAutoDiffFunction(
StringRef originalName, Demangle::AutoDiffFunctionKind kind,
const AutoDiffConfig &config) {
// If the original function is mangled, mangle the tree.
if (isMangledName(originalName)) {
Demangler demangler;
auto node = mangleAutoDiffFunctionAsNode(
originalName, kind, config, demangler);
auto mangling = Demangle::mangleNode(node, Flavor);
assert(mangling.isSuccess());
return mangling.result();
}
// Otherwise, treat the original function symbol as a black box and just
// mangle the other parts.
beginManglingWithoutPrefix();
appendOperator(originalName);
appendAutoDiffFunctionParts("TJ", kind, config);
return finalize();
}
// Returns the mangled name for a derivative function of the given kind.
std::string DifferentiationMangler::mangleDerivativeFunction(
StringRef originalName, AutoDiffDerivativeFunctionKind kind,
const AutoDiffConfig &config) {
return mangleAutoDiffFunction(
originalName, getAutoDiffFunctionKind(kind), config);
}
// Returns the mangled name for a derivative function of the given kind.
std::string DifferentiationMangler::mangleLinearMap(
StringRef originalName, AutoDiffLinearMapKind kind,
const AutoDiffConfig &config) {
return mangleAutoDiffFunction(
originalName, getAutoDiffFunctionKind(kind), config);
}
static NodePointer mangleDerivativeFunctionSubsetParametersThunkAsNode(
StringRef originalName, Type toType, Demangle::AutoDiffFunctionKind kind,
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
IndexSubset *toParamIndices, Demangler &demangler) {
assert(isMangledName(originalName));
auto demangledOrig = demangler.demangleSymbol(originalName);
assert(demangledOrig && "Should only be called when the original "
"function has a mangled name");
assert(demangledOrig->getKind() == Node::Kind::Global);
auto *thunk = demangler.createNode(Node::Kind::AutoDiffSubsetParametersThunk);
for (auto *child : *demangledOrig)
thunk->addChild(child, demangler);
NodePointer toTypeNode = nullptr;
{
ASTMangler typeMangler(toType->getASTContext());
toTypeNode = demangler.demangleType(
typeMangler.mangleTypeWithoutPrefix(toType));
assert(toTypeNode && "Cannot demangle the to-type as node");
}
thunk->addChild(toTypeNode, demangler);
thunk->addChild(
demangler.createNode(
Node::Kind::AutoDiffFunctionKind, (Node::IndexType)kind),
demangler);
thunk->addChild(
demangler.createNode(
Node::Kind::IndexSubset, fromParamIndices->getString()),
demangler);
thunk->addChild(
demangler.createNode(
Node::Kind::IndexSubset, fromResultIndices->getString()),
demangler);
thunk->addChild(
demangler.createNode(
Node::Kind::IndexSubset, toParamIndices->getString()),
demangler);
auto root = demangler.createNode(Node::Kind::Global);
root->addChild(thunk, demangler);
return root;
}
std::string
DifferentiationMangler::mangleDerivativeFunctionSubsetParametersThunk(
StringRef originalName, CanType toType,
AutoDiffDerivativeFunctionKind linearMapKind,
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
IndexSubset *toParamIndices) {
beginMangling();
auto kind = getAutoDiffFunctionKind(linearMapKind);
// If the original function is mangled, mangle the tree.
if (isMangledName(originalName)) {
Demangler demangler;
auto node = mangleDerivativeFunctionSubsetParametersThunkAsNode(
originalName, toType, kind, fromParamIndices, fromResultIndices,
toParamIndices, demangler);
auto mangling = Demangle::mangleNode(node, Flavor);
assert(mangling.isSuccess());
return mangling.result();
}
// Otherwise, treat the original function symbol as a black box and just
// mangle the other parts.
beginManglingWithoutPrefix();
appendOperator(originalName);
appendType(toType, nullptr);
auto kindCode = (char)kind;
appendOperator("TJS", StringRef(&kindCode, 1));
appendIndexSubset(fromParamIndices);
appendOperator("p");
appendIndexSubset(fromResultIndices);
appendOperator("r");
appendIndexSubset(toParamIndices);
appendOperator("P");
return finalize();
}
std::string DifferentiationMangler::mangleLinearMapSubsetParametersThunk(
CanType fromType, AutoDiffLinearMapKind linearMapKind,
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
IndexSubset *toParamIndices) {
beginMangling();
appendType(fromType, nullptr);
auto functionKindCode = (char)getAutoDiffFunctionKind(linearMapKind);
appendOperator("TJS", StringRef(&functionKindCode, 1));
appendIndexSubset(fromParamIndices);
appendOperator("p");
appendIndexSubset(fromResultIndices);
appendOperator("r");
appendIndexSubset(toParamIndices);
appendOperator("P");
return finalize();
}