mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
[AutoDiff] Handle materializing adjoints with non-differentiable fields (#67319)
This commit is contained in:
@@ -19,6 +19,8 @@
|
|||||||
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
|
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
|
||||||
|
|
||||||
#include "swift/AST/Decl.h"
|
#include "swift/AST/Decl.h"
|
||||||
|
#include "swift/SIL/SILDebugVariable.h"
|
||||||
|
#include "swift/SIL/SILLocation.h"
|
||||||
#include "swift/SIL/SILValue.h"
|
#include "swift/SIL/SILValue.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
@@ -38,10 +40,18 @@ enum AdjointValueKind {
|
|||||||
|
|
||||||
/// A concrete SIL value.
|
/// A concrete SIL value.
|
||||||
Concrete,
|
Concrete,
|
||||||
|
|
||||||
|
/// A special adjoint, made up of 2 adjoints -- an aggregate base adjoint and
|
||||||
|
/// an element adjoint to add to one of its fields. This case exists to avoid
|
||||||
|
/// eager materialization of a base adjoint upon addition with one of its
|
||||||
|
/// fields.
|
||||||
|
AddElement,
|
||||||
};
|
};
|
||||||
|
|
||||||
class AdjointValue;
|
class AdjointValue;
|
||||||
|
|
||||||
|
struct AddElementValue;
|
||||||
|
|
||||||
class AdjointValueBase {
|
class AdjointValueBase {
|
||||||
friend class AdjointValue;
|
friend class AdjointValue;
|
||||||
|
|
||||||
@@ -60,9 +70,13 @@ class AdjointValueBase {
|
|||||||
union Value {
|
union Value {
|
||||||
unsigned numAggregateElements;
|
unsigned numAggregateElements;
|
||||||
SILValue concrete;
|
SILValue concrete;
|
||||||
|
AddElementValue *addElementValue;
|
||||||
|
|
||||||
Value(unsigned numAggregateElements)
|
Value(unsigned numAggregateElements)
|
||||||
: numAggregateElements(numAggregateElements) {}
|
: numAggregateElements(numAggregateElements) {}
|
||||||
Value(SILValue v) : concrete(v) {}
|
Value(SILValue v) : concrete(v) {}
|
||||||
|
Value(AddElementValue *addElementValue)
|
||||||
|
: addElementValue(addElementValue) {}
|
||||||
Value() {}
|
Value() {}
|
||||||
} value;
|
} value;
|
||||||
|
|
||||||
@@ -86,6 +100,11 @@ class AdjointValueBase {
|
|||||||
|
|
||||||
explicit AdjointValueBase(SILType type, llvm::Optional<DebugInfo> debugInfo)
|
explicit AdjointValueBase(SILType type, llvm::Optional<DebugInfo> debugInfo)
|
||||||
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
|
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
|
||||||
|
|
||||||
|
explicit AdjointValueBase(SILType type, AddElementValue *addElementValue,
|
||||||
|
llvm::Optional<DebugInfo> debugInfo)
|
||||||
|
: kind(AdjointValueKind::AddElement), type(type), debugInfo(debugInfo),
|
||||||
|
value(addElementValue) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
|
/// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
|
||||||
@@ -127,6 +146,14 @@ public:
|
|||||||
return new (buf) AdjointValueBase(type, elements, debugInfo);
|
return new (buf) AdjointValueBase(type, elements, debugInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static AdjointValue
|
||||||
|
createAddElement(llvm::BumpPtrAllocator &allocator, SILType type,
|
||||||
|
AddElementValue *addElementValue,
|
||||||
|
llvm::Optional<DebugInfo> debugInfo = llvm::None) {
|
||||||
|
auto *buf = allocator.Allocate<AdjointValueBase>();
|
||||||
|
return new (buf) AdjointValueBase(type, addElementValue, debugInfo);
|
||||||
|
}
|
||||||
|
|
||||||
AdjointValueKind getKind() const { return base->kind; }
|
AdjointValueKind getKind() const { return base->kind; }
|
||||||
SILType getType() const { return base->type; }
|
SILType getType() const { return base->type; }
|
||||||
CanType getSwiftType() const { return getType().getASTType(); }
|
CanType getSwiftType() const { return getType().getASTType(); }
|
||||||
@@ -140,6 +167,9 @@ public:
|
|||||||
bool isZero() const { return getKind() == AdjointValueKind::Zero; }
|
bool isZero() const { return getKind() == AdjointValueKind::Zero; }
|
||||||
bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; }
|
bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; }
|
||||||
bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; }
|
bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; }
|
||||||
|
bool isAddElement() const {
|
||||||
|
return getKind() == AdjointValueKind::AddElement;
|
||||||
|
}
|
||||||
|
|
||||||
unsigned getNumAggregateElements() const {
|
unsigned getNumAggregateElements() const {
|
||||||
assert(isAggregate());
|
assert(isAggregate());
|
||||||
@@ -162,41 +192,77 @@ public:
|
|||||||
return base->value.concrete;
|
return base->value.concrete;
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(llvm::raw_ostream &s) const {
|
AddElementValue *getAddElementValue() const {
|
||||||
switch (getKind()) {
|
assert(isAddElement());
|
||||||
case AdjointValueKind::Zero:
|
return base->value.addElementValue;
|
||||||
s << "Zero[" << getType() << ']';
|
|
||||||
break;
|
|
||||||
case AdjointValueKind::Aggregate:
|
|
||||||
s << "Aggregate[" << getType() << "](";
|
|
||||||
if (auto *decl =
|
|
||||||
getType().getASTType()->getStructOrBoundGenericStruct()) {
|
|
||||||
interleave(
|
|
||||||
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
|
|
||||||
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
|
|
||||||
s << std::get<0>(elt)->getName() << ": ";
|
|
||||||
std::get<1>(elt).print(s);
|
|
||||||
},
|
|
||||||
[&s] { s << ", "; });
|
|
||||||
} else if (getType().is<TupleType>()) {
|
|
||||||
interleave(
|
|
||||||
getAggregateElements(),
|
|
||||||
[&s](const AdjointValue &elt) { elt.print(s); },
|
|
||||||
[&s] { s << ", "; });
|
|
||||||
} else {
|
|
||||||
llvm_unreachable("Invalid aggregate");
|
|
||||||
}
|
|
||||||
s << ')';
|
|
||||||
break;
|
|
||||||
case AdjointValueKind::Concrete:
|
|
||||||
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void print(llvm::raw_ostream &s) const;
|
||||||
|
|
||||||
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
|
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// An abstraction that represents the field locator in
|
||||||
|
/// an `AddElement` adjoint kind. Depending on the aggregate
|
||||||
|
/// kind - tuple or struct, of the `baseAdjoint` in an
|
||||||
|
/// `AddElement` adjoint, the field locator may be an `unsigned int`
|
||||||
|
/// or a `VarDecl *`.
|
||||||
|
struct FieldLocator final {
|
||||||
|
FieldLocator(VarDecl *field) : inner(field) {}
|
||||||
|
FieldLocator(unsigned int index) : inner(index) {}
|
||||||
|
|
||||||
|
friend AddElementValue;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool isTupleFieldLocator() const {
|
||||||
|
return std::holds_alternative<unsigned int>(inner);
|
||||||
|
}
|
||||||
|
|
||||||
|
const static constexpr std::true_type TUPLE_FIELD_LOCATOR_TAG =
|
||||||
|
std::true_type{};
|
||||||
|
const static constexpr std::false_type STRUCT_FIELD_LOCATOR_TAG =
|
||||||
|
std::false_type{};
|
||||||
|
|
||||||
|
unsigned int getInner(std::true_type) const {
|
||||||
|
return std::get<unsigned int>(inner);
|
||||||
|
}
|
||||||
|
|
||||||
|
VarDecl *getInner(std::false_type) const {
|
||||||
|
return std::get<VarDecl *>(inner);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::variant<unsigned int, VarDecl *> inner;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// The underlying value for an `AddElement` adjoint.
|
||||||
|
struct AddElementValue final {
|
||||||
|
AdjointValue baseAdjoint;
|
||||||
|
AdjointValue eltToAdd;
|
||||||
|
FieldLocator fieldLocator;
|
||||||
|
|
||||||
|
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
|
||||||
|
FieldLocator fieldLocator)
|
||||||
|
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd),
|
||||||
|
fieldLocator(fieldLocator) {
|
||||||
|
assert(baseAdjoint.getType().is<TupleType>() ||
|
||||||
|
baseAdjoint.getType().getStructOrBoundGenericStruct() != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isTupleAdjoint() const { return fieldLocator.isTupleFieldLocator(); }
|
||||||
|
|
||||||
|
bool isStructAdjoint() const { return !isTupleAdjoint(); }
|
||||||
|
|
||||||
|
VarDecl *getFieldDecl() const {
|
||||||
|
assert(isStructAdjoint());
|
||||||
|
return this->fieldLocator.getInner(FieldLocator::STRUCT_FIELD_LOCATOR_TAG);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned int getFieldIndex() const {
|
||||||
|
assert(isTupleAdjoint());
|
||||||
|
return this->fieldLocator.getInner(FieldLocator::TUPLE_FIELD_LOCATOR_TAG);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||||
const AdjointValue &adjVal) {
|
const AdjointValue &adjVal) {
|
||||||
adjVal.print(os);
|
adjVal.print(os);
|
||||||
|
|||||||
70
lib/SILOptimizer/Differentiation/AdjointValue.cpp
Normal file
70
lib/SILOptimizer/Differentiation/AdjointValue.cpp
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
//===--- AdjointValue.h - Helper class for differentiation ----*- C++ -*---===//
|
||||||
|
//
|
||||||
|
// This source file is part of the Swift.org open source project
|
||||||
|
//
|
||||||
|
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
|
||||||
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
||||||
|
//
|
||||||
|
// See https://swift.org/LICENSE.txt for license information
|
||||||
|
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// AdjointValue - a symbolic representation for adjoint values enabling
|
||||||
|
// efficient differentiation by avoiding zero materialization.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "differentiation"
|
||||||
|
|
||||||
|
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
|
||||||
|
|
||||||
|
void swift::autodiff::AdjointValue::print(llvm::raw_ostream &s) const {
|
||||||
|
switch (getKind()) {
|
||||||
|
case AdjointValueKind::Zero:
|
||||||
|
s << "Zero[" << getType() << ']';
|
||||||
|
break;
|
||||||
|
case AdjointValueKind::Aggregate:
|
||||||
|
s << "Aggregate[" << getType() << "](";
|
||||||
|
if (auto *decl = getType().getASTType()->getStructOrBoundGenericStruct()) {
|
||||||
|
interleave(
|
||||||
|
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
|
||||||
|
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
|
||||||
|
s << std::get<0>(elt)->getName() << ": ";
|
||||||
|
std::get<1>(elt).print(s);
|
||||||
|
},
|
||||||
|
[&s] { s << ", "; });
|
||||||
|
} else if (getType().is<TupleType>()) {
|
||||||
|
interleave(
|
||||||
|
getAggregateElements(),
|
||||||
|
[&s](const AdjointValue &elt) { elt.print(s); }, [&s] { s << ", "; });
|
||||||
|
} else {
|
||||||
|
llvm_unreachable("Invalid aggregate");
|
||||||
|
}
|
||||||
|
s << ')';
|
||||||
|
break;
|
||||||
|
case AdjointValueKind::Concrete:
|
||||||
|
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
|
||||||
|
break;
|
||||||
|
case AdjointValueKind::AddElement:
|
||||||
|
auto *addElementValue = getAddElementValue();
|
||||||
|
auto baseAdjoint = addElementValue->baseAdjoint;
|
||||||
|
auto eltToAdd = addElementValue->eltToAdd;
|
||||||
|
|
||||||
|
s << "AddElement[";
|
||||||
|
baseAdjoint.print(s);
|
||||||
|
|
||||||
|
s << ", Field(";
|
||||||
|
if (addElementValue->isTupleAdjoint()) {
|
||||||
|
s << addElementValue->getFieldIndex();
|
||||||
|
} else {
|
||||||
|
s << addElementValue->getFieldDecl()->getNameStr();
|
||||||
|
}
|
||||||
|
s << "), ";
|
||||||
|
|
||||||
|
eltToAdd.print(s);
|
||||||
|
|
||||||
|
s << "]";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
target_sources(swiftSILOptimizer PRIVATE
|
target_sources(swiftSILOptimizer PRIVATE
|
||||||
ADContext.cpp
|
ADContext.cpp
|
||||||
|
AdjointValue.cpp
|
||||||
Common.cpp
|
Common.cpp
|
||||||
DifferentiationInvoker.cpp
|
DifferentiationInvoker.cpp
|
||||||
JVPCloner.cpp
|
JVPCloner.cpp
|
||||||
|
|||||||
@@ -228,11 +228,12 @@ private:
|
|||||||
auto zeroVal = emitZeroDirect(val.getSwiftType(), loc);
|
auto zeroVal = emitZeroDirect(val.getSwiftType(), loc);
|
||||||
return zeroVal;
|
return zeroVal;
|
||||||
}
|
}
|
||||||
case AdjointValueKind::Aggregate:
|
|
||||||
llvm_unreachable(
|
|
||||||
"Tuples and structs are not supported in forward mode yet.");
|
|
||||||
case AdjointValueKind::Concrete:
|
case AdjointValueKind::Concrete:
|
||||||
return val.getConcreteValue();
|
return val.getConcreteValue();
|
||||||
|
case AdjointValueKind::Aggregate:
|
||||||
|
case AdjointValueKind::AddElement:
|
||||||
|
llvm_unreachable(
|
||||||
|
"Tuples and structs are not supported in forward mode yet.");
|
||||||
}
|
}
|
||||||
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
|
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -323,6 +323,15 @@ private:
|
|||||||
return AdjointValue::createAggregate(allocator, remapType(type), elements);
|
return AdjointValue::createAggregate(allocator, remapType(type), elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AdjointValue makeAddElementAdjointValue(AdjointValue baseAdjoint,
|
||||||
|
AdjointValue eltToAdd,
|
||||||
|
FieldLocator fieldLocator) {
|
||||||
|
auto *addElementValue =
|
||||||
|
new AddElementValue(baseAdjoint, eltToAdd, fieldLocator);
|
||||||
|
return AdjointValue::createAddElement(allocator, baseAdjoint.getType(),
|
||||||
|
addElementValue);
|
||||||
|
}
|
||||||
|
|
||||||
//--------------------------------------------------------------------------//
|
//--------------------------------------------------------------------------//
|
||||||
// Adjoint value materialization
|
// Adjoint value materialization
|
||||||
//--------------------------------------------------------------------------//
|
//--------------------------------------------------------------------------//
|
||||||
@@ -355,6 +364,19 @@ private:
|
|||||||
case AdjointValueKind::Concrete:
|
case AdjointValueKind::Concrete:
|
||||||
result = val.getConcreteValue();
|
result = val.getConcreteValue();
|
||||||
break;
|
break;
|
||||||
|
case AdjointValueKind::AddElement: {
|
||||||
|
auto adjointSILType = val.getAddElementValue()->baseAdjoint.getType();
|
||||||
|
auto *baseAdjAlloc = builder.createAllocStack(loc, adjointSILType);
|
||||||
|
materializeAdjointIndirect(val, baseAdjAlloc, loc);
|
||||||
|
|
||||||
|
auto baseAdjConcrete = recordTemporary(builder.emitLoadValueOperation(
|
||||||
|
loc, baseAdjAlloc, LoadOwnershipQualifier::Take));
|
||||||
|
|
||||||
|
builder.createDeallocStack(loc, baseAdjAlloc);
|
||||||
|
|
||||||
|
result = baseAdjConcrete;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (auto debugInfo = val.getDebugInfo())
|
if (auto debugInfo = val.getDebugInfo())
|
||||||
builder.createDebugValue(
|
builder.createDebugValue(
|
||||||
@@ -400,12 +422,62 @@ private:
|
|||||||
}
|
}
|
||||||
/// If adjoint value is concrete, it is already materialized. Store it in
|
/// If adjoint value is concrete, it is already materialized. Store it in
|
||||||
/// the destination address.
|
/// the destination address.
|
||||||
case AdjointValueKind::Concrete:
|
case AdjointValueKind::Concrete: {
|
||||||
auto concreteVal = val.getConcreteValue();
|
auto concreteVal = val.getConcreteValue();
|
||||||
builder.emitStoreValueOperation(loc, concreteVal, destAddress,
|
auto copyOfConcreteVal = builder.emitCopyValueOperation(loc, concreteVal);
|
||||||
|
builder.emitStoreValueOperation(loc, copyOfConcreteVal, destAddress,
|
||||||
StoreOwnershipQualifier::Init);
|
StoreOwnershipQualifier::Init);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case AdjointValueKind::AddElement: {
|
||||||
|
auto baseAdjoint = val;
|
||||||
|
auto baseAdjointType = baseAdjoint.getType();
|
||||||
|
|
||||||
|
// Current adjoint may be made up of layers of `AddElement` adjoints.
|
||||||
|
// We can iteratively gather the list of elements to add instead of making
|
||||||
|
// recursive calls to `materializeAdjointIndirect`.
|
||||||
|
SmallVector<AddElementValue *, 4> addEltAdjValues;
|
||||||
|
|
||||||
|
do {
|
||||||
|
auto addElementValue = baseAdjoint.getAddElementValue();
|
||||||
|
addEltAdjValues.push_back(addElementValue);
|
||||||
|
baseAdjoint = addElementValue->baseAdjoint;
|
||||||
|
assert(baseAdjointType == baseAdjoint.getType());
|
||||||
|
} while (baseAdjoint.getKind() == AdjointValueKind::AddElement);
|
||||||
|
|
||||||
|
materializeAdjointIndirect(baseAdjoint, destAddress, loc);
|
||||||
|
|
||||||
|
for (auto *addElementValue : addEltAdjValues) {
|
||||||
|
auto eltToAdd = addElementValue->eltToAdd;
|
||||||
|
|
||||||
|
SILValue baseAdjEltAddr;
|
||||||
|
if (baseAdjoint.getType().is<TupleType>()) {
|
||||||
|
baseAdjEltAddr = builder.createTupleElementAddr(
|
||||||
|
loc, destAddress, addElementValue->getFieldIndex());
|
||||||
|
} else {
|
||||||
|
baseAdjEltAddr = builder.createStructElementAddr(
|
||||||
|
loc, destAddress, addElementValue->getFieldDecl());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto eltToAddMaterialized = materializeAdjointDirect(eltToAdd, loc);
|
||||||
|
// Copy `eltToAddMaterialized` so we have a value with owned ownership
|
||||||
|
// semantics, required for using `eltToAddMaterialized` in a `store`
|
||||||
|
// instruction.
|
||||||
|
auto eltToAddMaterializedCopy =
|
||||||
|
builder.emitCopyValueOperation(loc, eltToAddMaterialized);
|
||||||
|
auto *eltToAddAlloc = builder.createAllocStack(loc, eltToAdd.getType());
|
||||||
|
builder.emitStoreValueOperation(loc, eltToAddMaterializedCopy,
|
||||||
|
eltToAddAlloc,
|
||||||
|
StoreOwnershipQualifier::Init);
|
||||||
|
|
||||||
|
builder.emitInPlaceAdd(loc, baseAdjEltAddr, eltToAddAlloc);
|
||||||
|
builder.createDestroyAddr(loc, eltToAddAlloc);
|
||||||
|
builder.createDeallocStack(loc, eltToAddAlloc);
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//--------------------------------------------------------------------------//
|
//--------------------------------------------------------------------------//
|
||||||
@@ -1095,6 +1167,10 @@ public:
|
|||||||
"Aggregate adjoint values should not occur for `struct` "
|
"Aggregate adjoint values should not occur for `struct` "
|
||||||
"instructions");
|
"instructions");
|
||||||
}
|
}
|
||||||
|
case AdjointValueKind::AddElement: {
|
||||||
|
llvm_unreachable(
|
||||||
|
"Adjoint of `StructInst` cannot be of kind `AddElement`");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -1150,41 +1226,29 @@ public:
|
|||||||
auto structTy = remapType(sei->getOperand()->getType()).getASTType();
|
auto structTy = remapType(sei->getOperand()->getType()).getASTType();
|
||||||
auto *tanField =
|
auto *tanField =
|
||||||
getTangentStoredProperty(getContext(), sei, structTy, getInvoker());
|
getTangentStoredProperty(getContext(), sei, structTy, getInvoker());
|
||||||
|
assert(tanField && "Invalid projections should have been diagnosed");
|
||||||
// Check the `struct_extract` operand's value tangent category.
|
// Check the `struct_extract` operand's value tangent category.
|
||||||
switch (getTangentValueCategory(sei->getOperand())) {
|
switch (getTangentValueCategory(sei->getOperand())) {
|
||||||
case SILValueCategory::Object: {
|
case SILValueCategory::Object: {
|
||||||
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
|
auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType();
|
||||||
auto *tangentVectorDecl =
|
|
||||||
tangentVectorTy->getStructOrBoundGenericStruct();
|
|
||||||
assert(tangentVectorDecl);
|
|
||||||
auto tangentVectorSILTy =
|
auto tangentVectorSILTy =
|
||||||
SILType::getPrimitiveObjectType(tangentVectorTy);
|
SILType::getPrimitiveObjectType(tangentVectorTy);
|
||||||
assert(tanField && "Invalid projections should have been diagnosed");
|
auto eltAdj = getAdjointValue(bb, sei);
|
||||||
// Accumulate adjoint for the `struct_extract` operand.
|
|
||||||
auto av = getAdjointValue(bb, sei);
|
switch (eltAdj.getKind()) {
|
||||||
switch (av.getKind()) {
|
case AdjointValueKind::Zero: {
|
||||||
case AdjointValueKind::Zero:
|
|
||||||
addAdjointValue(bb, sei->getOperand(),
|
addAdjointValue(bb, sei->getOperand(),
|
||||||
makeZeroAdjointValue(tangentVectorSILTy), loc);
|
makeZeroAdjointValue(tangentVectorSILTy), loc);
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
|
case AdjointValueKind::Aggregate:
|
||||||
case AdjointValueKind::Concrete:
|
case AdjointValueKind::Concrete:
|
||||||
case AdjointValueKind::Aggregate: {
|
case AdjointValueKind::AddElement: {
|
||||||
SmallVector<AdjointValue, 8> eltVals;
|
auto baseAdj = makeZeroAdjointValue(tangentVectorSILTy);
|
||||||
for (auto *field : tangentVectorDecl->getStoredProperties()) {
|
|
||||||
if (field == tanField) {
|
|
||||||
eltVals.push_back(av);
|
|
||||||
} else {
|
|
||||||
auto substMap = tangentVectorTy->getMemberSubstitutionMap(
|
|
||||||
field->getModuleContext(), field);
|
|
||||||
auto fieldTy = field->getInterfaceType().subst(substMap);
|
|
||||||
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
|
|
||||||
assert(fieldSILTy.isObject());
|
|
||||||
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
addAdjointValue(bb, sei->getOperand(),
|
addAdjointValue(bb, sei->getOperand(),
|
||||||
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
|
makeAddElementAdjointValue(baseAdj, eltAdj, tanField),
|
||||||
loc);
|
loc);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@@ -1320,7 +1384,7 @@ public:
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case AdjointValueKind::Aggregate:
|
case AdjointValueKind::Aggregate: {
|
||||||
unsigned adjIndex = 0;
|
unsigned adjIndex = 0;
|
||||||
for (auto i : range(ti->getElements().size())) {
|
for (auto i : range(ti->getElements().size())) {
|
||||||
if (!getTangentSpace(ti->getElement(i)->getType().getASTType()))
|
if (!getTangentSpace(ti->getElement(i)->getType().getASTType()))
|
||||||
@@ -1330,6 +1394,11 @@ public:
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case AdjointValueKind::AddElement: {
|
||||||
|
llvm_unreachable(
|
||||||
|
"Adjoint of `TupleInst` cannot be of kind `AddElement`");
|
||||||
|
}
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case SILValueCategory::Address: {
|
case SILValueCategory::Address: {
|
||||||
@@ -1358,42 +1427,42 @@ public:
|
|||||||
/// index corresponding to n
|
/// index corresponding to n
|
||||||
void visitTupleExtractInst(TupleExtractInst *tei) {
|
void visitTupleExtractInst(TupleExtractInst *tei) {
|
||||||
auto *bb = tei->getParent();
|
auto *bb = tei->getParent();
|
||||||
|
auto loc = tei->getLoc();
|
||||||
auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
|
auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
|
||||||
auto av = getAdjointValue(bb, tei);
|
auto eltAdj = getAdjointValue(bb, tei);
|
||||||
switch (av.getKind()) {
|
switch (eltAdj.getKind()) {
|
||||||
case AdjointValueKind::Zero:
|
case AdjointValueKind::Zero: {
|
||||||
addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy),
|
addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy),
|
||||||
tei->getLoc());
|
loc);
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
case AdjointValueKind::Aggregate:
|
case AdjointValueKind::Aggregate:
|
||||||
case AdjointValueKind::Concrete: {
|
case AdjointValueKind::Concrete:
|
||||||
|
case AdjointValueKind::AddElement: {
|
||||||
auto tupleTy = tei->getTupleType();
|
auto tupleTy = tei->getTupleType();
|
||||||
auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>();
|
auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>();
|
||||||
if (!tupleTanTupleTy) {
|
if (!tupleTanTupleTy) {
|
||||||
addAdjointValue(bb, tei->getOperand(), av, tei->getLoc());
|
addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
SmallVector<AdjointValue, 8> elements;
|
|
||||||
unsigned adjIdx = 0;
|
unsigned elements = 0;
|
||||||
for (unsigned i : range(tupleTy->getNumElements())) {
|
for (unsigned i : range(tupleTy->getNumElements())) {
|
||||||
if (!getTangentSpace(
|
if (!getTangentSpace(
|
||||||
tupleTy->getElement(i).getType()->getCanonicalType()))
|
tupleTy->getElement(i).getType()->getCanonicalType()))
|
||||||
continue;
|
continue;
|
||||||
if (tei->getFieldIndex() == i)
|
elements++;
|
||||||
elements.push_back(av);
|
|
||||||
else
|
|
||||||
elements.push_back(makeZeroAdjointValue(
|
|
||||||
getRemappedTangentType(SILType::getPrimitiveObjectType(
|
|
||||||
tupleTanTupleTy->getElementType(adjIdx++)
|
|
||||||
->getCanonicalType()))));
|
|
||||||
}
|
}
|
||||||
if (elements.size() == 1) {
|
|
||||||
addAdjointValue(bb, tei->getOperand(), elements.front(), tei->getLoc());
|
if (elements == 1) {
|
||||||
break;
|
addAdjointValue(bb, tei->getOperand(), eltAdj, loc);
|
||||||
|
} else {
|
||||||
|
auto baseAdj = makeZeroAdjointValue(tupleTanTy);
|
||||||
|
addAdjointValue(
|
||||||
|
bb, tei->getOperand(),
|
||||||
|
makeAddElementAdjointValue(baseAdj, eltAdj, tei->getFieldIndex()),
|
||||||
|
loc);
|
||||||
}
|
}
|
||||||
addAdjointValue(bb, tei->getOperand(),
|
|
||||||
makeAggregateAdjointValue(tupleTanTy, elements),
|
|
||||||
tei->getLoc());
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2719,7 +2788,12 @@ bool PullbackCloner::Implementation::runForSemanticMemberGetter() {
|
|||||||
addAdjointValue(origEntry, origSelf,
|
addAdjointValue(origEntry, origSelf,
|
||||||
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
|
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
|
||||||
pbLoc);
|
pbLoc);
|
||||||
|
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
case AdjointValueKind::AddElement:
|
||||||
|
llvm_unreachable("Adjoint of an aggregate type's field cannot be of kind "
|
||||||
|
"`AddElement`");
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -2959,7 +3033,7 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
|
|||||||
|
|
||||||
AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
|
AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
|
||||||
AdjointValue lhs, AdjointValue rhs, SILLocation loc) {
|
AdjointValue lhs, AdjointValue rhs, SILLocation loc) {
|
||||||
LLVM_DEBUG(getADDebugStream() << "Materializing adjoint directly.\nLHS: "
|
LLVM_DEBUG(getADDebugStream() << "Accumulating adjoint directly.\nLHS: "
|
||||||
<< lhs << "\nRHS: " << rhs << '\n');
|
<< lhs << "\nRHS: " << rhs << '\n');
|
||||||
switch (lhs.getKind()) {
|
switch (lhs.getKind()) {
|
||||||
// x
|
// x
|
||||||
@@ -2976,7 +3050,7 @@ AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
|
|||||||
case AdjointValueKind::Zero:
|
case AdjointValueKind::Zero:
|
||||||
return lhs;
|
return lhs;
|
||||||
// x + (y, z) => (x.0 + y, x.1 + z)
|
// x + (y, z) => (x.0 + y, x.1 + z)
|
||||||
case AdjointValueKind::Aggregate:
|
case AdjointValueKind::Aggregate: {
|
||||||
SmallVector<AdjointValue, 8> newElements;
|
SmallVector<AdjointValue, 8> newElements;
|
||||||
auto lhsTy = lhsVal->getType().getASTType();
|
auto lhsTy = lhsVal->getType().getASTType();
|
||||||
auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal);
|
auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal);
|
||||||
@@ -3004,13 +3078,24 @@ AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
|
|||||||
}
|
}
|
||||||
return makeAggregateAdjointValue(lhsVal->getType(), newElements);
|
return makeAggregateAdjointValue(lhsVal->getType(), newElements);
|
||||||
}
|
}
|
||||||
|
// x + (baseAdjoint, index, eltToAdd) => (x+baseAdjoint, index, eltToAdd)
|
||||||
|
case AdjointValueKind::AddElement: {
|
||||||
|
auto *addElementValue = rhs.getAddElementValue();
|
||||||
|
auto baseAdjoint = addElementValue->baseAdjoint;
|
||||||
|
auto eltToAdd = addElementValue->eltToAdd;
|
||||||
|
|
||||||
|
auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
|
||||||
|
return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
|
||||||
|
addElementValue->fieldLocator);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 0
|
// 0
|
||||||
case AdjointValueKind::Zero:
|
case AdjointValueKind::Zero:
|
||||||
// 0 + x => x
|
// 0 + x => x
|
||||||
return rhs;
|
return rhs;
|
||||||
// (x, y)
|
// (x, y)
|
||||||
case AdjointValueKind::Aggregate:
|
case AdjointValueKind::Aggregate: {
|
||||||
switch (rhs.getKind()) {
|
switch (rhs.getKind()) {
|
||||||
// (x, y) + z => (z.0 + x, z.1 + y)
|
// (x, y) + z => (z.0 + x, z.1 + y)
|
||||||
case AdjointValueKind::Concrete:
|
case AdjointValueKind::Concrete:
|
||||||
@@ -3026,6 +3111,50 @@ AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect(
|
|||||||
lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc));
|
lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc));
|
||||||
return makeAggregateAdjointValue(lhs.getType(), newElements);
|
return makeAggregateAdjointValue(lhs.getType(), newElements);
|
||||||
}
|
}
|
||||||
|
// (x.0, ..., x.n) + (baseAdjoint, index, eltToAdd) => (x + baseAdjoint,
|
||||||
|
// index, eltToAdd)
|
||||||
|
case AdjointValueKind::AddElement: {
|
||||||
|
auto *addElementValue = rhs.getAddElementValue();
|
||||||
|
auto baseAdjoint = addElementValue->baseAdjoint;
|
||||||
|
auto eltToAdd = addElementValue->eltToAdd;
|
||||||
|
auto newBaseAdjoint = accumulateAdjointsDirect(lhs, baseAdjoint, loc);
|
||||||
|
|
||||||
|
return makeAddElementAdjointValue(newBaseAdjoint, eltToAdd,
|
||||||
|
addElementValue->fieldLocator);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// (baseAdjoint, index, eltToAdd)
|
||||||
|
case AdjointValueKind::AddElement: {
|
||||||
|
switch (rhs.getKind()) {
|
||||||
|
case AdjointValueKind::Zero:
|
||||||
|
return lhs;
|
||||||
|
// (baseAdjoint, index, eltToAdd) + x => (x + baseAdjoint, index, eltToAdd)
|
||||||
|
case AdjointValueKind::Concrete:
|
||||||
|
// (baseAdjoint, index, eltToAdd) + (x.0, ..., x.n) => (x + baseAdjoint,
|
||||||
|
// index, eltToAdd)
|
||||||
|
case AdjointValueKind::Aggregate:
|
||||||
|
return accumulateAdjointsDirect(rhs, lhs, loc);
|
||||||
|
// (baseAdjoint1, index1, eltToAdd1) + (baseAdjoint2, index2, eltToAdd2)
|
||||||
|
// => ((baseAdjoint1 + baseAdjoint2, index1, eltToAdd1), index2, eltToAdd2)
|
||||||
|
case AdjointValueKind::AddElement: {
|
||||||
|
auto *addElementValueLhs = lhs.getAddElementValue();
|
||||||
|
auto baseAdjointLhs = addElementValueLhs->baseAdjoint;
|
||||||
|
auto eltToAddLhs = addElementValueLhs->eltToAdd;
|
||||||
|
|
||||||
|
auto *addElementValueRhs = rhs.getAddElementValue();
|
||||||
|
auto baseAdjointRhs = addElementValueRhs->baseAdjoint;
|
||||||
|
auto eltToAddRhs = addElementValueRhs->eltToAdd;
|
||||||
|
|
||||||
|
auto sumOfBaseAdjoints =
|
||||||
|
accumulateAdjointsDirect(baseAdjointLhs, baseAdjointRhs, loc);
|
||||||
|
auto newBaseAdjoint = makeAddElementAdjointValue(
|
||||||
|
sumOfBaseAdjoints, eltToAddLhs, addElementValueLhs->fieldLocator);
|
||||||
|
|
||||||
|
return makeAddElementAdjointValue(newBaseAdjoint, eltToAddRhs,
|
||||||
|
addElementValueRhs->fieldLocator);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
|
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
|
||||||
|
|||||||
186
test/AutoDiff/SILOptimizer/pullback_generation.sil
Normal file
186
test/AutoDiff/SILOptimizer/pullback_generation.sil
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
// Pullback generation tests written in SIL for features
|
||||||
|
// that may not be directly supported by the Swift frontend
|
||||||
|
|
||||||
|
// RUN: %target-sil-opt --differentiation -emit-sorted-sil %s 2>&1 | %FileCheck %s
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - `struct_extract`
|
||||||
|
// - Input to pullback has non-owned ownership semantics which requires copying
|
||||||
|
// this value to stack before lifetime-ending uses.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
sil_stage raw
|
||||||
|
|
||||||
|
import Builtin
|
||||||
|
import Swift
|
||||||
|
import SwiftShims
|
||||||
|
|
||||||
|
import _Differentiation
|
||||||
|
|
||||||
|
struct X {
|
||||||
|
@_hasStorage var a: Float { get set }
|
||||||
|
@_hasStorage var b: String { get set }
|
||||||
|
init(a: Float, b: String)
|
||||||
|
}
|
||||||
|
|
||||||
|
extension X : Differentiable, Equatable, AdditiveArithmetic {
|
||||||
|
public typealias TangentVector = X
|
||||||
|
mutating func move(by offset: X)
|
||||||
|
public static var zero: X { get }
|
||||||
|
public static func + (lhs: X, rhs: X) -> X
|
||||||
|
public static func - (lhs: X, rhs: X) -> X
|
||||||
|
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X, _ b: X) -> Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Y {
|
||||||
|
@_hasStorage var a: X { get set }
|
||||||
|
@_hasStorage var b: String { get set }
|
||||||
|
init(a: X, b: String)
|
||||||
|
}
|
||||||
|
|
||||||
|
extension Y : Differentiable, Equatable, AdditiveArithmetic {
|
||||||
|
public typealias TangentVector = Y
|
||||||
|
mutating func move(by offset: Y)
|
||||||
|
public static var zero: Y { get }
|
||||||
|
public static func + (lhs: Y, rhs: Y) -> Y
|
||||||
|
public static func - (lhs: Y, rhs: Y) -> Y
|
||||||
|
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Y, _ b: Y) -> Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
|
||||||
|
}
|
||||||
|
|
||||||
|
sil hidden [ossa] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
|
||||||
|
bb0(%0 : @guaranteed $Y):
|
||||||
|
%1 = struct_extract %0 : $Y, #Y.a
|
||||||
|
%2 = copy_value %1 : $X
|
||||||
|
return %2 : $X
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @$function_with_struct_extract_1TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned Y {
|
||||||
|
// CHECK: bb0(%0 : @guaranteed $X):
|
||||||
|
// CHECK: %1 = alloc_stack $Y
|
||||||
|
// CHECK: %2 = witness_method $Y, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %3 = metatype $@thick Y.Type
|
||||||
|
// CHECK: %4 = apply %2<Y>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %5 = struct_element_addr %1 : $*Y, #Y.a
|
||||||
|
|
||||||
|
// Since input parameter $0 has non-owned ownership semantics, it
|
||||||
|
// needs to be copied before a lifetime-ending use.
|
||||||
|
// CHECK: %6 = copy_value %0 : $X
|
||||||
|
|
||||||
|
// CHECK: %7 = alloc_stack $X
|
||||||
|
// CHECK: store %6 to [init] %7 : $*X
|
||||||
|
// CHECK: %9 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %10 = metatype $@thick X.Type
|
||||||
|
// CHECK: %11 = apply %9<X>(%5, %7, %10) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %7 : $*X
|
||||||
|
// CHECK: dealloc_stack %7 : $*X
|
||||||
|
// CHECK: %14 = load [take] %1 : $*Y
|
||||||
|
// CHECK: dealloc_stack %1 : $*Y
|
||||||
|
// CHECK: %16 = copy_value %14 : $Y
|
||||||
|
// CHECK: destroy_value %14 : $Y
|
||||||
|
// CHECK: return %16 : $Y
|
||||||
|
// CHECK: } // end sil function '$function_with_struct_extract_1TJpSpSr'
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - `tuple_extract`
|
||||||
|
// - Tuples as differentiable input arguments are not supported yet, so creating
|
||||||
|
// a basic test in SIL instead.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
|
||||||
|
}
|
||||||
|
|
||||||
|
sil hidden [ossa] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
|
||||||
|
bb0(%0 : $(Float, Float)):
|
||||||
|
%1 = tuple_extract %0 : $(Float, Float), 0
|
||||||
|
return %1 : $Float
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_1TJpSpSr : $@convention(thin) (Float) -> (Float, Float) {
|
||||||
|
// CHECK: bb0(%0 : $Float):
|
||||||
|
// CHECK: %1 = alloc_stack $(Float, Float)
|
||||||
|
// CHECK: %2 = tuple_element_addr %1 : $*(Float, Float), 0
|
||||||
|
// CHECK: %3 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %4 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %5 = apply %3<Float>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %6 = tuple_element_addr %1 : $*(Float, Float), 1
|
||||||
|
// CHECK: %7 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %8 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %9 = apply %7<Float>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %10 = tuple_element_addr %1 : $*(Float, Float), 0
|
||||||
|
// CHECK: %11 = alloc_stack $Float
|
||||||
|
// CHECK: store %0 to [trivial] %11 : $*Float
|
||||||
|
// CHECK: %13 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %14 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %15 = apply %13<Float>(%10, %11, %14) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %11 : $*Float
|
||||||
|
// CHECK: dealloc_stack %11 : $*Float
|
||||||
|
// CHECK: %18 = load [trivial] %1 : $*(Float, Float)
|
||||||
|
// CHECK: dealloc_stack %1 : $*(Float, Float)
|
||||||
|
// CHECK: return %18 : $(Float, Float)
|
||||||
|
// CHECK: } // end sil function 'function_with_tuple_extract_1TJpSpSr'
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - Inner values of concrete adjoints must be copied
|
||||||
|
// during direct materialization.
|
||||||
|
// - If the input to pullback BB has non-owned ownership semantics we cannot
|
||||||
|
// perform a lifetime-ending operation on it.
|
||||||
|
// - If the input to the pullback BB is an owned, non-trivial value we must
|
||||||
|
// copy it or there will be a double consume when all owned parameters are
|
||||||
|
// destroyed at the end of the basic block.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
|
||||||
|
}
|
||||||
|
|
||||||
|
sil hidden [ossa] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
|
||||||
|
bb0(%0 : @guaranteed $(X, X)):
|
||||||
|
%1 = tuple_extract %0 : $(X, X), 0
|
||||||
|
%2 = copy_value %1: $X
|
||||||
|
return %2 : $X
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_2TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned (X, X) {
|
||||||
|
// CHECK: bb0(%0 : @guaranteed $X):
|
||||||
|
// CHECK: %1 = alloc_stack $(X, X)
|
||||||
|
// CHECK: %2 = tuple_element_addr %1 : $*(X, X), 0
|
||||||
|
// CHECK: %3 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %4 = metatype $@thick X.Type
|
||||||
|
// CHECK: %5 = apply %3<X>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %6 = tuple_element_addr %1 : $*(X, X), 1
|
||||||
|
// CHECK: %7 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %8 = metatype $@thick X.Type
|
||||||
|
// CHECK: %9 = apply %7<X>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %10 = tuple_element_addr %1 : $*(X, X), 0
|
||||||
|
// CHECK: %11 = copy_value %0 : $X
|
||||||
|
// CHECK: %12 = alloc_stack $X
|
||||||
|
// CHECK: store %11 to [init] %12 : $*X
|
||||||
|
// CHECK: %14 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %15 = metatype $@thick X.Type
|
||||||
|
// CHECK: %16 = apply %14<X>(%10, %12, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %12 : $*X
|
||||||
|
// CHECK: dealloc_stack %12 : $*X
|
||||||
|
// CHECK: %19 = load [take] %1 : $*(X, X)
|
||||||
|
// CHECK: dealloc_stack %1 : $*(X, X)
|
||||||
|
// CHECK: %21 = copy_value %19 : $(X, X)
|
||||||
|
// CHECK: destroy_value %19 : $(X, X)
|
||||||
|
// CHECK: return %21 : $(X, X)
|
||||||
|
// CHECK: } // end sil function 'function_with_tuple_extract_2TJpSpSr'
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - `tuple_extract`
|
||||||
|
// - Adjoint of extracted element can be `AddElement`
|
||||||
|
// - Just need to make sure that we are able to generate a pullback
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
|
||||||
|
}
|
||||||
|
|
||||||
|
sil hidden [ossa] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
|
||||||
|
bb0(%0 : $((Float, Float), Float)):
|
||||||
|
%1 = tuple_extract %0 : $((Float, Float), Float), 0
|
||||||
|
%2 = tuple_extract %1 : $(Float, Float), 0
|
||||||
|
return %2 : $Float
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_3TJpSpSr : $@convention(thin) (Float) -> ((Float, Float), Float) {
|
||||||
200
test/AutoDiff/SILOptimizer/pullback_generation.swift
Normal file
200
test/AutoDiff/SILOptimizer/pullback_generation.swift
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
// Pullback generation tests written in Swift
|
||||||
|
|
||||||
|
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm --sil-print-after=differentiation %s 2>&1 | %FileCheck %s
|
||||||
|
|
||||||
|
import _Differentiation
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - `struct_extract`
|
||||||
|
// - Operand is piecewise materializable
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct PiecewiseMaterializable: Differentiable, Equatable, AdditiveArithmetic {
|
||||||
|
public typealias TangentVector = Self
|
||||||
|
|
||||||
|
var a: Float
|
||||||
|
var b: Double
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
func f1(v: PiecewiseMaterializable) -> Float {
|
||||||
|
v.a
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f11vSfAA23PiecewiseMaterializableV_tFTJpSpSr : $@convention(thin) (Float) -> PiecewiseMaterializable {
|
||||||
|
// CHECK: bb0(%0 : $Float):
|
||||||
|
// CHECK: %1 = alloc_stack $PiecewiseMaterializable
|
||||||
|
// CHECK: %2 = witness_method $PiecewiseMaterializable, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %3 = metatype $@thick PiecewiseMaterializable.Type
|
||||||
|
// CHECK: %4 = apply %2<PiecewiseMaterializable>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %5 = struct_element_addr %1 : $*PiecewiseMaterializable, #PiecewiseMaterializable.a
|
||||||
|
// CHECK: %6 = alloc_stack $Float
|
||||||
|
// CHECK: store %0 to [trivial] %6 : $*Float
|
||||||
|
// CHECK: %8 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %9 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %10 = apply %8<Float>(%5, %6, %9) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %6 : $*Float
|
||||||
|
// CHECK: dealloc_stack %6 : $*Float
|
||||||
|
// CHECK: %13 = load [trivial] %1 : $*PiecewiseMaterializable
|
||||||
|
// CHECK: dealloc_stack %1 : $*PiecewiseMaterializable
|
||||||
|
// CHECK: debug_value %13 : $PiecewiseMaterializable, let, name "v", argno 1
|
||||||
|
// CHECK: return %13 : $PiecewiseMaterializable
|
||||||
|
// CHECK: } // end sil function '$s19pullback_generation2f11vSfAA23PiecewiseMaterializableV_tFTJpSpSr'
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - `struct_extract`
|
||||||
|
// - Operand is non-piecewise materializable
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct NonPiecewiseMaterializable {
|
||||||
|
var a: Float
|
||||||
|
var b: String
|
||||||
|
}
|
||||||
|
|
||||||
|
extension NonPiecewiseMaterializable: Differentiable, Equatable, AdditiveArithmetic {
|
||||||
|
public typealias TangentVector = Self
|
||||||
|
mutating func move(by offset: TangentVector) {fatalError()}
|
||||||
|
public static var zero: Self {fatalError()}
|
||||||
|
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
|
||||||
|
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
func f2(v: NonPiecewiseMaterializable) -> Float {
|
||||||
|
v.a
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f21vSfAA26NonPiecewiseMaterializableV_tFTJpSpSr : $@convention(thin) (Float) -> @owned NonPiecewiseMaterializable {
|
||||||
|
// CHECK: bb0(%0 : $Float):
|
||||||
|
// CHECK: %1 = alloc_stack $NonPiecewiseMaterializable
|
||||||
|
// CHECK: %2 = witness_method $NonPiecewiseMaterializable, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %3 = metatype $@thick NonPiecewiseMaterializable.Type
|
||||||
|
// CHECK: %4 = apply %2<NonPiecewiseMaterializable>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %5 = struct_element_addr %1 : $*NonPiecewiseMaterializable, #NonPiecewiseMaterializable.a
|
||||||
|
// CHECK: %6 = alloc_stack $Float
|
||||||
|
// CHECK: store %0 to [trivial] %6 : $*Float
|
||||||
|
// CHECK: %8 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %9 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %10 = apply %8<Float>(%5, %6, %9) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %6 : $*Float
|
||||||
|
// CHECK: dealloc_stack %6 : $*Float
|
||||||
|
// CHECK: %13 = load [take] %1 : $*NonPiecewiseMaterializable
|
||||||
|
// CHECK: dealloc_stack %1 : $*NonPiecewiseMaterializable
|
||||||
|
// CHECK: debug_value %13 : $NonPiecewiseMaterializable, let, name "v", argno 1
|
||||||
|
// CHECK: %16 = copy_value %13 : $NonPiecewiseMaterializable
|
||||||
|
// CHECK: destroy_value %13 : $NonPiecewiseMaterializable
|
||||||
|
// CHECK: return %16 : $NonPiecewiseMaterializable
|
||||||
|
// CHECK: } // end sil function '$s19pullback_generation2f21vSfAA26NonPiecewiseMaterializableV_tFTJpSpSr'
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - `struct_extract`
|
||||||
|
// - Operand is non-piecewise materializable with an aggregate, piecewise
|
||||||
|
// materializable, differentiable field
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct NonPiecewiseMaterializableWithAggDifferentiableField {
|
||||||
|
var a: PiecewiseMaterializable
|
||||||
|
var b: String
|
||||||
|
}
|
||||||
|
|
||||||
|
extension NonPiecewiseMaterializableWithAggDifferentiableField: Differentiable, Equatable, AdditiveArithmetic {
|
||||||
|
public typealias TangentVector = Self
|
||||||
|
mutating func move(by offset: TangentVector) {fatalError()}
|
||||||
|
public static var zero: Self {fatalError()}
|
||||||
|
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
|
||||||
|
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
func f3(v: NonPiecewiseMaterializableWithAggDifferentiableField) -> PiecewiseMaterializable {
|
||||||
|
v.a
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f31vAA23PiecewiseMaterializableVAA03NondE26WithAggDifferentiableFieldV_tFTJpSpSr : $@convention(thin) (PiecewiseMaterializable) -> @owned NonPiecewiseMaterializableWithAggDifferentiableField {
|
||||||
|
// CHECK: bb0(%0 : $PiecewiseMaterializable):
|
||||||
|
// CHECK: %1 = alloc_stack $NonPiecewiseMaterializableWithAggDifferentiableField
|
||||||
|
// CHECK: %2 = witness_method $NonPiecewiseMaterializableWithAggDifferentiableField, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %3 = metatype $@thick NonPiecewiseMaterializableWithAggDifferentiableField.Type
|
||||||
|
// CHECK: %4 = apply %2<NonPiecewiseMaterializableWithAggDifferentiableField>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %5 = struct_element_addr %1 : $*NonPiecewiseMaterializableWithAggDifferentiableField, #NonPiecewiseMaterializableWithAggDifferentiableField.a
|
||||||
|
// CHECK: %6 = alloc_stack $PiecewiseMaterializable
|
||||||
|
// CHECK: store %0 to [trivial] %6 : $*PiecewiseMaterializable
|
||||||
|
// CHECK: %8 = witness_method $PiecewiseMaterializable, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %9 = metatype $@thick PiecewiseMaterializable.Type
|
||||||
|
// CHECK: %10 = apply %8<PiecewiseMaterializable>(%5, %6, %9) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %6 : $*PiecewiseMaterializable
|
||||||
|
// CHECK: dealloc_stack %6 : $*PiecewiseMaterializable
|
||||||
|
// CHECK: %13 = load [take] %1 : $*NonPiecewiseMaterializableWithAggDifferentiableField
|
||||||
|
// CHECK: dealloc_stack %1 : $*NonPiecewiseMaterializableWithAggDifferentiableField
|
||||||
|
// CHECK: debug_value %13 : $NonPiecewiseMaterializableWithAggDifferentiableField, let, name "v", argno 1
|
||||||
|
// CHECK: %16 = copy_value %13 : $NonPiecewiseMaterializableWithAggDifferentiableField
|
||||||
|
// CHECK: destroy_value %13 : $NonPiecewiseMaterializableWithAggDifferentiableField
|
||||||
|
// CHECK: return %16 : $NonPiecewiseMaterializableWithAggDifferentiableField
|
||||||
|
// CHECK: } // end sil function '$s19pullback_generation2f31vAA23PiecewiseMaterializableVAA03NondE26WithAggDifferentiableFieldV_tFTJpSpSr'
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - `struct_extract`
|
||||||
|
// - Adjoint of extracted element can be `AddElement`
|
||||||
|
// - Just need to make sure that we are able to generate a pullback for B.x's
|
||||||
|
// getter
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
struct A: Differentiable {
|
||||||
|
public var x: Float
|
||||||
|
}
|
||||||
|
|
||||||
|
struct B: Differentiable {
|
||||||
|
var y: A
|
||||||
|
|
||||||
|
public init(a: A) {
|
||||||
|
self.y = a
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
public var x: Float {
|
||||||
|
get { return self.y.x }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation1BV1xSfvgTJpSpSr : $@convention(thin) (Float) -> B.TangentVector {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - Inner values of concrete adjoints must be copied
|
||||||
|
// during indirect materialization
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct NonTrivial {
|
||||||
|
var x: Float
|
||||||
|
var y: String
|
||||||
|
}
|
||||||
|
|
||||||
|
extension NonTrivial: Differentiable, Equatable, AdditiveArithmetic {
|
||||||
|
public typealias TangentVector = Self
|
||||||
|
mutating func move(by offset: TangentVector) {fatalError()}
|
||||||
|
public static var zero: Self {fatalError()}
|
||||||
|
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
|
||||||
|
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
func f4(a: NonTrivial) -> Float {
|
||||||
|
var sum: Float = 0
|
||||||
|
for _ in 0..<1 {
|
||||||
|
sum += a.x
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
|
||||||
|
// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
|
||||||
|
// CHECK: %88 = alloc_stack $NonTrivial
|
||||||
|
|
||||||
|
// Non-trivial value must be copied or there will be a
|
||||||
|
// double consume when all owned parameters are destroyed
|
||||||
|
// at the end of the basic block.
|
||||||
|
// CHECK: %89 = copy_value %67 : $NonTrivial
|
||||||
|
|
||||||
|
// CHECK: store %89 to [init] %88 : $*NonTrivial
|
||||||
|
// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
|
||||||
|
// CHECK: %92 = alloc_stack $Float
|
||||||
|
// CHECK: store %86 to [trivial] %92 : $*Float
|
||||||
|
// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %95 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_value %67 : $NonTrivial
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm --sil-print-after=differentiation -Xllvm --debug-only=differentiation %s 2>&1 | %FileCheck %s
|
||||||
|
|
||||||
|
import _Differentiation
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - `struct_extract`
|
||||||
|
// - Nested AddElement adjoint kind is created due to the extraction of same
|
||||||
|
// field multiple times
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
struct SmallTestModel : Differentiable {
|
||||||
|
public var stored1: Float = 3.0
|
||||||
|
public var stored2: Float = 3.0
|
||||||
|
public var stored3: Float = 3.0
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
func multipleExtractionOfSameField(_ model: SmallTestModel) -> Float{
|
||||||
|
return model.stored1 + model.stored1 + model.stored1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: [AD] Accumulating adjoint directly.
|
||||||
|
// CHECK: LHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((%5, **%6**) = destructure_tuple %3 : $(Float, Float)
|
||||||
|
// CHECK: )]
|
||||||
|
// CHECK: RHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((%9, **%10**) = destructure_tuple %7 : $(Float, Float)
|
||||||
|
// CHECK: )]
|
||||||
|
|
||||||
|
// CHECK-LABEL: [AD] Accumulating adjoint directly.
|
||||||
|
// CHECK: LHS: AddElement[AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((%5, **%6**) = destructure_tuple %3 : $(Float, Float)
|
||||||
|
// CHECK: )], Field(stored1), Concrete[$Float]((%9, **%10**) = destructure_tuple %7 : $(Float, Float)
|
||||||
|
// CHECK: )]
|
||||||
|
// CHECK: RHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((**%9**, %10) = destructure_tuple %7 : $(Float, Float)
|
||||||
|
// CHECK: )]
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @$s46pullback_generation_nested_addelement_adjoints29multipleExtractionOfSameFieldySfAA14SmallTestModelVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> SmallTestModel.TangentVector {
|
||||||
|
// CHECK: bb0(%0 : $Float, %1 : @owned $@callee_guaranteed (Float) -> (Float, Float), %2 : @owned $@callee_guaranteed (Float) -> (Float, Float)):
|
||||||
|
// CHECK: %3 = apply %2(%0) : $@callee_guaranteed (Float) -> (Float, Float)
|
||||||
|
// CHECK: destroy_value %2 : $@callee_guaranteed (Float) -> (Float, Float)
|
||||||
|
// CHECK: (%5, %6) = destructure_tuple %3 : $(Float, Float)
|
||||||
|
// CHECK: %7 = apply %1(%5) : $@callee_guaranteed (Float) -> (Float, Float)
|
||||||
|
// CHECK: destroy_value %1 : $@callee_guaranteed (Float) -> (Float, Float)
|
||||||
|
// CHECK: (%9, %10) = destructure_tuple %7 : $(Float, Float)
|
||||||
|
// CHECK: %11 = alloc_stack $SmallTestModel.TangentVector
|
||||||
|
// CHECK: %12 = witness_method $SmallTestModel.TangentVector, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %13 = metatype $@thick SmallTestModel.TangentVector.Type
|
||||||
|
// CHECK: %14 = apply %12<SmallTestModel.TangentVector>(%11, %13) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %15 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored1
|
||||||
|
// CHECK: %16 = alloc_stack $Float
|
||||||
|
// CHECK: store %9 to [trivial] %16 : $*Float
|
||||||
|
// CHECK: %18 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %19 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %20 = apply %18<Float>(%15, %16, %19) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %16 : $*Float
|
||||||
|
// CHECK: dealloc_stack %16 : $*Float
|
||||||
|
// CHECK: %23 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored1
|
||||||
|
// CHECK: %24 = alloc_stack $Float
|
||||||
|
// CHECK: store %10 to [trivial] %24 : $*Float
|
||||||
|
// CHECK: %26 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %27 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %28 = apply %26<Float>(%23, %24, %27) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %24 : $*Float
|
||||||
|
// CHECK: dealloc_stack %24 : $*Float
|
||||||
|
// CHECK: %31 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored1
|
||||||
|
// CHECK: %32 = alloc_stack $Float
|
||||||
|
// CHECK: store %6 to [trivial] %32 : $*Float
|
||||||
|
// CHECK: %34 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %35 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %36 = apply %34<Float>(%31, %32, %35) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %32 : $*Float
|
||||||
|
// CHECK: dealloc_stack %32 : $*Float
|
||||||
|
// CHECK: %39 = load [trivial] %11 : $*SmallTestModel.TangentVector
|
||||||
|
// CHECK: dealloc_stack %11 : $*SmallTestModel.TangentVector
|
||||||
|
// CHECK: debug_value %39 : $SmallTestModel.TangentVector, let, name "model", argno 1
|
||||||
|
// CHECK: return %39 : $SmallTestModel.TangentVector
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pullback generation - `struct_extract`
|
||||||
|
// - Nested AddElement adjoint kind is created due to the extraction of multiple
|
||||||
|
// fields from the same struct
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
func multipleExtractionsFromSameStruct(_ model: SmallTestModel) -> Float{
|
||||||
|
return model.stored1 + model.stored2 + model.stored3
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: [AD] Accumulating adjoint directly.
|
||||||
|
// CHECK: LHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored3), Concrete[$Float]((%5, **%6**) = destructure_tuple %3 : $(Float, Float)
|
||||||
|
// CHECK: )]
|
||||||
|
// CHECK: RHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored2), Concrete[$Float]((%9, **%10**) = destructure_tuple %7 : $(Float, Float)
|
||||||
|
// CHECK: )]
|
||||||
|
|
||||||
|
// CHECK-LABEL: [AD] Accumulating adjoint directly.
|
||||||
|
// CHECK: LHS: AddElement[AddElement[Zero[$SmallTestModel.TangentVector], Field(stored3), Concrete[$Float]((%5, **%6**) = destructure_tuple %3 : $(Float, Float)
|
||||||
|
// CHECK: )], Field(stored2), Concrete[$Float]((%9, **%10**) = destructure_tuple %7 : $(Float, Float)
|
||||||
|
// CHECK: )]
|
||||||
|
// CHECK: RHS: AddElement[Zero[$SmallTestModel.TangentVector], Field(stored1), Concrete[$Float]((**%9**, %10) = destructure_tuple %7 : $(Float, Float)
|
||||||
|
// CHECK: )]
|
||||||
|
|
||||||
|
// CHECK-LABEL: sil private [ossa] @$s46pullback_generation_nested_addelement_adjoints33multipleExtractionsFromSameStructySfAA14SmallTestModelVFTJpSpSr : $@convention(thin) (Float, @owned @callee_guaranteed (Float) -> (Float, Float), @owned @callee_guaranteed (Float) -> (Float, Float)) -> SmallTestModel.TangentVector {
|
||||||
|
// CHECK: bb0(%0 : $Float, %1 : @owned $@callee_guaranteed (Float) -> (Float, Float), %2 : @owned $@callee_guaranteed (Float) -> (Float, Float)):
|
||||||
|
// CHECK: %3 = apply %2(%0) : $@callee_guaranteed (Float) -> (Float, Float)
|
||||||
|
// CHECK: destroy_value %2 : $@callee_guaranteed (Float) -> (Float, Float)
|
||||||
|
// CHECK: (%5, %6) = destructure_tuple %3 : $(Float, Float)
|
||||||
|
// CHECK: %7 = apply %1(%5) : $@callee_guaranteed (Float) -> (Float, Float)
|
||||||
|
// CHECK: destroy_value %1 : $@callee_guaranteed (Float) -> (Float, Float)
|
||||||
|
// CHECK: (%9, %10) = destructure_tuple %7 : $(Float, Float)
|
||||||
|
// CHECK: %11 = alloc_stack $SmallTestModel.TangentVector
|
||||||
|
// CHECK: %12 = witness_method $SmallTestModel.TangentVector, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %13 = metatype $@thick SmallTestModel.TangentVector.Type
|
||||||
|
// CHECK: %14 = apply %12<SmallTestModel.TangentVector>(%11, %13) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
|
||||||
|
// CHECK: %15 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored1
|
||||||
|
// CHECK: %16 = alloc_stack $Float
|
||||||
|
// CHECK: store %9 to [trivial] %16 : $*Float
|
||||||
|
// CHECK: %18 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %19 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %20 = apply %18<Float>(%15, %16, %19) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %16 : $*Float
|
||||||
|
// CHECK: dealloc_stack %16 : $*Float
|
||||||
|
// CHECK: %23 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored2
|
||||||
|
// CHECK: %24 = alloc_stack $Float
|
||||||
|
// CHECK: store %10 to [trivial] %24 : $*Float
|
||||||
|
// CHECK: %26 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %27 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %28 = apply %26<Float>(%23, %24, %27) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %24 : $*Float
|
||||||
|
// CHECK: dealloc_stack %24 : $*Float
|
||||||
|
// CHECK: %31 = struct_element_addr %11 : $*SmallTestModel.TangentVector, #SmallTestModel.TangentVector.stored3
|
||||||
|
// CHECK: %32 = alloc_stack $Float
|
||||||
|
// CHECK: store %6 to [trivial] %32 : $*Float
|
||||||
|
// CHECK: %34 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: %35 = metatype $@thick Float.Type
|
||||||
|
// CHECK: %36 = apply %34<Float>(%31, %32, %35) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
|
// CHECK: destroy_addr %32 : $*Float
|
||||||
|
// CHECK: dealloc_stack %32 : $*Float
|
||||||
|
// CHECK: %39 = load [trivial] %11 : $*SmallTestModel.TangentVector
|
||||||
|
// CHECK: dealloc_stack %11 : $*SmallTestModel.TangentVector
|
||||||
|
// CHECK: debug_value %39 : $SmallTestModel.TangentVector, let, name "model", argno 1
|
||||||
|
// CHECK: return %39 : $SmallTestModel.TangentVector
|
||||||
|
// CHECK: }
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
// RUN: %target-swift-frontend -emit-sil -O %s
|
||||||
|
|
||||||
|
import _Differentiation
|
||||||
|
|
||||||
|
// Issue #66522:
|
||||||
|
// Pullback generation for a function mapping a differentiable input type to
|
||||||
|
// one of its differentiable fields fails when the input's tangent vector contains
|
||||||
|
// non-differentiable fields.
|
||||||
|
public struct P<Value>: Differentiable
|
||||||
|
where
|
||||||
|
Value: Differentiable,
|
||||||
|
Value.TangentVector == Value,
|
||||||
|
Value: AdditiveArithmetic {
|
||||||
|
// `P` is its own `TangentVector`
|
||||||
|
public typealias TangentVector = Self
|
||||||
|
|
||||||
|
// Non-differentiable field in `P`'s `TangentVector`.
|
||||||
|
public let name: String = ""
|
||||||
|
var value: Value
|
||||||
|
}
|
||||||
|
|
||||||
|
extension P: Equatable, AdditiveArithmetic
|
||||||
|
where Value: AdditiveArithmetic {
|
||||||
|
public static var zero: Self {fatalError()}
|
||||||
|
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
|
||||||
|
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
|
||||||
|
}
|
||||||
|
|
||||||
|
@differentiable(reverse)
|
||||||
|
internal func testFunction(data: P<Double>) -> Double {
|
||||||
|
data.value
|
||||||
|
}
|
||||||
@@ -504,17 +504,14 @@ SimpleMathTests.test("Adjoint value accumulation for aggregate lhs and concrete
|
|||||||
// CHECK: [[ADJ_TUPLE:%.*]] = apply [[PB1]]([[DX]]) : $@callee_guaranteed (Float) -> (Float, Float)
|
// CHECK: [[ADJ_TUPLE:%.*]] = apply [[PB1]]([[DX]]) : $@callee_guaranteed (Float) -> (Float, Float)
|
||||||
// CHECK: ([[TMP0:%.*]], [[ADJ_CONCRETE:%.*]]) = destructure_tuple [[ADJ_TUPLE]] : $(Float, Float)
|
// CHECK: ([[TMP0:%.*]], [[ADJ_CONCRETE:%.*]]) = destructure_tuple [[ADJ_TUPLE]] : $(Float, Float)
|
||||||
// CHECK: [[TMP1:%.*]] = apply [[PB0]]([[TMP0]]) : $@callee_guaranteed (Float) -> SmallTestModel.TangentVector
|
// CHECK: [[TMP1:%.*]] = apply [[PB0]]([[TMP0]]) : $@callee_guaranteed (Float) -> SmallTestModel.TangentVector
|
||||||
// CHECK: [[ADJ_STRUCT_FIELD:%.*]] = destructure_struct [[TMP1]] : $SmallTestModel.TangentVector
|
// CHECK: [[TMP_RES_ADJ_STRUCT:%.*]] = alloc_stack $SmallTestModel.TangentVector
|
||||||
// CHECK: [[TMP_RES:%.*]] = alloc_stack $Float
|
// CHECK: store [[TMP1]] to [trivial] [[TMP_RES_ADJ_STRUCT]] : $*SmallTestModel.TangentVector
|
||||||
// CHECK: [[TMP_ADJ_STRUCT_FIELD:%.*]] = alloc_stack $Float
|
// CHECK: [[TMP_RES_ADJ_STRUCT_FIELD:%.*]] = struct_element_addr [[TMP_RES_ADJ_STRUCT]] : $*SmallTestModel.TangentVector, #{{.*}}SmallTestModel.TangentVector.stored
|
||||||
// CHECK: [[TMP_ADJ_CONCRETE:%.*]] = alloc_stack $Float
|
// CHECK: [[TMP_RES_ADJ_STRUCT_ADD_ELT:%.*]] = alloc_stack $Float
|
||||||
// CHECK: store [[ADJ_STRUCT_FIELD]] to [trivial] [[TMP_ADJ_STRUCT_FIELD]] : $*Float
|
// CHECK: store [[ADJ_CONCRETE]] to [trivial] [[TMP_RES_ADJ_STRUCT_ADD_ELT]] : $*Float
|
||||||
// CHECK: store [[ADJ_CONCRETE]] to [trivial] [[TMP_ADJ_CONCRETE]] : $*Float
|
// CHECK: [[PLUS_EQUAL:%.*]] = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
// CHECK: [[PLUS_EQUAL:%.*]] = witness_method $Float, #AdditiveArithmetic."+"
|
// CHECK: {{.*}} = apply [[PLUS_EQUAL]]<Float>([[TMP_RES_ADJ_STRUCT_FIELD]], [[TMP_RES_ADJ_STRUCT_ADD_ELT]], {{.*}}) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
|
||||||
// CHECK: %{{.*}} = apply [[PLUS_EQUAL]]<Float>([[TMP_RES]], [[TMP_ADJ_CONCRETE]], [[TMP_ADJ_STRUCT_FIELD]], {{.*}})
|
// CHECK: [[RES_STRUCT:%.*]] = load [trivial] [[TMP_RES_ADJ_STRUCT]] : $*SmallTestModel.TangentVector
|
||||||
// CHECK: [[RES:%.*]] = load [trivial] [[TMP_RES]] : $*Float
|
|
||||||
// CHECK: [[RES_STRUCT:%.*]] = struct $SmallTestModel.TangentVector ([[RES]] : $Float)
|
|
||||||
// CHECK: return [[RES_STRUCT]] : $SmallTestModel.TangentVector
|
// CHECK: return [[RES_STRUCT]] : $SmallTestModel.TangentVector
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
runAllTests()
|
runAllTests()
|
||||||
|
|||||||
Reference in New Issue
Block a user