Files
swift-mirror/lib/AST/AutoDiff.cpp
Dan Zheng a174243159 [AutoDiff upstream] Add SIL differentiability witness IRGen. (#29704)
SIL differentiability witnesses are a new top-level SIL construct mapping
an "original" SIL function and derivative configuration to derivative SIL
functions.

This patch adds `SILDifferentiabilityWitness` IRGen.

`SILDifferentiabilityWitness` has a fixed `{ i8*, i8* }` layout:
JVP and VJP derivative function pointers.

Resolves TF-1146.
2020-02-07 14:10:34 -08:00

126 lines
4.6 KiB
C++

//===--- AutoDiff.cpp - Swift automatic differentiation utilities ---------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
#include "swift/AST/AutoDiff.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"
using namespace swift;
void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << "(parameters=";
parameterIndices->print(s);
s << " results=";
resultIndices->print(s);
if (derivativeGenericSignature) {
s << " where=";
derivativeGenericSignature->print(s);
}
s << ')';
}
// TODO(TF-874): This helper is inefficient and should be removed. Unwrapping at
// most once (for curried method types) is sufficient.
static void unwrapCurryLevels(AnyFunctionType *fnTy,
SmallVectorImpl<AnyFunctionType *> &results) {
while (fnTy != nullptr) {
results.push_back(fnTy);
fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
}
}
static unsigned countNumFlattenedElementTypes(Type type) {
if (auto *tupleTy = type->getCanonicalType()->getAs<TupleType>())
return accumulate(tupleTy->getElementTypes(), 0,
[&](unsigned num, Type type) {
return num + countNumFlattenedElementTypes(type);
});
return 1;
}
// TODO(TF-874): Simplify this helper and remove the `reverseCurryLevels` flag.
// See TF-874 for WIP.
void autodiff::getSubsetParameterTypes(IndexSubset *subset,
AnyFunctionType *type,
SmallVectorImpl<Type> &results,
bool reverseCurryLevels) {
SmallVector<AnyFunctionType *, 2> curryLevels;
unwrapCurryLevels(type, curryLevels);
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
unsigned currentOffset = 0;
for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) {
curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset;
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
}
// If `reverseCurryLevels` is true, reverse the curry levels and offsets.
if (reverseCurryLevels) {
std::reverse(curryLevels.begin(), curryLevels.end());
std::reverse(curryLevelParameterIndexOffsets.begin(),
curryLevelParameterIndexOffsets.end());
}
for (unsigned curryLevelIndex : indices(curryLevels)) {
auto *curryLevel = curryLevels[curryLevelIndex];
unsigned parameterIndexOffset =
curryLevelParameterIndexOffsets[curryLevelIndex];
for (unsigned paramIndex : range(curryLevel->getNumParams()))
if (subset->contains(parameterIndexOffset + paramIndex))
results.push_back(curryLevel->getParams()[paramIndex].getOldType());
}
}
GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig) {
if (!derivativeGenSig)
derivativeGenSig = originalFnTy->getSubstGenericSignature();
if (!derivativeGenSig)
return nullptr;
// Constrain all differentiability parameters to `Differentiable`.
auto &ctx = originalFnTy->getASTContext();
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
SmallVector<Requirement, 4> requirements;
for (unsigned paramIdx : diffParamIndices->getIndices()) {
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
Requirement req(RequirementKind::Conformance, paramType,
diffableProto->getDeclaredType());
requirements.push_back(req);
}
return evaluateOrDefault(
ctx.evaluator,
AbstractGenericSignatureRequest{derivativeGenSig.getPointer(),
/*addedGenericParams*/ {},
std::move(requirements)},
nullptr);
}
Type TangentSpace::getType() const {
switch (kind) {
case Kind::TangentVector:
return value.tangentVectorType;
case Kind::Tuple:
return value.tupleType;
}
}
CanType TangentSpace::getCanonicalType() const {
return getType()->getCanonicalType();
}
NominalTypeDecl *TangentSpace::getNominal() const {
assert(isTangentVector());
return getTangentVector()->getNominalOrBoundGenericNominal();
}