mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[AutoDiff] Enable variable debugging support for pullback functions.
- Properly clone and use debug scopes for all instructions in pullback functions. - Emit `debug_value` instructions for adjoint values. - Add debug locations and variable info to adjoint buffer allocations. - Add `TangentBuilder` (a `SILBuilder` subclass) to unify and simplify special emitter utilities for tangent vector code generation. More simplifications to come. Pullback variable inspection example: ```console (lldb) n Process 50984 stopped * thread #1, queue = 'com.apple.main-thread', stop reason = step over frame #0: 0x0000000100003497 main`pullback of foo(x=0) at main.swift:12:11 9 import _Differentiation 10 11 func foo(_ x: Float) -> Float { -> 12 let y = sin(x) 13 let z = cos(y) 14 let k = tanh(z) + cos(y) 15 return k Target 0: (main) stopped. (lldb) fr v (Float) x = 0 (Float) k = 1 (Float) z = 0.495846391 (Float) y = -0.689988375 ``` Resolves rdar://68616528 / SR-13535.
This commit is contained in:
@@ -115,6 +115,8 @@ private:
|
||||
mutable FuncDecl *cachedPlusFn = nullptr;
|
||||
/// `AdditiveArithmetic.+=` declaration.
|
||||
mutable FuncDecl *cachedPlusEqualFn = nullptr;
|
||||
/// `AdditiveArithmetic.zero` declaration.
|
||||
mutable AccessorDecl *cachedZeroGetter = nullptr;
|
||||
|
||||
public:
|
||||
/// Construct an ADContext for the given module.
|
||||
@@ -201,6 +203,7 @@ public:
|
||||
|
||||
FuncDecl *getPlusDecl() const;
|
||||
FuncDecl *getPlusEqualDecl() const;
|
||||
AccessorDecl *getAdditiveArithmeticZeroGetter() const;
|
||||
|
||||
/// Cleans up all the internal state.
|
||||
void cleanUp();
|
||||
@@ -269,6 +272,10 @@ public:
|
||||
Diag<T...> diag, U &&... args);
|
||||
};
|
||||
|
||||
raw_ostream &getADDebugStream();
|
||||
SILLocation getValidLocation(SILValue v);
|
||||
SILLocation getValidLocation(SILInstruction *inst);
|
||||
|
||||
template <typename... T, typename... U>
|
||||
InFlightDiagnostic
|
||||
ADContext::emitNondifferentiabilityError(SILValue value,
|
||||
|
||||
@@ -51,29 +51,45 @@ class AdjointValueBase {
|
||||
/// The type of this value as if it were materialized as a SIL value.
|
||||
SILType type;
|
||||
|
||||
using DebugInfo = std::pair<SILDebugLocation, SILDebugVariable>;
|
||||
|
||||
/// The debug location and variable info associated with the original value.
|
||||
Optional<DebugInfo> debugInfo;
|
||||
|
||||
/// The underlying value.
|
||||
union Value {
|
||||
llvm::ArrayRef<AdjointValue> aggregate;
|
||||
unsigned numAggregateElements;
|
||||
SILValue concrete;
|
||||
Value(llvm::ArrayRef<AdjointValue> v) : aggregate(v) {}
|
||||
Value(unsigned numAggregateElements)
|
||||
: numAggregateElements(numAggregateElements) {}
|
||||
Value(SILValue v) : concrete(v) {}
|
||||
Value() {}
|
||||
} value;
|
||||
|
||||
// Begins tail-allocated aggregate elements, if
|
||||
// `kind == AdjointValueKind::Aggregate`.
|
||||
|
||||
explicit AdjointValueBase(SILType type,
|
||||
llvm::ArrayRef<AdjointValue> aggregate)
|
||||
: kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}
|
||||
llvm::ArrayRef<AdjointValue> aggregate,
|
||||
Optional<DebugInfo> debugInfo)
|
||||
: kind(AdjointValueKind::Aggregate), type(type), debugInfo(debugInfo),
|
||||
value(aggregate.size()) {
|
||||
MutableArrayRef<AdjointValue> tailElements(
|
||||
reinterpret_cast<AdjointValue *>(this + 1), aggregate.size());
|
||||
std::uninitialized_copy(
|
||||
aggregate.begin(), aggregate.end(), tailElements.begin());
|
||||
}
|
||||
|
||||
explicit AdjointValueBase(SILValue v)
|
||||
: kind(AdjointValueKind::Concrete), type(v->getType()), value(v) {}
|
||||
explicit AdjointValueBase(SILValue v, Optional<DebugInfo> debugInfo)
|
||||
: kind(AdjointValueKind::Concrete), type(v->getType()),
|
||||
debugInfo(debugInfo), value(v) {}
|
||||
|
||||
explicit AdjointValueBase(SILType type)
|
||||
: kind(AdjointValueKind::Zero), type(type) {}
|
||||
explicit AdjointValueBase(SILType type, Optional<DebugInfo> debugInfo)
|
||||
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
|
||||
};
|
||||
|
||||
/// A symbolic adjoint value that is capable of representing zero value 0 and
|
||||
/// 1, in addition to a materialized SILValue. This is expected to be passed
|
||||
/// around by value in most cases, as it's two words long.
|
||||
/// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
|
||||
/// thereof.
|
||||
class AdjointValue final {
|
||||
|
||||
private:
|
||||
@@ -85,26 +101,37 @@ public:
|
||||
AdjointValueBase *operator->() const { return base; }
|
||||
AdjointValueBase &operator*() const { return *base; }
|
||||
|
||||
static AdjointValue createConcrete(llvm::BumpPtrAllocator &allocator,
|
||||
SILValue value) {
|
||||
return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(value);
|
||||
using DebugInfo = AdjointValueBase::DebugInfo;
|
||||
|
||||
static AdjointValue createConcrete(
|
||||
llvm::BumpPtrAllocator &allocator, SILValue value,
|
||||
Optional<DebugInfo> debugInfo = None) {
|
||||
auto *buf = allocator.Allocate<AdjointValueBase>();
|
||||
return new (buf) AdjointValueBase(value, debugInfo);
|
||||
}
|
||||
|
||||
static AdjointValue createZero(llvm::BumpPtrAllocator &allocator,
|
||||
SILType type) {
|
||||
return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(type);
|
||||
static AdjointValue createZero(
|
||||
llvm::BumpPtrAllocator &allocator, SILType type,
|
||||
Optional<DebugInfo> debugInfo = None) {
|
||||
auto *buf = allocator.Allocate<AdjointValueBase>();
|
||||
return new (buf) AdjointValueBase(type, debugInfo);
|
||||
}
|
||||
|
||||
static AdjointValue createAggregate(llvm::BumpPtrAllocator &allocator,
|
||||
SILType type,
|
||||
llvm::ArrayRef<AdjointValue> aggregate) {
|
||||
return new (allocator.Allocate<AdjointValueBase>())
|
||||
AdjointValueBase(type, aggregate);
|
||||
static AdjointValue createAggregate(
|
||||
llvm::BumpPtrAllocator &allocator, SILType type,
|
||||
ArrayRef<AdjointValue> elements,
|
||||
Optional<DebugInfo> debugInfo = None) {
|
||||
AdjointValue *buf = reinterpret_cast<AdjointValue *>(allocator.Allocate(
|
||||
sizeof(AdjointValueBase) + elements.size() * sizeof(AdjointValue),
|
||||
alignof(AdjointValueBase)));
|
||||
return new (buf) AdjointValueBase(type, elements, debugInfo);
|
||||
}
|
||||
|
||||
AdjointValueKind getKind() const { return base->kind; }
|
||||
SILType getType() const { return base->type; }
|
||||
CanType getSwiftType() const { return getType().getASTType(); }
|
||||
Optional<DebugInfo> getDebugInfo() const { return base->debugInfo; }
|
||||
void setDebugInfo(DebugInfo debugInfo) const { base->debugInfo = debugInfo; }
|
||||
|
||||
NominalTypeDecl *getAnyNominal() const {
|
||||
return getSwiftType()->getAnyNominal();
|
||||
@@ -116,16 +143,18 @@ public:
|
||||
|
||||
unsigned getNumAggregateElements() const {
|
||||
assert(isAggregate());
|
||||
return base->value.aggregate.size();
|
||||
return base->value.numAggregateElements;
|
||||
}
|
||||
|
||||
AdjointValue getAggregateElement(unsigned i) const {
|
||||
assert(isAggregate());
|
||||
return base->value.aggregate[i];
|
||||
return getAggregateElements()[i];
|
||||
}
|
||||
|
||||
llvm::ArrayRef<AdjointValue> getAggregateElements() const {
|
||||
return base->value.aggregate;
|
||||
assert(isAggregate());
|
||||
return {
|
||||
reinterpret_cast<const AdjointValue *>(base + 1),
|
||||
getNumAggregateElements()};
|
||||
}
|
||||
|
||||
SILValue getConcreteValue() const {
|
||||
@@ -143,7 +172,7 @@ public:
|
||||
if (auto *decl =
|
||||
getType().getASTType()->getStructOrBoundGenericStruct()) {
|
||||
interleave(
|
||||
llvm::zip(decl->getStoredProperties(), base->value.aggregate),
|
||||
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
|
||||
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
|
||||
s << std::get<0>(elt)->getName() << ": ";
|
||||
std::get<1>(elt).print(s);
|
||||
@@ -151,7 +180,7 @@ public:
|
||||
[&s] { s << ", "; });
|
||||
} else if (getType().is<TupleType>()) {
|
||||
interleave(
|
||||
base->value.aggregate,
|
||||
getAggregateElements(),
|
||||
[&s](const AdjointValue &elt) { elt.print(s); },
|
||||
[&s] { s << ", "; });
|
||||
} else {
|
||||
|
||||
@@ -27,7 +27,9 @@
|
||||
#include "swift/SIL/TypeSubstCloner.h"
|
||||
#include "swift/SILOptimizer/Analysis/ArraySemantic.h"
|
||||
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
|
||||
#include "swift/SILOptimizer/Differentiation/ADContext.h"
|
||||
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
|
||||
#include "swift/SILOptimizer/Differentiation/TangentBuilder.h"
|
||||
|
||||
namespace swift {
|
||||
|
||||
@@ -142,6 +144,9 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Optional<std::pair<SILDebugLocation, SILDebugVariable>>
|
||||
findDebugLocationAndVariable(SILValue originalValue);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Diagnostic utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -190,12 +195,6 @@ SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
|
||||
void extractAllElements(SILValue value, SILBuilder &builder,
|
||||
SmallVectorImpl<SILValue> &results);
|
||||
|
||||
/// Emit a zero value into the given buffer access by calling
|
||||
/// `AdditiveArithmetic.zero`. The given type must conform to
|
||||
/// `AdditiveArithmetic`.
|
||||
void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
|
||||
SILValue bufferAccess, SILLocation loc);
|
||||
|
||||
/// Emit a `Builtin.Word` value that represents the given type's memory layout
|
||||
/// size.
|
||||
SILValue emitMemoryLayoutSize(
|
||||
|
||||
78
include/swift/SILOptimizer/Differentiation/TangentBuilder.h
Normal file
78
include/swift/SILOptimizer/Differentiation/TangentBuilder.h
Normal file
@@ -0,0 +1,78 @@
|
||||
//===--- TangentBuilder.h - Tangent SIL builder --------------*- C++ -*----===//
|
||||
//
|
||||
// This source file is part of the Swift.org open source project
|
||||
//
|
||||
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
||||
// Licensed under Apache License v2.0 with Runtime Library Exception
|
||||
//
|
||||
// See https://swift.org/LICENSE.txt for license information
|
||||
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines a helper class for emitting tangent code for automatic
|
||||
// differentiation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_TANGENTBUILDER_H
|
||||
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_TANGENTBUILDER_H
|
||||
|
||||
#include "swift/SIL/SILBuilder.h"
|
||||
|
||||
namespace swift {
|
||||
namespace autodiff {
|
||||
|
||||
class ADContext;
|
||||
|
||||
class TangentBuilder: public SILBuilder {
|
||||
private:
|
||||
ADContext &adContext;
|
||||
|
||||
public:
|
||||
TangentBuilder(SILFunction &fn, ADContext &adContext)
|
||||
: SILBuilder(fn), adContext(adContext) {}
|
||||
TangentBuilder(SILBasicBlock *bb, ADContext &adContext)
|
||||
: SILBuilder(bb), adContext(adContext) {}
|
||||
TangentBuilder(SILBasicBlock::iterator insertionPt, ADContext &adContext)
|
||||
: SILBuilder(insertionPt), adContext(adContext) {}
|
||||
TangentBuilder(SILBasicBlock *bb, SILBasicBlock::iterator insertionPt,
|
||||
ADContext &adContext)
|
||||
: SILBuilder(bb, insertionPt), adContext(adContext) {}
|
||||
|
||||
/// Emits an `AdditiveArithmetic.zero` into the given buffer. If it is not an
|
||||
/// initialization (`isInit`), a `destroy_addr` will be emitted on the buffer
|
||||
/// first. The buffer must have a type that conforms to `AdditiveArithmetic`
|
||||
/// or be a tuple thereof.
|
||||
void emitZeroIntoBuffer(SILLocation loc, SILValue buffer,
|
||||
IsInitialization_t isInit);
|
||||
|
||||
/// Emits an `AdditiveArithmetic.zero` of the given type. The type must be a
|
||||
/// loadable type, and must conform to `AddditiveArithmetic` or be a tuple
|
||||
/// thereof.
|
||||
SILValue emitZero(SILLocation loc, CanType type);
|
||||
|
||||
/// Emits an `AdditiveArithmetic.+=` for the given destination buffer and
|
||||
/// operand. The type of the buffer and the operand must conform to
|
||||
/// `AddditiveArithmetic` or be a tuple thereof. The operand will not be
|
||||
/// consumed.
|
||||
void emitInPlaceAdd(SILLocation loc, SILValue destinationBuffer,
|
||||
SILValue operand);
|
||||
|
||||
/// Emits an `AdditiveArithmetic.+` for the given operands. The type of the
|
||||
/// operands must conform to `AddditiveArithmetic` or be a tuple thereof. The
|
||||
/// operands will not be consumed.
|
||||
void emitAddIntoBuffer(SILLocation loc, SILValue destinationBuffer,
|
||||
SILValue lhsAddress, SILValue rhsAddress);
|
||||
|
||||
/// Emits an `AdditiveArithmetic.+` for the given operands. The type of the
|
||||
/// operands must be a loadable type, and must conform to
|
||||
/// `AddditiveArithmetic` or be a tuple thereof. The operands will not be
|
||||
/// consumed.
|
||||
SILValue emitAdd(SILLocation loc, SILValue lhs, SILValue rhs);
|
||||
};
|
||||
|
||||
} // end namespace autodiff
|
||||
} // end namespace swift
|
||||
|
||||
#endif /* SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_TANGENTBUILDER_H */
|
||||
@@ -38,6 +38,8 @@ class ArchetypeType;
|
||||
|
||||
namespace autodiff {
|
||||
|
||||
class ADContext;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Thunk helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -107,7 +109,7 @@ std::pair<SILFunction *, SubstitutionMap>
|
||||
getOrCreateSubsetParametersThunkForDerivativeFunction(
|
||||
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
|
||||
AutoDiffDerivativeFunctionKind kind, AutoDiffConfig desiredConfig,
|
||||
AutoDiffConfig actualConfig);
|
||||
AutoDiffConfig actualConfig, ADContext &adContext);
|
||||
|
||||
/// Get or create a derivative function parameter index subset thunk from
|
||||
/// `actualIndices` to `desiredIndices` for the given associated function
|
||||
@@ -119,7 +121,8 @@ getOrCreateSubsetParametersThunkForLinearMap(
|
||||
SILOptFunctionBuilder &fb, SILFunction *assocFn,
|
||||
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
|
||||
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
|
||||
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig);
|
||||
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig,
|
||||
ADContext &adContext);
|
||||
|
||||
} // end namespace autodiff
|
||||
|
||||
|
||||
Reference in New Issue
Block a user