[AutoDiff] Handle materializing adjoints with non-differentiable fields (#67319)

This commit is contained in:
Kshitij Jain
2023-09-12 14:22:41 -07:00
committed by GitHub
parent c2ba21e334
commit d971f125d9
10 changed files with 921 additions and 98 deletions

View File

@@ -19,6 +19,8 @@
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H #define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
#include "swift/AST/Decl.h" #include "swift/AST/Decl.h"
#include "swift/SIL/SILDebugVariable.h"
#include "swift/SIL/SILLocation.h"
#include "swift/SIL/SILValue.h" #include "swift/SIL/SILValue.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
@@ -38,10 +40,18 @@ enum AdjointValueKind {
/// A concrete SIL value. /// A concrete SIL value.
Concrete, Concrete,
/// A special adjoint, made up of 2 adjoints -- an aggregate base adjoint and
/// an element adjoint to add to one of its fields. This case exists to avoid
/// eager materialization of a base adjoint upon addition with one of its
/// fields.
AddElement,
}; };
class AdjointValue; class AdjointValue;
struct AddElementValue;
class AdjointValueBase { class AdjointValueBase {
friend class AdjointValue; friend class AdjointValue;
@@ -60,9 +70,13 @@ class AdjointValueBase {
union Value { union Value {
unsigned numAggregateElements; unsigned numAggregateElements;
SILValue concrete; SILValue concrete;
AddElementValue *addElementValue;
Value(unsigned numAggregateElements) Value(unsigned numAggregateElements)
: numAggregateElements(numAggregateElements) {} : numAggregateElements(numAggregateElements) {}
Value(SILValue v) : concrete(v) {} Value(SILValue v) : concrete(v) {}
Value(AddElementValue *addElementValue)
: addElementValue(addElementValue) {}
Value() {} Value() {}
} value; } value;
@@ -86,6 +100,11 @@ class AdjointValueBase {
explicit AdjointValueBase(SILType type, llvm::Optional<DebugInfo> debugInfo) explicit AdjointValueBase(SILType type, llvm::Optional<DebugInfo> debugInfo)
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {} : kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
explicit AdjointValueBase(SILType type, AddElementValue *addElementValue,
llvm::Optional<DebugInfo> debugInfo)
: kind(AdjointValueKind::AddElement), type(type), debugInfo(debugInfo),
value(addElementValue) {}
}; };
/// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate /// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
@@ -127,6 +146,14 @@ public:
return new (buf) AdjointValueBase(type, elements, debugInfo); return new (buf) AdjointValueBase(type, elements, debugInfo);
} }
static AdjointValue
createAddElement(llvm::BumpPtrAllocator &allocator, SILType type,
AddElementValue *addElementValue,
llvm::Optional<DebugInfo> debugInfo = llvm::None) {
auto *buf = allocator.Allocate<AdjointValueBase>();
return new (buf) AdjointValueBase(type, addElementValue, debugInfo);
}
AdjointValueKind getKind() const { return base->kind; } AdjointValueKind getKind() const { return base->kind; }
SILType getType() const { return base->type; } SILType getType() const { return base->type; }
CanType getSwiftType() const { return getType().getASTType(); } CanType getSwiftType() const { return getType().getASTType(); }
@@ -140,6 +167,9 @@ public:
bool isZero() const { return getKind() == AdjointValueKind::Zero; } bool isZero() const { return getKind() == AdjointValueKind::Zero; }
bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; } bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; }
bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; } bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; }
bool isAddElement() const {
return getKind() == AdjointValueKind::AddElement;
}
unsigned getNumAggregateElements() const { unsigned getNumAggregateElements() const {
assert(isAggregate()); assert(isAggregate());
@@ -162,41 +192,77 @@ public:
return base->value.concrete; return base->value.concrete;
} }
void print(llvm::raw_ostream &s) const { AddElementValue *getAddElementValue() const {
switch (getKind()) { assert(isAddElement());
case AdjointValueKind::Zero: return base->value.addElementValue;
s << "Zero[" << getType() << ']';
break;
case AdjointValueKind::Aggregate:
s << "Aggregate[" << getType() << "](";
if (auto *decl =
getType().getASTType()->getStructOrBoundGenericStruct()) {
interleave(
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
s << std::get<0>(elt)->getName() << ": ";
std::get<1>(elt).print(s);
},
[&s] { s << ", "; });
} else if (getType().is<TupleType>()) {
interleave(
getAggregateElements(),
[&s](const AdjointValue &elt) { elt.print(s); },
[&s] { s << ", "; });
} else {
llvm_unreachable("Invalid aggregate");
}
s << ')';
break;
case AdjointValueKind::Concrete:
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
break;
}
} }
void print(llvm::raw_ostream &s) const;
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); }; SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
}; };
/// An abstraction that represents the field locator in
/// an `AddElement` adjoint kind. Depending on the aggregate
/// kind - tuple or struct, of the `baseAdjoint` in an
/// `AddElement` adjoint, the field locator may be an `unsigned int`
/// or a `VarDecl *`.
struct FieldLocator final {
FieldLocator(VarDecl *field) : inner(field) {}
FieldLocator(unsigned int index) : inner(index) {}
friend AddElementValue;
private:
bool isTupleFieldLocator() const {
return std::holds_alternative<unsigned int>(inner);
}
const static constexpr std::true_type TUPLE_FIELD_LOCATOR_TAG =
std::true_type{};
const static constexpr std::false_type STRUCT_FIELD_LOCATOR_TAG =
std::false_type{};
unsigned int getInner(std::true_type) const {
return std::get<unsigned int>(inner);
}
VarDecl *getInner(std::false_type) const {
return std::get<VarDecl *>(inner);
}
std::variant<unsigned int, VarDecl *> inner;
};
/// The underlying value for an `AddElement` adjoint.
struct AddElementValue final {
AdjointValue baseAdjoint;
AdjointValue eltToAdd;
FieldLocator fieldLocator;
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
FieldLocator fieldLocator)
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd),
fieldLocator(fieldLocator) {
assert(baseAdjoint.getType().is<TupleType>() ||
baseAdjoint.getType().getStructOrBoundGenericStruct() != nullptr);
}
bool isTupleAdjoint() const { return fieldLocator.isTupleFieldLocator(); }
bool isStructAdjoint() const { return !isTupleAdjoint(); }
VarDecl *getFieldDecl() const {
assert(isStructAdjoint());
return this->fieldLocator.getInner(FieldLocator::STRUCT_FIELD_LOCATOR_TAG);
}
unsigned int getFieldIndex() const {
assert(isTupleAdjoint());
return this->fieldLocator.getInner(FieldLocator::TUPLE_FIELD_LOCATOR_TAG);
}
};
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const AdjointValue &adjVal) { const AdjointValue &adjVal) {
adjVal.print(os); adjVal.print(os);

View File

@@ -0,0 +1,70 @@
//===--- AdjointValue.h - Helper class for differentiation ----*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// AdjointValue - a symbolic representation for adjoint values enabling
// efficient differentiation by avoiding zero materialization.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "differentiation"
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
void swift::autodiff::AdjointValue::print(llvm::raw_ostream &s) const {
switch (getKind()) {
case AdjointValueKind::Zero:
s << "Zero[" << getType() << ']';
break;
case AdjointValueKind::Aggregate:
s << "Aggregate[" << getType() << "](";
if (auto *decl = getType().getASTType()->getStructOrBoundGenericStruct()) {
interleave(
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
s << std::get<0>(elt)->getName() << ": ";
std::get<1>(elt).print(s);
},
[&s] { s << ", "; });
} else if (getType().is<TupleType>()) {
interleave(
getAggregateElements(),
[&s](const AdjointValue &elt) { elt.print(s); }, [&s] { s << ", "; });
} else {
llvm_unreachable("Invalid aggregate");
}
s << ')';
break;
case AdjointValueKind::Concrete:
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
break;
case AdjointValueKind::AddElement:
auto *addElementValue = getAddElementValue();
auto baseAdjoint = addElementValue->baseAdjoint;
auto eltToAdd = addElementValue->eltToAdd;
s << "AddElement[";
baseAdjoint.print(s);
s << ", Field(";
if (addElementValue->isTupleAdjoint()) {
s << addElementValue->getFieldIndex();
} else {
s << addElementValue->getFieldDecl()->getNameStr();
}
s << "), ";
eltToAdd.print(s);
s << "]";
break;
}
}

View File

@@ -1,5 +1,6 @@
target_sources(swiftSILOptimizer PRIVATE target_sources(swiftSILOptimizer PRIVATE
ADContext.cpp ADContext.cpp
AdjointValue.cpp
Common.cpp Common.cpp
DifferentiationInvoker.cpp DifferentiationInvoker.cpp
JVPCloner.cpp JVPCloner.cpp

View File

@@ -228,11 +228,12 @@ private:
auto zeroVal = emitZeroDirect(val.getSwiftType(), loc); auto zeroVal = emitZeroDirect(val.getSwiftType(), loc);
return zeroVal; return zeroVal;
} }
case AdjointValueKind::Aggregate:
llvm_unreachable(
"Tuples and structs are not supported in forward mode yet.");
case AdjointValueKind::Concrete: case AdjointValueKind::Concrete:
return val.getConcreteValue(); return val.getConcreteValue();
case AdjointValueKind::Aggregate:
case AdjointValueKind::AddElement:
llvm_unreachable(
"Tuples and structs are not supported in forward mode yet.");
} }
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715 llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
} }

View File

@@ -323,6 +323,15 @@ private:
return AdjointValue::createAggregate(allocator, remapType(type), elements); return AdjointValue::createAggregate(allocator, remapType(type), elements);
} }
AdjointValue makeAddElementAdjointValue(AdjointValue baseAdjoint,
AdjointValue eltToAdd,
FieldLocator fieldLocator) {
auto *addElementValue =
new AddElementValue(baseAdjoint, eltToAdd, fieldLocator);
return AdjointValue::createAddElement(allocator, baseAdjoint.getType(),
addElementValue);
}
//--------------------------------------------------------------------------// //--------------------------------------------------------------------------//
// Adjoint value materialization // Adjoint value materialization
//--------------------------------------------------------------------------// //--------------------------------------------------------------------------//
@@ -355,6 +364,19 @@ private:
case AdjointValueKind::Concrete: case AdjointValueKind::Concrete:
result = val.getConcreteValue(); result = val.getConcreteValue();
break; break;
case AdjointValueKind::AddElement: {
auto adjointSILType = val.getAddElementValue()->baseAdjoint.getType();
auto *baseAdjAlloc = builder.createAllocStack(loc, adjointSILType);
materializeAdjointIndirect(val, baseAdjAlloc, loc);
auto baseAdjConcrete = recordTemporary(builder.emitLoadValueOperation(
loc, baseAdjAlloc, LoadOwnershipQualifier::Take));
builder.createDeallocStack(loc, baseAdjAlloc);
result = baseAdjConcrete;
break;
}
} }
if (auto debugInfo = val.getDebugInfo()) if (auto debugInfo = val.getDebugInfo())
builder.createDebugValue( builder.createDebugValue(
@@ -400,12 +422,62 @@ private:
} }
/// If adjoint value is concrete, it is already materialized. Store it in /// If adjoint value is concrete, it is already materialized. Store it in
/// the destination address. /// the destination address.
case AdjointValueKind::Concrete: case AdjointValueKind::Concrete: {
auto concreteVal = val.getConcreteValue(); auto concreteVal = val.getConcreteValue();
builder.emitStoreValueOperation(loc, concreteVal, destAddress, auto copyOfConcreteVal = builder.emitCopyValueOperation(loc, concreteVal);
builder.emitStoreValueOperation(loc, copyOfConcreteVal, destAddress,
StoreOwnershipQualifier::Init); StoreOwnershipQualifier::Init);
break; break;
} }
case AdjointValueKind::AddElement: {
auto baseAdjoint = val;
auto baseAdjointType = baseAdjoint.getType();
// Current adjoint may be made up of layers of `AddElement` adjoints.
// We can iteratively gather the list of elements to add instead of making
// recursive calls to `materializeAdjointIndirect`.
SmallVector<AddElementValue *, 4> addEltAdjValues;
do {
auto addElementValue = baseAdjoint.getAddElementValue();
addEltAdjValues.push_back(addElementValue);
baseAdjoint = addElementValue->baseAdjoint;
assert(baseAdjointType == baseAdjoint.getType());
} while (baseAdjoint.getKind() == AdjointValueKind::AddElement);
materializeAdjointIndirect(baseAdjoint, destAddress, loc);
for (auto *addElementValue : addEltAdjValues) {
auto eltToAdd = addElementValue->eltToAdd;
SILValue baseAdjEltAddr;
if (baseAdjoint.getType().is<TupleType>()) {
baseAdjEltAddr = builder.createTupleElementAddr(
loc, destAddress, addElementValue->getFieldIndex());
} else {
baseAdjEltAddr = builder.createStructElementAddr(
loc, destAddress, addElementValue->getFieldDecl());
}
auto eltToAddMaterialized = materializeAdjointDirect(eltToAdd, loc);
// Copy `eltToAddMaterialized` so we have a value with owned ownership
// semantics, required for using `eltToAddMaterialized` in a `store`
// instruction.
auto eltToAddMaterializedCopy =
builder.emitCopyValueOperation(loc, eltToAddMaterialized);
auto *eltToAddAlloc = builder.createAllocStack(loc, eltToAdd.getType());
builder.emitStoreValueOperation(loc, eltToAddMaterializedCopy,
eltToAddAlloc,
StoreOwnershipQualifier::Init);
builder.emitInPlaceAdd(loc, baseAdjEltAddr, eltToAddAlloc);
builder.createDestroyAddr(loc, eltToAddAlloc);
builder.createDeallocStack(loc, eltToAddAlloc);
}
break;
}
}
} }
//--------------------------------------------------------------------------// //--------------------------------------------------------------------------//
@@ -1095,6 +1167,10 @@ public:
"Aggregate adjoint values should not occur for `struct` " "Aggregate adjoint values should not occur for `struct` "
"instructions"); "instructions");
} }
case AdjointValueKind::AddElement: {
llvm_unreachable(
"Adjoint of `StructInst` cannot be of kind `AddElement`");
}
} }
break; break;
} }
@@ -1150,41 +1226,29 @@ public:
auto structTy = remapType(sei->getOperand()->getType()).getASTType(); auto structTy = remapType(sei->getOperand()->getType()).getASTType();
auto *tanField = auto *tanField =
getTangentStoredProperty(getContext(), sei, structTy, getInvoker()); getTangentStoredProperty(getContext(), sei, structTy, getInvoker());
assert(tanField && "Invalid projections should have been diagnosed");
// Check the `struct_extract` operand's value tangent category. // Check the `struct_extract` operand's value tangent category.
switch (getTangentValueCategory(sei->getOperand())) { switch (getTangentValueCategory(sei->getOperand())) {
case SILValueCategory::Object: { case SILValueCategory::Object: {
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
auto *tangentVectorDecl =
tangentVectorTy->getStructOrBoundGenericStruct();
assert(tangentVectorDecl);
auto tangentVectorSILTy = auto tangentVectorSILTy =
SILType::getPrimitiveObjectType(tangentVectorTy); SILType::getPrimitiveObjectType(tangentVectorTy);
assert(tanField && "Invalid projections should have been diagnosed"); auto eltAdj = getAdjointValue(bb, sei);
// Accumulate adjoint for the `struct_extract` operand.
auto av = getAdjointValue(bb, sei); switch (eltAdj.getKind()) {
switch (av.getKind()) { case AdjointValueKind::Zero: {
case AdjointValueKind::Zero:
addAdjointValue(bb, sei->getOperand(), addAdjointValue(bb, sei->getOperand(),
makeZeroAdjointValue(tangentVectorSILTy), loc); makeZeroAdjointValue(tangentVectorSILTy), loc);
break; break;
}
case AdjointValueKind::Aggregate:
case AdjointValueKind::Concrete: case AdjointValueKind::Concrete:
case AdjointValueKind::Aggregate: { case AdjointValueKind::AddElement: {
SmallVector<AdjointValue, 8> eltVals; auto baseAdj = makeZeroAdjointValue(tangentVectorSILTy);
for (auto *field : tangentVectorDecl->getStoredProperties()) {
if (field == tanField) {
eltVals.push_back(av);
} else {
auto substMap = tangentVectorTy->getMemberSubstitutionMap(
field->getModuleContext(), field);
auto fieldTy = field->getInterfaceType().subst(substMap);
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
assert(fieldSILTy.isObject());
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
}
}
addAdjointValue(bb, sei->getOperand(), addAdjointValue(bb, sei->getOperand(),
makeAggregateAdjointValue(tangentVectorSILTy, eltVals), makeAddElementAdjointValue(baseAdj, eltAdj, tanField),
loc); loc);
break;
} }
} }
break; break;
@@ -1320,7 +1384,7 @@ public:
} }
break; break;
} }
case AdjointValueKind::Aggregate: case AdjointValueKind::Aggregate: {
unsigned adjIndex = 0; unsigned adjIndex = 0;
for (auto i : range(ti->getElements().size())) { for (auto i : range(ti->getElements().size())) {
if (!getTangentSpace(ti->getElement(i)->getType().getASTType())) if (!getTangentSpace(ti->getElement(i)->getType().getASTType()))
@@ -1330,6 +1394,11 @@ public:
} }
break; break;
} }
case AdjointValueKind::AddElement: {
llvm_unreachable(
"Adjoint of `TupleInst` cannot be of kind `AddElement`");
}
}
break; break;
} }
case SILValueCategory::Address: { case SILValueCategory::Address: {
@@ -1358,42 +1427,42 @@ public:
/// index corresponding to n /// index corresponding to n
void visitTupleExtractInst(TupleExtractInst *tei) { void visitTupleExtractInst(TupleExtractInst *tei) {
auto *bb = tei->getParent(); auto *bb = tei->getParent();
auto loc = tei->getLoc();
auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType()); auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
auto av = getAdjointValue(bb, tei); auto eltAdj = getAdjointValue(bb, tei);
switch (av.getKind()) { switch (eltAdj.getKind()) {
case AdjointValueKind::Zero: case AdjointValueKind::Zero: {
addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy), addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy),
tei->getLoc()); loc);
break; break;
}
case AdjointValueKind::Aggregate: case AdjointValueKind::Aggregate:
case AdjointValueKind::Concrete: { case AdjointValueKind::Concrete:
case AdjointValueKind::AddElement: {
auto tupleTy = tei->getTupleType(); auto tupleTy = tei->getTupleType();
auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>(); auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>();
if (!tupleTanTupleTy) { if (!tupleTanTupleTy) {
addAdjointValue(bb, tei->getOperand(), av, tei->getLoc()); addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
break; break;
} }
SmallVector<AdjointValue, 8> elements;
unsigned adjIdx = 0; unsigned elements = 0;
for (unsigned i : range(tupleTy->getNumElements())) { for (unsigned i : range(tupleTy->getNumElements())) {
if (!getTangentSpace( if (!getTangentSpace(
tupleTy->getElement(i).getType()->getCanonicalType())) tupleTy->getElement(i).getType()->getCanonicalType()))
continue; continue;
if (tei->getFieldIndex() == i) elements++;
elements.push_back(av);
else
elements.push_back(makeZeroAdjointValue(
getRemappedTangentType(SILType::getPrimitiveObjectType(
tupleTanTupleTy->getElementType(adjIdx++)
->getCanonicalType()))));
} }
if (elements.size() == 1) {
addAdjointValue(bb, tei->getOperand(), elements.front(), tei->getLoc()); if (elements == 1) {
break; addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
} else {
auto baseAdj = makeZeroAdjointValue(tupleTanTy);
addAdjointValue(
bb, tei->getOperand(),
makeAddElementAdjointValue(baseAdj, eltAdj, tei->getFieldIndex()),
loc);
} }
addAdjointValue(bb, tei->getOperand(),
makeAggregateAdjointValue(tupleTanTy, elements),
tei->getLoc());
break; break;
} }
} }
@@ -2719,7 +2788,12 @@ bool PullbackCloner::Implementation::runForSemanticMemberGetter() {
addAdjointValue(origEntry, origSelf, addAdjointValue(origEntry, origSelf,
makeAggregateAdjointValue(tangentVectorSILTy, eltVals), makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
pbLoc); pbLoc);
break;
} }
case AdjointValueKind::AddElement:
llvm_unreachable("Adjoint of an aggregate type's field cannot be of kind "
"`AddElement`");
} }
break; break;
} }
@@ -2959,7 +3033,7 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect( AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
AdjointValue lhs, AdjointValue rhs, SILLocation loc) { AdjointValue lhs, AdjointValue rhs, SILLocation loc) {
LLVM_DEBUG(getADDebugStream() << "Materializing adjoint directly.\nLHS: " LLVM_DEBUG(getADDebugStream() << "Accumulating adjoint directly.\nLHS: "
<< lhs << "\nRHS: " << rhs << '\n'); << lhs << "\nRHS: " << rhs << '\n');
switch (lhs.getKind()) { switch (lhs.getKind()) {
// x // x
@@ -2976,7 +3050,7 @@ AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
case AdjointValueKind::Zero: case AdjointValueKind::Zero:
return lhs; return lhs;
// x + (y, z) => (x.0 + y, x.1 + z) // x + (y, z) => (x.0 + y, x.1 + z)
case AdjointValueKind::Aggregate: case AdjointValueKind::Aggregate: {
SmallVector<AdjointValue, 8> newElements; SmallVector<AdjointValue, 8> newElements;
auto lhsTy = lhsVal->getType().getASTType(); auto lhsTy = lhsVal->getType().getASTType();
auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal); auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal);
@@ -3004,13 +3078,24 @@ AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
} }
return makeAggregateAdjointValue(lhsVal->getType(), newElements); return makeAggregateAdjointValue(lhsVal->getType(), newElements);
} }
// x + (baseAdjoint, index, eltToAdd) => (x+baseAdjoint, index, eltToAdd)
case AdjointValueKind::AddElement: {
auto *addElementValue = rhs.getAddElementValue();
auto baseAdjoint = addElementValue->baseAdjoint;
auto eltToAdd = addElementValue->eltToAdd;
auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
addElementValue->fieldLocator);
}
}
} }
// 0 // 0
case AdjointValueKind::Zero: case AdjointValueKind::Zero:
// 0 + x => x // 0 + x => x
return rhs; return rhs;
// (x, y) // (x, y)
case AdjointValueKind::Aggregate: case AdjointValueKind::Aggregate: {
switch (rhs.getKind()) { switch (rhs.getKind()) {
// (x, y) + z => (z.0 + x, z.1 + y) // (x, y) + z => (z.0 + x, z.1 + y)
case AdjointValueKind::Concrete: case AdjointValueKind::Concrete:
@@ -3026,7 +3111,51 @@ AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc)); lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc));
return makeAggregateAdjointValue(lhs.getType(), newElements); return makeAggregateAdjointValue(lhs.getType(), newElements);
} }
// (x.0, ..., x.n) + (baseAdjoint, index, eltToAdd) => (x + baseAdjoint,
// index, eltToAdd)
case AdjointValueKind::AddElement: {
auto *addElementValue = rhs.getAddElementValue();
auto baseAdjoint = addElementValue->baseAdjoint;
auto eltToAdd = addElementValue->eltToAdd;
auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
addElementValue->fieldLocator);
} }
}
}
// (baseAdjoint, index, eltToAdd)
case AdjointValueKind::AddElement: {
switch (rhs.getKind()) {
case AdjointValueKind::Zero:
return lhs;
// (baseAdjoint, index, eltToAdd) + x => (x + baseAdjoint, index, eltToAdd)
case AdjointValueKind::Concrete:
// (baseAdjoint, index, eltToAdd) + (x.0, ..., x.n) => (x + baseAdjoint,
// index, eltToAdd)
case AdjointValueKind::Aggregate:
return accumulateAdjointsDirect(rhs, lhs, loc);
// (baseAdjoint1, index1, eltToAdd1) + (baseAdjoint2, index2, eltToAdd2)
// => ((baseAdjoint1 + baseAdjoint2, index1, eltToAdd1), index2, eltToAdd2)
case AdjointValueKind::AddElement: {
auto *addElementValueLhs = lhs.getAddElementValue();
auto baseAdjointLhs = addElementValueLhs->baseAdjoint;
auto eltToAddLhs = addElementValueLhs->eltToAdd;
auto *addElementValueRhs = rhs.getAddElementValue();
auto baseAdjointRhs = addElementValueRhs->baseAdjoint;
auto eltToAddRhs = addElementValueRhs->eltToAdd;
auto sumOfBaseAdjoints =
accumulateAdjointsDirect(baseAdjointLhs, baseAdjointRhs, loc);
auto newBaseAdjoint = makeAddElementAdjointValue(
sumOfBaseAdjoints, eltToAddLhs, addElementValueLhs->fieldLocator);
return makeAddElementAdjointValue(newBaseAdjoint, eltToAddRhs,
addElementValueRhs->fieldLocator);
}
}
}
} }
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715 llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
} }

View File

@@ -0,0 +1,186 @@
// Pullback generation tests written in SIL for features
// that may not be directly supported by the Swift frontend
// RUN: %target-sil-opt --differentiation -emit-sorted-sil %s 2>&1 | %FileCheck %s
//===----------------------------------------------------------------------===//
// Pullback generation - `struct_extract`
// - Input to pullback has non-owned ownership semantics which requires copying
// this value to stack before lifetime-ending uses.
//===----------------------------------------------------------------------===//
sil_stage raw
import Builtin
import Swift
import SwiftShims
import _Differentiation
struct X {
@_hasStorage var a: Float { get set }
@_hasStorage var b: String { get set }
init(a: Float, b: String)
}
extension X : Differentiable, Equatable, AdditiveArithmetic {
public typealias TangentVector = X
mutating func move(by offset: X)
public static var zero: X { get }
public static func + (lhs: X, rhs: X) -> X
public static func - (lhs: X, rhs: X) -> X
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X, _ b: X) -> Bool
}
struct Y {
@_hasStorage var a: X { get set }
@_hasStorage var b: String { get set }
init(a: X, b: String)
}
extension Y : Differentiable, Equatable, AdditiveArithmetic {
public typealias TangentVector = Y
mutating func move(by offset: Y)
public static var zero: Y { get }
public static func + (lhs: Y, rhs: Y) -> Y
public static func - (lhs: Y, rhs: Y) -> Y
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Y, _ b: Y) -> Bool
}
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
}
sil hidden [ossa] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
bb0(%0 : @guaranteed $Y):
%1 = struct_extract %0 : $Y, #Y.a
%2 = copy_value %1 : $X
return %2 : $X
}
// CHECK-LABEL: sil private [ossa] @$function_with_struct_extract_1TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned Y {
// CHECK: bb0(%0 : @guaranteed $X):
// CHECK: %1 = alloc_stack $Y
// CHECK: %2 = witness_method $Y, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %3 = metatype $@thick Y.Type
// CHECK: %4 = apply %2<Y>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %5 = struct_element_addr %1 : $*Y, #Y.a
// Since input parameter $0 has non-owned ownership semantics, it
// needs to be copied before a lifetime-ending use.
// CHECK: %6 = copy_value %0 : $X
// CHECK: %7 = alloc_stack $X
// CHECK: store %6 to [init] %7 : $*X
// CHECK: %9 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %10 = metatype $@thick X.Type
// CHECK: %11 = apply %9<X>(%5, %7, %10) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %7 : $*X
// CHECK: dealloc_stack %7 : $*X
// CHECK: %14 = load [take] %1 : $*Y
// CHECK: dealloc_stack %1 : $*Y
// CHECK: %16 = copy_value %14 : $Y
// CHECK: destroy_value %14 : $Y
// CHECK: return %16 : $Y
// CHECK: } // end sil function '$function_with_struct_extract_1TJpSpSr'
//===----------------------------------------------------------------------===//
// Pullback generation - `tuple_extract`
// - Tuples as differentiable input arguments are not supported yet, so creating
// a basic test in SIL instead.
//===----------------------------------------------------------------------===//
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
}
sil hidden [ossa] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
bb0(%0 : $(Float, Float)):
%1 = tuple_extract %0 : $(Float, Float), 0
return %1 : $Float
}
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_1TJpSpSr : $@convention(thin) (Float) -> (Float, Float) {
// CHECK: bb0(%0 : $Float):
// CHECK: %1 = alloc_stack $(Float, Float)
// CHECK: %2 = tuple_element_addr %1 : $*(Float, Float), 0
// CHECK: %3 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %4 = metatype $@thick Float.Type
// CHECK: %5 = apply %3<Float>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %6 = tuple_element_addr %1 : $*(Float, Float), 1
// CHECK: %7 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %8 = metatype $@thick Float.Type
// CHECK: %9 = apply %7<Float>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %10 = tuple_element_addr %1 : $*(Float, Float), 0
// CHECK: %11 = alloc_stack $Float
// CHECK: store %0 to [trivial] %11 : $*Float
// CHECK: %13 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %14 = metatype $@thick Float.Type
// CHECK: %15 = apply %13<Float>(%10, %11, %14) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %11 : $*Float
// CHECK: dealloc_stack %11 : $*Float
// CHECK: %18 = load [trivial] %1 : $*(Float, Float)
// CHECK: dealloc_stack %1 : $*(Float, Float)
// CHECK: return %18 : $(Float, Float)
// CHECK: } // end sil function 'function_with_tuple_extract_1TJpSpSr'
//===----------------------------------------------------------------------===//
// Pullback generation - Inner values of concrete adjoints must be copied
// during direct materialization.
// - If the input to pullback BB has non-owned ownership semantics we cannot
// perform a lifetime-ending operation on it.
// - If the input to the pullback BB is an owned, non-trivial value we must
// copy it or there will be a double consume when all owned parameters are
// destroyed at the end of the basic block.
//===----------------------------------------------------------------------===//
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
}
sil hidden [ossa] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
bb0(%0 : @guaranteed $(X, X)):
%1 = tuple_extract %0 : $(X, X), 0
%2 = copy_value %1: $X
return %2 : $X
}
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_2TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned (X, X) {
// CHECK: bb0(%0 : @guaranteed $X):
// CHECK: %1 = alloc_stack $(X, X)
// CHECK: %2 = tuple_element_addr %1 : $*(X, X), 0
// CHECK: %3 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %4 = metatype $@thick X.Type
// CHECK: %5 = apply %3<X>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %6 = tuple_element_addr %1 : $*(X, X), 1
// CHECK: %7 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %8 = metatype $@thick X.Type
// CHECK: %9 = apply %7<X>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %10 = tuple_element_addr %1 : $*(X, X), 0
// CHECK: %11 = copy_value %0 : $X
// CHECK: %12 = alloc_stack $X
// CHECK: store %11 to [init] %12 : $*X
// CHECK: %14 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %15 = metatype $@thick X.Type
// CHECK: %16 = apply %14<X>(%10, %12, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %12 : $*X
// CHECK: dealloc_stack %12 : $*X
// CHECK: %19 = load [take] %1 : $*(X, X)
// CHECK: dealloc_stack %1 : $*(X, X)
// CHECK: %21 = copy_value %19 : $(X, X)
// CHECK: destroy_value %19 : $(X, X)
// CHECK: return %21 : $(X, X)
// CHECK: } // end sil function 'function_with_tuple_extract_2TJpSpSr'
//===----------------------------------------------------------------------===//
// Pullback generation - `tuple_extract`
// - Adjoint of extracted element can be `AddElement`
// - Just need to make sure that we are able to generate a pullback
//===----------------------------------------------------------------------===//
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
}
sil hidden [ossa] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
bb0(%0 : $((Float, Float), Float)):
%1 = tuple_extract %0 : $((Float, Float), Float), 0
%2 = tuple_extract %1 : $(Float, Float), 0
return %2 : $Float
}
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_3TJpSpSr : $@convention(thin) (Float) -> ((Float, Float), Float) {

View File

@@ -0,0 +1,200 @@
// Pullback generation tests written in Swift
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm --sil-print-after=differentiation %s 2>&1 | %FileCheck %s
import _Differentiation
//===----------------------------------------------------------------------===//
// Pullback generation - `struct_extract`
// - Operand is piecewise materializable
//===----------------------------------------------------------------------===//
struct PiecewiseMaterializable: Differentiable, Equatable, AdditiveArithmetic {
public typealias TangentVector = Self
var a: Float
var b: Double
}
@differentiable(reverse)
func f1(v: PiecewiseMaterializable) -> Float {
v.a
}
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f11vSfAA23PiecewiseMaterializableV_tFTJpSpSr : $@convention(thin) (Float) -> PiecewiseMaterializable {
// CHECK: bb0(%0 : $Float):
// CHECK: %1 = alloc_stack $PiecewiseMaterializable
// CHECK: %2 = witness_method $PiecewiseMaterializable, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %3 = metatype $@thick PiecewiseMaterializable.Type
// CHECK: %4 = apply %2<PiecewiseMaterializable>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %5 = struct_element_addr %1 : $*PiecewiseMaterializable, #PiecewiseMaterializable.a
// CHECK: %6 = alloc_stack $Float
// CHECK: store %0 to [trivial] %6 : $*Float
// CHECK: %8 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %9 = metatype $@thick Float.Type
// CHECK: %10 = apply %8<Float>(%5, %6, %9) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %6 : $*Float
// CHECK: dealloc_stack %6 : $*Float
// CHECK: %13 = load [trivial] %1 : $*PiecewiseMaterializable
// CHECK: dealloc_stack %1 : $*PiecewiseMaterializable
// CHECK: debug_value %13 : $PiecewiseMaterializable, let, name "v", argno 1
// CHECK: return %13 : $PiecewiseMaterializable
// CHECK: } // end sil function '$s19pullback_generation2f11vSfAA23PiecewiseMaterializableV_tFTJpSpSr'
//===----------------------------------------------------------------------===//
// Pullback generation - `struct_extract`
// - Operand is non-piecewise materializable
//===----------------------------------------------------------------------===//
struct NonPiecewiseMaterializable {
var a: Float
var b: String
}
extension NonPiecewiseMaterializable: Differentiable, Equatable, AdditiveArithmetic {
public typealias TangentVector = Self
mutating func move(by offset: TangentVector) {fatalError()}
public static var zero: Self {fatalError()}
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
}
@differentiable(reverse)
func f2(v: NonPiecewiseMaterializable) -> Float {
v.a
}
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f21vSfAA26NonPiecewiseMaterializableV_tFTJpSpSr : $@convention(thin) (Float) -> @owned NonPiecewiseMaterializable {
// CHECK: bb0(%0 : $Float):
// CHECK: %1 = alloc_stack $NonPiecewiseMaterializable
// CHECK: %2 = witness_method $NonPiecewiseMaterializable, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %3 = metatype $@thick NonPiecewiseMaterializable.Type
// CHECK: %4 = apply %2<NonPiecewiseMaterializable>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %5 = struct_element_addr %1 : $*NonPiecewiseMaterializable, #NonPiecewiseMaterializable.a
// CHECK: %6 = alloc_stack $Float
// CHECK: store %0 to [trivial] %6 : $*Float
// CHECK: %8 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %9 = metatype $@thick Float.Type
// CHECK: %10 = apply %8<Float>(%5, %6, %9) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %6 : $*Float
// CHECK: dealloc_stack %6 : $*Float
// CHECK: %13 = load [take] %1 : $*NonPiecewiseMaterializable
// CHECK: dealloc_stack %1 : $*NonPiecewiseMaterializable
// CHECK: debug_value %13 : $NonPiecewiseMaterializable, let, name "v", argno 1
// CHECK: %16 = copy_value %13 : $NonPiecewiseMaterializable
// CHECK: destroy_value %13 : $NonPiecewiseMaterializable
// CHECK: return %16 : $NonPiecewiseMaterializable
// CHECK: } // end sil function '$s19pullback_generation2f21vSfAA26NonPiecewiseMaterializableV_tFTJpSpSr'
//===----------------------------------------------------------------------===//
// Pullback generation - `struct_extract`
// - Operand is non-piecewise materializable with an aggregate, piecewise
// materializable, differentiable field
//===----------------------------------------------------------------------===//
struct NonPiecewiseMaterializableWithAggDifferentiableField {
var a: PiecewiseMaterializable
var b: String
}
extension NonPiecewiseMaterializableWithAggDifferentiableField: Differentiable, Equatable, AdditiveArithmetic {
public typealias TangentVector = Self
mutating func move(by offset: TangentVector) {fatalError()}
public static var zero: Self {fatalError()}
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
}
@differentiable(reverse)
func f3(v: NonPiecewiseMaterializableWithAggDifferentiableField) -> PiecewiseMaterializable {
v.a
}
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f31vAA23PiecewiseMaterializableVAA03NondE26WithAggDifferentiableFieldV_tFTJpSpSr : $@convention(thin) (PiecewiseMaterializable) -> @owned NonPiecewiseMaterializableWithAggDifferentiableField {
// CHECK: bb0(%0 : $PiecewiseMaterializable):
// CHECK: %1 = alloc_stack $NonPiecewiseMaterializableWithAggDifferentiableField
// CHECK: %2 = witness_method $NonPiecewiseMaterializableWithAggDifferentiableField, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %3 = metatype $@thick NonPiecewiseMaterializableWithAggDifferentiableField.Type
// CHECK: %4 = apply %2<NonPiecewiseMaterializableWithAggDifferentiableField>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %5 = struct_element_addr %1 : $*NonPiecewiseMaterializableWithAggDifferentiableField, #NonPiecewiseMaterializableWithAggDifferentiableField.a
// CHECK: %6 = alloc_stack $PiecewiseMaterializable
// CHECK: store %0 to [trivial] %6 : $*PiecewiseMaterializable
// CHECK: %8 = witness_method $PiecewiseMaterializable, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %9 = metatype $@thick PiecewiseMaterializable.Type
// CHECK: %10 = apply %8<PiecewiseMaterializable>(%5, %6, %9) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %6 : $*PiecewiseMaterializable
// CHECK: dealloc_stack %6 : $*PiecewiseMaterializable
// CHECK: %13 = load [take] %1 : $*NonPiecewiseMaterializableWithAggDifferentiableField
// CHECK: dealloc_stack %1 : $*NonPiecewiseMaterializableWithAggDifferentiableField
// CHECK: debug_value %13 : $NonPiecewiseMaterializableWithAggDifferentiableField, let, name "v", argno 1
// CHECK: %16 = copy_value %13 : $NonPiecewiseMaterializableWithAggDifferentiableField
// CHECK: destroy_value %13 : $NonPiecewiseMaterializableWithAggDifferentiableField
// CHECK: return %16 : $NonPiecewiseMaterializableWithAggDifferentiableField
// CHECK: } // end sil function '$s19pullback_generation2f31vAA23PiecewiseMaterializableVAA03NondE26WithAggDifferentiableFieldV_tFTJpSpSr'
//===----------------------------------------------------------------------===//
// Pullback generation - `struct_extract`
// - Adjoint of extracted element can be `AddElement`
// - Just need to make sure that we are able to generate a pullback for B.x's
// getter
//===----------------------------------------------------------------------===//
struct A: Differentiable {
public var x: Float
}
struct B: Differentiable {
var y: A
public init(a: A) {
self.y = a
}
@differentiable(reverse)
public var x: Float {
get { return self.y.x }
}
}
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation1BV1xSfvgTJpSpSr : $@convention(thin) (Float) -> B.TangentVector {
//===----------------------------------------------------------------------===//
// Pullback generation - Inner values of concrete adjoints must be copied
// during indirect materialization
//===----------------------------------------------------------------------===//
struct NonTrivial {
var x: Float
var y: String
}
extension NonTrivial: Differentiable, Equatable, AdditiveArithmetic {
public typealias TangentVector = Self
mutating func move(by offset: TangentVector) {fatalError()}
public static var zero: Self {fatalError()}
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
}
@differentiable(reverse)
func f4(a: NonTrivial) -> Float {
var sum: Float = 0
for _ in 0..<1 {
sum += a.x
}
return sum
}
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
// CHECK: %88 = alloc_stack $NonTrivial
// Non-trivial value must be copied or there will be a
// double consume when all owned parameters are destroyed
// at the end of the basic block.
// CHECK: %89 = copy_value %67 : $NonTrivial
// CHECK: store %89 to [init] %88 : $*NonTrivial
// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
// CHECK: %92 = alloc_stack $Float
// CHECK: store %86 to [trivial] %92 : $*Float
// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %95 = metatype $@thick Float.Type
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_value %67 : $NonTrivial

View File

@@ -0,0 +1,141 @@
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm --sil-print-after=differentiation -Xllvm --debug-only=differentiation %s 2>&1 | %FileCheck %s
import _Differentiation
//===----------------------------------------------------------------------===//
// Pullback generation - `struct_extract`
// - Nested AddElement adjoint kind is created due to the extraction of same
// field multiple times
//===----------------------------------------------------------------------===//
struct SmallTestModel : Differentiable {
public var stored1: Float = 3.0
public var stored2: Float = 3.0
public var stored3: Float = 3.0
}
@differentiable(reverse)
func multipleExtractionOfSameField(_ model: SmallTestModel) -> Float{
return model.stored1 + model.stored1 + model.stored1
}
// CHECK-LABEL: [AD] Accumulating adjoint directly.
// CHECK: LHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((%5, **%6**) = destructure_tuple %3 : $(Float, Float)
// CHECK: )]
// CHECK: RHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((%9, **%10**) = destructure_tuple %7 : $(Float, Float)
// CHECK: )]
// CHECK-LABEL: [AD] Accumulating adjoint directly.
// CHECK: LHS: AddElement[AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((%5, **%6**) = destructure_tuple %3 : $(Float, Float)
// CHECK: )], Field(stored1), Concrete[$Float]((%9, **%10**) = destructure_tuple %7 : $(Float, Float)
// CHECK: )]
// CHECK: RHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((**%9**, %10) = destructure_tuple %7 : $(Float, Float)
// CHECK: )]
// CHECK-LABEL: sil private [ossa] @$s46pullback_generation_nested_addelement_adjoints29multipleExtractionOfSameFieldySfAA14SmallTestModelVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> SmallTestModel.TangentVector {
// CHECK: bb0(%0 : $Float, %1 : @owned $@callee_guaranteed (Float) -> (Float, Float), %2 : @owned $@callee_guaranteed (Float) -> (Float, Float)):
// CHECK: %3 = apply %2(%0) : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: destroy_value %2 : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: (%5, %6) = destructure_tuple %3 : $(Float, Float)
// CHECK: %7 = apply %1(%5) : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: destroy_value %1 : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: (%9, %10) = destructure_tuple %7 : $(Float, Float)
// CHECK: %11 = alloc_stack $SmallTestModel.TangentVector
// CHECK: %12 = witness_method $SmallTestModel.TangentVector, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %13 = metatype $@thick SmallTestModel.TangentVector.Type
// CHECK: %14 = apply %12<SmallTestModel.TangentVector>(%11, %13) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %15 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored1
// CHECK: %16 = alloc_stack $Float
// CHECK: store %9 to [trivial] %16 : $*Float
// CHECK: %18 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %19 = metatype $@thick Float.Type
// CHECK: %20 = apply %18<Float>(%15, %16, %19) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %16 : $*Float
// CHECK: dealloc_stack %16 : $*Float
// CHECK: %23 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored1
// CHECK: %24 = alloc_stack $Float
// CHECK: store %10 to [trivial] %24 : $*Float
// CHECK: %26 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %27 = metatype $@thick Float.Type
// CHECK: %28 = apply %26<Float>(%23, %24, %27) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %24 : $*Float
// CHECK: dealloc_stack %24 : $*Float
// CHECK: %31 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored1
// CHECK: %32 = alloc_stack $Float
// CHECK: store %6 to [trivial] %32 : $*Float
// CHECK: %34 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %35 = metatype $@thick Float.Type
// CHECK: %36 = apply %34<Float>(%31, %32, %35) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %32 : $*Float
// CHECK: dealloc_stack %32 : $*Float
// CHECK: %39 = load [trivial] %11 : $*SmallTestModel.TangentVector
// CHECK: dealloc_stack %11 : $*SmallTestModel.TangentVector
// CHECK: debug_value %39 : $SmallTestModel.TangentVector, let, name "model", argno 1
// CHECK: return %39 : $SmallTestModel.TangentVector
// CHECK: }
//===----------------------------------------------------------------------===//
// Pullback generation - `struct_extract`
// - Nested AddElement adjoint kind is created due to the extraction of multiple
// fields from the same struct
//===----------------------------------------------------------------------===//
@differentiable(reverse)
func multipleExtractionsFromSameStruct(_ model: SmallTestModel) -> Float{
return model.stored1 + model.stored2 + model.stored3
}
// CHECK-LABEL: [AD] Accumulating adjoint directly.
// CHECK: LHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored3), Concrete[$Float]((%5, **%6**) = destructure_tuple %3 : $(Float, Float)
// CHECK: )]
// CHECK: RHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored2), Concrete[$Float]((%9, **%10**) = destructure_tuple %7 : $(Float, Float)
// CHECK: )]
// CHECK-LABEL: [AD] Accumulating adjoint directly.
// CHECK: LHS: AddElement[AddElement[Zero[$SmallTestModel.TangentVector], Field(stored3), Concrete[$Float]((%5, **%6**) = destructure_tuple %3 : $(Float, Float)
// CHECK: )], Field(stored2), Concrete[$Float]((%9, **%10**) = destructure_tuple %7 : $(Float, Float)
// CHECK: )]
// CHECK: RHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((**%9**, %10) = destructure_tuple %7 : $(Float, Float)
// CHECK: )]
// CHECK-LABEL: sil private [ossa] @$s46pullback_generation_nested_addelement_adjoints33multipleExtractionsFromSameStructySfAA14SmallTestModelVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> SmallTestModel.TangentVector {
// CHECK: bb0(%0 : $Float, %1 : @owned $@callee_guaranteed (Float) -> (Float, Float), %2 : @owned $@callee_guaranteed (Float) -> (Float, Float)):
// CHECK: %3 = apply %2(%0) : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: destroy_value %2 : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: (%5, %6) = destructure_tuple %3 : $(Float, Float)
// CHECK: %7 = apply %1(%5) : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: destroy_value %1 : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: (%9, %10) = destructure_tuple %7 : $(Float, Float)
// CHECK: %11 = alloc_stack $SmallTestModel.TangentVector
// CHECK: %12 = witness_method $SmallTestModel.TangentVector, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %13 = metatype $@thick SmallTestModel.TangentVector.Type
// CHECK: %14 = apply %12<SmallTestModel.TangentVector>(%11, %13) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %15 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored1
// CHECK: %16 = alloc_stack $Float
// CHECK: store %9 to [trivial] %16 : $*Float
// CHECK: %18 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %19 = metatype $@thick Float.Type
// CHECK: %20 = apply %18<Float>(%15, %16, %19) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %16 : $*Float
// CHECK: dealloc_stack %16 : $*Float
// CHECK: %23 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored2
// CHECK: %24 = alloc_stack $Float
// CHECK: store %10 to [trivial] %24 : $*Float
// CHECK: %26 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %27 = metatype $@thick Float.Type
// CHECK: %28 = apply %26<Float>(%23, %24, %27) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %24 : $*Float
// CHECK: dealloc_stack %24 : $*Float
// CHECK: %31 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored3
// CHECK: %32 = alloc_stack $Float
// CHECK: store %6 to [trivial] %32 : $*Float
// CHECK: %34 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %35 = metatype $@thick Float.Type
// CHECK: %36 = apply %34<Float>(%31, %32, %35) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_addr %32 : $*Float
// CHECK: dealloc_stack %32 : $*Float
// CHECK: %39 = load [trivial] %11 : $*SmallTestModel.TangentVector
// CHECK: dealloc_stack %11 : $*SmallTestModel.TangentVector
// CHECK: debug_value %39 : $SmallTestModel.TangentVector, let, name "model", argno 1
// CHECK: return %39 : $SmallTestModel.TangentVector
// CHECK: }

View File

@@ -0,0 +1,32 @@
// RUN: %target-swift-frontend -emit-sil -O %s
import _Differentiation
// Issue #66522:
// Pullback generation for a function mapping a differentiable input type to
// one of its differentiable fields fails when the input's tangent vector contains
// non-differentiable fields.
public struct P<Value>: Differentiable
where
Value: Differentiable,
Value.TangentVector == Value,
Value: AdditiveArithmetic {
// `P` is its own `TangentVector`
public typealias TangentVector = Self
// Non-differentiable field in `P`'s `TangentVector`.
public let name: String = ""
var value: Value
}
extension P: Equatable, AdditiveArithmetic
where Value: AdditiveArithmetic {
public static var zero: Self {fatalError()}
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
}
@differentiable(reverse)
internal func testFunction(data: P<Double>) -> Double {
data.value
}

View File

@@ -501,20 +501,17 @@ SimpleMathTests.test("Adjoint value accumulation for aggregate lhs and concrete
// CHECK-LABEL: sil private [ossa] @${{.*}}doubled{{.*}}TJp{{.*}} : $@convention(thin) (Float, @owned {{.*}}) -> SmallTestModel.TangentVector { // CHECK-LABEL: sil private [ossa] @${{.*}}doubled{{.*}}TJp{{.*}} : $@convention(thin) (Float, @owned {{.*}}) -> SmallTestModel.TangentVector {
// CHECK: bb0([[DX:%.*]] : $Float, [[PB0:%.*]] : {{.*}}, [[PB1:%.*]] : {{.*}}): // CHECK: bb0([[DX:%.*]] : $Float, [[PB0:%.*]] : {{.*}}, [[PB1:%.*]] : {{.*}}):
// CHECK: [[ADJ_TUPLE:%.*]] = apply [[PB1]]([[DX]]) : $@callee_guaranteed (Float) -> (Float, Float) // CHECK: [[ADJ_TUPLE:%.*]] = apply [[PB1]]([[DX]]) : $@callee_guaranteed (Float) -> (Float, Float)
// CHECK: ([[TMP0:%.*]], [[ADJ_CONCRETE:%.*]]) = destructure_tuple [[ADJ_TUPLE]] : $(Float, Float) // CHECK: ([[TMP0:%.*]], [[ADJ_CONCRETE:%.*]]) = destructure_tuple [[ADJ_TUPLE]] : $(Float, Float)
// CHECK: [[TMP1:%.*]] = apply [[PB0]]([[TMP0]]) : $@callee_guaranteed (Float) -> SmallTestModel.TangentVector // CHECK: [[TMP1:%.*]] = apply [[PB0]]([[TMP0]]) : $@callee_guaranteed (Float) -> SmallTestModel.TangentVector
// CHECK: [[ADJ_STRUCT_FIELD:%.*]] = destructure_struct [[TMP1]] : $SmallTestModel.TangentVector // CHECK: [[TMP_RES_ADJ_STRUCT:%.*]] = alloc_stack $SmallTestModel.TangentVector
// CHECK: [[TMP_RES:%.*]] = alloc_stack $Float // CHECK: store [[TMP1]] to [trivial] [[TMP_RES_ADJ_STRUCT]] : $*SmallTestModel.TangentVector
// CHECK: [[TMP_ADJ_STRUCT_FIELD:%.*]] = alloc_stack $Float // CHECK: [[TMP_RES_ADJ_STRUCT_FIELD:%.*]] = struct_element_addr [[TMP_RES_ADJ_STRUCT]] : $*SmallTestModel.TangentVector, #{{.*}}SmallTestModel.TangentVector.stored
// CHECK: [[TMP_ADJ_CONCRETE:%.*]] = alloc_stack $Float // CHECK: [[TMP_RES_ADJ_STRUCT_ADD_ELT:%.*]] = alloc_stack $Float
// CHECK: store [[ADJ_STRUCT_FIELD]] to [trivial] [[TMP_ADJ_STRUCT_FIELD]] : $*Float // CHECK: store [[ADJ_CONCRETE]] to [trivial] [[TMP_RES_ADJ_STRUCT_ADD_ELT]] : $*Float
// CHECK: store [[ADJ_CONCRETE]] to [trivial] [[TMP_ADJ_CONCRETE]] : $*Float // CHECK: [[PLUS_EQUAL:%.*]] = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: [[PLUS_EQUAL:%.*]] = witness_method $Float, #AdditiveArithmetic."+" // CHECK: {{.*}} = apply [[PLUS_EQUAL]]<Float>([[TMP_RES_ADJ_STRUCT_FIELD]], [[TMP_RES_ADJ_STRUCT_ADD_ELT]], {{.*}}) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %{{.*}} = apply [[PLUS_EQUAL]]<Float>([[TMP_RES]], [[TMP_ADJ_CONCRETE]], [[TMP_ADJ_STRUCT_FIELD]], {{.*}}) // CHECK: [[RES_STRUCT:%.*]] = load [trivial] [[TMP_RES_ADJ_STRUCT]] : $*SmallTestModel.TangentVector
// CHECK: [[RES:%.*]] = load [trivial] [[TMP_RES]] : $*Float // CHECK: return [[RES_STRUCT]] : $SmallTestModel.TangentVector
// CHECK: [[RES_STRUCT:%.*]] = struct $SmallTestModel.TangentVector ([[RES]] : $Float)
// CHECK: return [[RES_STRUCT]] : $SmallTestModel.TangentVector
// CHECK: }
runAllTests() runAllTests()