mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Although I don't plan to bring over new assertions wholesale into the current qualification branch, it's entirely possible that various minor changes in main will use the new assertions; having this basic support in the release branch will simplify that. (This is why I'm adding the includes as a separate pass from rewriting the individual assertions)
160 lines
5.8 KiB
C++
160 lines
5.8 KiB
C++
//===--- ADContext.cpp - Differentiation Context --------------*- C++ -*---===//
|
|
//
|
|
// This source file is part of the Swift.org open source project
|
|
//
|
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See https://swift.org/LICENSE.txt for license information
|
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Per-module contextual information for the differentiation transform.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define DEBUG_TYPE "differentiation"
|
|
|
|
#include "swift/SILOptimizer/Differentiation/ADContext.h"
|
|
#include "swift/AST/DiagnosticsSIL.h"
|
|
#include "swift/AST/SourceFile.h"
|
|
#include "swift/Basic/Assertions.h"
|
|
#include "swift/SILOptimizer/PassManager/Transforms.h"
|
|
|
|
using llvm::DenseMap;
|
|
using llvm::SmallPtrSet;
|
|
using llvm::SmallVector;
|
|
|
|
namespace swift {
|
|
namespace autodiff {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Local helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Given an operator name, such as '+', and a protocol, returns the '+'
|
|
/// operator. If the operator does not exist in the protocol, returns null.
|
|
static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName,
|
|
ProtocolDecl *protocol) {
|
|
assert(operatorName.isOperator());
|
|
// Find the operator requirement in the given protocol declaration.
|
|
auto opLookup = protocol->lookupDirect(operatorName);
|
|
for (auto *decl : opLookup) {
|
|
if (!decl->isProtocolRequirement())
|
|
continue;
|
|
auto *fd = dyn_cast<FuncDecl>(decl);
|
|
if (!fd || !fd->isStatic() || !fd->isOperator())
|
|
continue;
|
|
return fd;
|
|
}
|
|
// Not found.
|
|
return nullptr;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ADContext methods
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ADContext::ADContext(SILModuleTransform &transform)
|
|
: transform(transform), module(*transform.getModule()),
|
|
passManager(*transform.getPassManager()) {}
|
|
|
|
/// Get the source file for the given `SILFunction`.
|
|
static SourceFile &getSourceFile(SILFunction *f) {
|
|
if (f->hasLocation())
|
|
if (auto *declContext = f->getLocation().getAsDeclContext())
|
|
if (auto *parentSourceFile = declContext->getParentSourceFile())
|
|
return *parentSourceFile;
|
|
for (auto *file : f->getModule().getSwiftModule()->getFiles())
|
|
if (auto *sourceFile = dyn_cast<SourceFile>(file))
|
|
return *sourceFile;
|
|
llvm_unreachable("Could not resolve SourceFile from SILFunction");
|
|
}
|
|
|
|
SynthesizedFileUnit &
|
|
ADContext::getOrCreateSynthesizedFile(SILFunction *original) {
|
|
auto &SF = getSourceFile(original);
|
|
return SF.getOrCreateSynthesizedFile();
|
|
}
|
|
|
|
FuncDecl *ADContext::getPlusDecl() const {
|
|
if (!cachedPlusFn) {
|
|
cachedPlusFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+"),
|
|
additiveArithmeticProtocol);
|
|
assert(cachedPlusFn && "AdditiveArithmetic.+ not found");
|
|
}
|
|
return cachedPlusFn;
|
|
}
|
|
|
|
FuncDecl *ADContext::getPlusEqualDecl() const {
|
|
if (!cachedPlusEqualFn) {
|
|
cachedPlusEqualFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+="),
|
|
additiveArithmeticProtocol);
|
|
assert(cachedPlusEqualFn && "AdditiveArithmetic.+= not found");
|
|
}
|
|
return cachedPlusEqualFn;
|
|
}
|
|
|
|
AccessorDecl *ADContext::getAdditiveArithmeticZeroGetter() const {
|
|
if (cachedZeroGetter)
|
|
return cachedZeroGetter;
|
|
auto zeroDeclLookup = getAdditiveArithmeticProtocol()
|
|
->lookupDirect(getASTContext().Id_zero);
|
|
auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front());
|
|
assert(zeroDecl->isProtocolRequirement());
|
|
cachedZeroGetter = zeroDecl->getOpaqueAccessor(AccessorKind::Get);
|
|
return cachedZeroGetter;
|
|
}
|
|
|
|
void ADContext::cleanUp() {
|
|
// Delete all references to generated functions.
|
|
for (auto fnRef : generatedFunctionReferences) {
|
|
if (auto *fnRefInst =
|
|
peerThroughFunctionConversions<FunctionRefInst>(fnRef)) {
|
|
fnRefInst->replaceAllUsesWithUndef();
|
|
fnRefInst->eraseFromParent();
|
|
}
|
|
}
|
|
// Delete all generated functions.
|
|
for (auto *generatedFunction : generatedFunctions) {
|
|
LLVM_DEBUG(getADDebugStream() << "Deleting generated function "
|
|
<< generatedFunction->getName() << '\n');
|
|
generatedFunction->dropAllReferences();
|
|
transform.notifyWillDeleteFunction(generatedFunction);
|
|
module.eraseFunction(generatedFunction);
|
|
}
|
|
}
|
|
|
|
DifferentiableFunctionInst *ADContext::createDifferentiableFunction(
|
|
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
|
|
IndexSubset *resultIndices, SILValue original,
|
|
std::optional<std::pair<SILValue, SILValue>> derivativeFunctions) {
|
|
auto *dfi = builder.createDifferentiableFunction(
|
|
loc, parameterIndices, resultIndices, original, derivativeFunctions);
|
|
processedDifferentiableFunctionInsts.erase(dfi);
|
|
return dfi;
|
|
}
|
|
|
|
LinearFunctionInst *ADContext::createLinearFunction(
|
|
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
|
|
SILValue original, std::optional<SILValue> transposeFunction) {
|
|
auto *lfi = builder.createLinearFunction(loc, parameterIndices, original,
|
|
transposeFunction);
|
|
processedLinearFunctionInsts.erase(lfi);
|
|
return lfi;
|
|
}
|
|
|
|
DifferentiableFunctionExpr *
|
|
ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) {
|
|
return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>();
|
|
}
|
|
|
|
LinearFunctionExpr *
|
|
ADContext::findDifferentialOperator(LinearFunctionInst *inst) {
|
|
return inst->getLoc().getAsASTNode<LinearFunctionExpr>();
|
|
}
|
|
|
|
} // end namespace autodiff
|
|
} // end namespace swift
|