[AutoDiff] Enable variable debugging support for pullback functions.

- Properly clone and use debug scopes for all instructions in pullback functions.
- Emit `debug_value` instructions for adjoint values.
- Add debug locations and variable info to adjoint buffer allocations.
- Add `TangentBuilder` (a `SILBuilder` subclass) to unify and simplify special emitter utilities for tangent vector code generation. More simplifications to come.

Pullback variable inspection example:
```console
(lldb) n
Process 50984 stopped
* thread #1, queue = 'com.apple.main-thread', stop reason = step over
    frame #0: 0x0000000100003497 main`pullback of foo(x=0) at main.swift:12:11
   9   	import _Differentiation
   10
   11  	func foo(_ x: Float) -> Float {
-> 12  	  let y = sin(x)
   13  	  let z = cos(y)
   14  	  let k = tanh(z) + cos(y)
   15  	  return k
Target 0: (main) stopped.
(lldb) fr v
(Float) x = 0
(Float) k = 1
(Float) z = 0.495846391
(Float) y = -0.689988375
```

Resolves rdar://68616528 / SR-13535.
This commit is contained in:
Richard Wei
2020-12-19 22:00:28 -08:00
parent 1c2b80fae7
commit e3b480b0c9
14 changed files with 573 additions and 425 deletions

View File

@@ -115,6 +115,8 @@ private:
mutable FuncDecl *cachedPlusFn = nullptr;
/// `AdditiveArithmetic.+=` declaration.
mutable FuncDecl *cachedPlusEqualFn = nullptr;
/// `AdditiveArithmetic.zero` declaration.
mutable AccessorDecl *cachedZeroGetter = nullptr;
public:
/// Construct an ADContext for the given module.
@@ -201,6 +203,7 @@ public:
FuncDecl *getPlusDecl() const;
FuncDecl *getPlusEqualDecl() const;
AccessorDecl *getAdditiveArithmeticZeroGetter() const;
/// Cleans up all the internal state.
void cleanUp();
@@ -269,6 +272,10 @@ public:
Diag<T...> diag, U &&... args);
};
raw_ostream &getADDebugStream();
SILLocation getValidLocation(SILValue v);
SILLocation getValidLocation(SILInstruction *inst);
template <typename... T, typename... U>
InFlightDiagnostic
ADContext::emitNondifferentiabilityError(SILValue value,

View File

@@ -51,29 +51,45 @@ class AdjointValueBase {
/// The type of this value as if it were materialized as a SIL value.
SILType type;
using DebugInfo = std::pair<SILDebugLocation, SILDebugVariable>;
/// The debug location and variable info associated with the original value.
Optional<DebugInfo> debugInfo;
/// The underlying value.
union Value {
llvm::ArrayRef<AdjointValue> aggregate;
unsigned numAggregateElements;
SILValue concrete;
Value(llvm::ArrayRef<AdjointValue> v) : aggregate(v) {}
Value(unsigned numAggregateElements)
: numAggregateElements(numAggregateElements) {}
Value(SILValue v) : concrete(v) {}
Value() {}
} value;
// Begins tail-allocated aggregate elements, if
// `kind == AdjointValueKind::Aggregate`.
explicit AdjointValueBase(SILType type,
llvm::ArrayRef<AdjointValue> aggregate)
: kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}
llvm::ArrayRef<AdjointValue> aggregate,
Optional<DebugInfo> debugInfo)
: kind(AdjointValueKind::Aggregate), type(type), debugInfo(debugInfo),
value(aggregate.size()) {
MutableArrayRef<AdjointValue> tailElements(
reinterpret_cast<AdjointValue *>(this + 1), aggregate.size());
std::uninitialized_copy(
aggregate.begin(), aggregate.end(), tailElements.begin());
}
explicit AdjointValueBase(SILValue v)
: kind(AdjointValueKind::Concrete), type(v->getType()), value(v) {}
explicit AdjointValueBase(SILValue v, Optional<DebugInfo> debugInfo)
: kind(AdjointValueKind::Concrete), type(v->getType()),
debugInfo(debugInfo), value(v) {}
explicit AdjointValueBase(SILType type)
: kind(AdjointValueKind::Zero), type(type) {}
explicit AdjointValueBase(SILType type, Optional<DebugInfo> debugInfo)
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
};
/// A symbolic adjoint value that is capable of representing zero value 0 and
/// 1, in addition to a materialized SILValue. This is expected to be passed
/// around by value in most cases, as it's two words long.
/// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
/// thereof.
class AdjointValue final {
private:
@@ -85,26 +101,37 @@ public:
AdjointValueBase *operator->() const { return base; }
AdjointValueBase &operator*() const { return *base; }
static AdjointValue createConcrete(llvm::BumpPtrAllocator &allocator,
SILValue value) {
return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(value);
using DebugInfo = AdjointValueBase::DebugInfo;
static AdjointValue createConcrete(
llvm::BumpPtrAllocator &allocator, SILValue value,
Optional<DebugInfo> debugInfo = None) {
auto *buf = allocator.Allocate<AdjointValueBase>();
return new (buf) AdjointValueBase(value, debugInfo);
}
static AdjointValue createZero(llvm::BumpPtrAllocator &allocator,
SILType type) {
return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(type);
static AdjointValue createZero(
llvm::BumpPtrAllocator &allocator, SILType type,
Optional<DebugInfo> debugInfo = None) {
auto *buf = allocator.Allocate<AdjointValueBase>();
return new (buf) AdjointValueBase(type, debugInfo);
}
static AdjointValue createAggregate(llvm::BumpPtrAllocator &allocator,
SILType type,
llvm::ArrayRef<AdjointValue> aggregate) {
return new (allocator.Allocate<AdjointValueBase>())
AdjointValueBase(type, aggregate);
static AdjointValue createAggregate(
llvm::BumpPtrAllocator &allocator, SILType type,
ArrayRef<AdjointValue> elements,
Optional<DebugInfo> debugInfo = None) {
AdjointValue *buf = reinterpret_cast<AdjointValue *>(allocator.Allocate(
sizeof(AdjointValueBase) + elements.size() * sizeof(AdjointValue),
alignof(AdjointValueBase)));
return new (buf) AdjointValueBase(type, elements, debugInfo);
}
AdjointValueKind getKind() const { return base->kind; }
SILType getType() const { return base->type; }
CanType getSwiftType() const { return getType().getASTType(); }
Optional<DebugInfo> getDebugInfo() const { return base->debugInfo; }
void setDebugInfo(DebugInfo debugInfo) const { base->debugInfo = debugInfo; }
NominalTypeDecl *getAnyNominal() const {
return getSwiftType()->getAnyNominal();
@@ -116,16 +143,18 @@ public:
unsigned getNumAggregateElements() const {
assert(isAggregate());
return base->value.aggregate.size();
return base->value.numAggregateElements;
}
AdjointValue getAggregateElement(unsigned i) const {
assert(isAggregate());
return base->value.aggregate[i];
return getAggregateElements()[i];
}
llvm::ArrayRef<AdjointValue> getAggregateElements() const {
return base->value.aggregate;
assert(isAggregate());
return {
reinterpret_cast<const AdjointValue *>(base + 1),
getNumAggregateElements()};
}
SILValue getConcreteValue() const {
@@ -143,7 +172,7 @@ public:
if (auto *decl =
getType().getASTType()->getStructOrBoundGenericStruct()) {
interleave(
llvm::zip(decl->getStoredProperties(), base->value.aggregate),
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
s << std::get<0>(elt)->getName() << ": ";
std::get<1>(elt).print(s);
@@ -151,7 +180,7 @@ public:
[&s] { s << ", "; });
} else if (getType().is<TupleType>()) {
interleave(
base->value.aggregate,
getAggregateElements(),
[&s](const AdjointValue &elt) { elt.print(s); },
[&s] { s << ", "; });
} else {

View File

@@ -27,7 +27,9 @@
#include "swift/SIL/TypeSubstCloner.h"
#include "swift/SILOptimizer/Analysis/ArraySemantic.h"
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
#include "swift/SILOptimizer/Differentiation/TangentBuilder.h"
namespace swift {
@@ -142,6 +144,9 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
return nullptr;
}
Optional<std::pair<SILDebugLocation, SILDebugVariable>>
findDebugLocationAndVariable(SILValue originalValue);
//===----------------------------------------------------------------------===//
// Diagnostic utilities
//===----------------------------------------------------------------------===//
@@ -190,12 +195,6 @@ SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
void extractAllElements(SILValue value, SILBuilder &builder,
SmallVectorImpl<SILValue> &results);
/// Emit a zero value into the given buffer access by calling
/// `AdditiveArithmetic.zero`. The given type must conform to
/// `AdditiveArithmetic`.
void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
SILValue bufferAccess, SILLocation loc);
/// Emit a `Builtin.Word` value that represents the given type's memory layout
/// size.
SILValue emitMemoryLayoutSize(

View File

@@ -0,0 +1,78 @@
//===--- TangentBuilder.h - Tangent SIL builder --------------*- C++ -*----===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file defines a helper class for emitting tangent code for automatic
// differentiation.
//
//===----------------------------------------------------------------------===//
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_TANGENTBUILDER_H
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_TANGENTBUILDER_H
#include "swift/SIL/SILBuilder.h"
namespace swift {
namespace autodiff {
class ADContext;
class TangentBuilder: public SILBuilder {
private:
ADContext &adContext;
public:
TangentBuilder(SILFunction &fn, ADContext &adContext)
: SILBuilder(fn), adContext(adContext) {}
TangentBuilder(SILBasicBlock *bb, ADContext &adContext)
: SILBuilder(bb), adContext(adContext) {}
TangentBuilder(SILBasicBlock::iterator insertionPt, ADContext &adContext)
: SILBuilder(insertionPt), adContext(adContext) {}
TangentBuilder(SILBasicBlock *bb, SILBasicBlock::iterator insertionPt,
ADContext &adContext)
: SILBuilder(bb, insertionPt), adContext(adContext) {}
/// Emits an `AdditiveArithmetic.zero` into the given buffer. If it is not an
/// initialization (`isInit`), a `destroy_addr` will be emitted on the buffer
/// first. The buffer must have a type that conforms to `AdditiveArithmetic`
/// or be a tuple thereof.
void emitZeroIntoBuffer(SILLocation loc, SILValue buffer,
IsInitialization_t isInit);
/// Emits an `AdditiveArithmetic.zero` of the given type. The type must be a
/// loadable type, and must conform to `AddditiveArithmetic` or be a tuple
/// thereof.
SILValue emitZero(SILLocation loc, CanType type);
/// Emits an `AdditiveArithmetic.+=` for the given destination buffer and
/// operand. The type of the buffer and the operand must conform to
/// `AddditiveArithmetic` or be a tuple thereof. The operand will not be
/// consumed.
void emitInPlaceAdd(SILLocation loc, SILValue destinationBuffer,
SILValue operand);
/// Emits an `AdditiveArithmetic.+` for the given operands. The type of the
/// operands must conform to `AddditiveArithmetic` or be a tuple thereof. The
/// operands will not be consumed.
void emitAddIntoBuffer(SILLocation loc, SILValue destinationBuffer,
SILValue lhsAddress, SILValue rhsAddress);
/// Emits an `AdditiveArithmetic.+` for the given operands. The type of the
/// operands must be a loadable type, and must conform to
/// `AddditiveArithmetic` or be a tuple thereof. The operands will not be
/// consumed.
SILValue emitAdd(SILLocation loc, SILValue lhs, SILValue rhs);
};
} // end namespace autodiff
} // end namespace swift
#endif /* SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_TANGENTBUILDER_H */

View File

@@ -38,6 +38,8 @@ class ArchetypeType;
namespace autodiff {
class ADContext;
//===----------------------------------------------------------------------===//
// Thunk helpers
//===----------------------------------------------------------------------===//
@@ -107,7 +109,7 @@ std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForDerivativeFunction(
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
AutoDiffDerivativeFunctionKind kind, AutoDiffConfig desiredConfig,
AutoDiffConfig actualConfig);
AutoDiffConfig actualConfig, ADContext &adContext);
/// Get or create a derivative function parameter index subset thunk from
/// `actualIndices` to `desiredIndices` for the given associated function
@@ -119,7 +121,8 @@ getOrCreateSubsetParametersThunkForLinearMap(
SILOptFunctionBuilder &fb, SILFunction *assocFn,
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig);
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig,
ADContext &adContext);
} // end namespace autodiff