Files
swift-mirror/lib/SILOptimizer/Differentiation/ADContext.cpp
Evan Wilde f3ff561c6f [NFC] add llvm namespace to Optional and None
This is phase-1 of switching from llvm::Optional to std::optional in the
next rebranch. llvm::Optional was removed from upstream LLVM, so we need
to migrate off rather soon. On Darwin, std::optional, and llvm::Optional
have the same layout, so we don't need to be as concerned about ABI
beyond the name mangling. `llvm::Optional` is only returned from one
function in
```
getStandardTypeSubst(StringRef TypeName,
                     bool allowConcurrencyManglings);
```
It's the return value, so it should not impact the mangling of the
function, and the layout is the same as `std::optional`, so it should be
mostly okay. This function doesn't appear to have users, and the ABI was
already broken 2 years ago for concurrency and no one seemed to notice
so this should be "okay".

I'm doing the migration incrementally so that folks working on main can
cherry-pick back to the release/5.9 branch. Once 5.9 is done and locked
away, then we can go through and finish the replacement. Since `None`
and `Optional` show up in contexts where they are not `llvm::None` and
`llvm::Optional`, I'm preparing the work now by going through and
removing the namespace unwrapping and making the `llvm` namespace
explicit. This should make it fairly mechanical to go through and
replace llvm::Optional with std::optional, and llvm::None with
std::nullopt. It's also a change that can be brought onto the
release/5.9 with minimal impact. This should be an NFC change.
2023-06-27 09:03:52 -07:00

159 lines
5.8 KiB
C++

//===--- ADContext.cpp - Differentiation Context --------------*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Per-module contextual information for the differentiation transform.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "differentiation"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/SourceFile.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
using llvm::DenseMap;
using llvm::SmallPtrSet;
using llvm::SmallVector;
namespace swift {
namespace autodiff {
//===----------------------------------------------------------------------===//
// Local helpers
//===----------------------------------------------------------------------===//
/// Given an operator name, such as '+', and a protocol, returns the '+'
/// operator. If the operator does not exist in the protocol, returns null.
static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName,
ProtocolDecl *protocol) {
assert(operatorName.isOperator());
// Find the operator requirement in the given protocol declaration.
auto opLookup = protocol->lookupDirect(operatorName);
for (auto *decl : opLookup) {
if (!decl->isProtocolRequirement())
continue;
auto *fd = dyn_cast<FuncDecl>(decl);
if (!fd || !fd->isStatic() || !fd->isOperator())
continue;
return fd;
}
// Not found.
return nullptr;
}
//===----------------------------------------------------------------------===//
// ADContext methods
//===----------------------------------------------------------------------===//
ADContext::ADContext(SILModuleTransform &transform)
: transform(transform), module(*transform.getModule()),
passManager(*transform.getPassManager()) {}
/// Get the source file for the given `SILFunction`.
static SourceFile &getSourceFile(SILFunction *f) {
if (f->hasLocation())
if (auto *declContext = f->getLocation().getAsDeclContext())
if (auto *parentSourceFile = declContext->getParentSourceFile())
return *parentSourceFile;
for (auto *file : f->getModule().getSwiftModule()->getFiles())
if (auto *sourceFile = dyn_cast<SourceFile>(file))
return *sourceFile;
llvm_unreachable("Could not resolve SourceFile from SILFunction");
}
SynthesizedFileUnit &
ADContext::getOrCreateSynthesizedFile(SILFunction *original) {
auto &SF = getSourceFile(original);
return SF.getOrCreateSynthesizedFile();
}
FuncDecl *ADContext::getPlusDecl() const {
if (!cachedPlusFn) {
cachedPlusFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+"),
additiveArithmeticProtocol);
assert(cachedPlusFn && "AdditiveArithmetic.+ not found");
}
return cachedPlusFn;
}
FuncDecl *ADContext::getPlusEqualDecl() const {
if (!cachedPlusEqualFn) {
cachedPlusEqualFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+="),
additiveArithmeticProtocol);
assert(cachedPlusEqualFn && "AdditiveArithmetic.+= not found");
}
return cachedPlusEqualFn;
}
AccessorDecl *ADContext::getAdditiveArithmeticZeroGetter() const {
if (cachedZeroGetter)
return cachedZeroGetter;
auto zeroDeclLookup = getAdditiveArithmeticProtocol()
->lookupDirect(getASTContext().Id_zero);
auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front());
assert(zeroDecl->isProtocolRequirement());
cachedZeroGetter = zeroDecl->getOpaqueAccessor(AccessorKind::Get);
return cachedZeroGetter;
}
void ADContext::cleanUp() {
// Delete all references to generated functions.
for (auto fnRef : generatedFunctionReferences) {
if (auto *fnRefInst =
peerThroughFunctionConversions<FunctionRefInst>(fnRef)) {
fnRefInst->replaceAllUsesWithUndef();
fnRefInst->eraseFromParent();
}
}
// Delete all generated functions.
for (auto *generatedFunction : generatedFunctions) {
LLVM_DEBUG(getADDebugStream() << "Deleting generated function "
<< generatedFunction->getName() << '\n');
generatedFunction->dropAllReferences();
transform.notifyWillDeleteFunction(generatedFunction);
module.eraseFunction(generatedFunction);
}
}
DifferentiableFunctionInst *ADContext::createDifferentiableFunction(
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
IndexSubset *resultIndices, SILValue original,
llvm::Optional<std::pair<SILValue, SILValue>> derivativeFunctions) {
auto *dfi = builder.createDifferentiableFunction(
loc, parameterIndices, resultIndices, original, derivativeFunctions);
processedDifferentiableFunctionInsts.erase(dfi);
return dfi;
}
LinearFunctionInst *ADContext::createLinearFunction(
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
SILValue original, llvm::Optional<SILValue> transposeFunction) {
auto *lfi = builder.createLinearFunction(loc, parameterIndices, original,
transposeFunction);
processedLinearFunctionInsts.erase(lfi);
return lfi;
}
DifferentiableFunctionExpr *
ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) {
return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>();
}
LinearFunctionExpr *
ADContext::findDifferentialOperator(LinearFunctionInst *inst) {
return inst->getLoc().getAsASTNode<LinearFunctionExpr>();
}
} // end namespace autodiff
} // end namespace swift