mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Merge remote-tracking branch 'origin/master' into master-next
This commit is contained in:
@@ -897,6 +897,20 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
|
||||
break;
|
||||
}
|
||||
|
||||
case DAK_Derivative: {
|
||||
Printer.printAttrName("@derivative");
|
||||
Printer << "(of: ";
|
||||
auto *attr = cast<DerivativeAttr>(this);
|
||||
Printer << attr->getOriginalFunctionName().Name;
|
||||
auto *derivative = cast<AbstractFunctionDecl>(D);
|
||||
auto diffParamsString = getDifferentiationParametersClauseString(
|
||||
derivative, attr->getParameterIndices(), attr->getParsedParameters());
|
||||
if (!diffParamsString.empty())
|
||||
Printer << ", " << diffParamsString;
|
||||
Printer << ')';
|
||||
break;
|
||||
}
|
||||
|
||||
case DAK_ImplicitlySynthesizesNestedRequirement:
|
||||
Printer.printAttrName("@_implicitly_synthesizes_nested_requirement");
|
||||
Printer << "(\"" << cast<ImplicitlySynthesizesNestedRequirementAttr>(this)->Value << "\")";
|
||||
|
||||
@@ -2124,6 +2124,21 @@ getActualReadWriteImplKind(unsigned rawKind) {
|
||||
return None;
|
||||
}
|
||||
|
||||
/// Translate from the serialization DifferentiabilityKind enumerators, which
|
||||
/// are guaranteed to be stable, to the AST ones.
|
||||
static Optional<swift::AutoDiffDerivativeFunctionKind>
|
||||
getActualAutoDiffDerivativeFunctionKind(uint8_t raw) {
|
||||
switch (serialization::AutoDiffDerivativeFunctionKind(raw)) {
|
||||
#define CASE(ID) \
|
||||
case serialization::AutoDiffDerivativeFunctionKind::ID: \
|
||||
return {swift::AutoDiffDerivativeFunctionKind::ID};
|
||||
CASE(JVP)
|
||||
CASE(VJP)
|
||||
#undef CASE
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
void ModuleFile::configureStorage(AbstractStorageDecl *decl,
|
||||
uint8_t rawOpaqueReadOwnership,
|
||||
uint8_t rawReadImplKind,
|
||||
@@ -4164,6 +4179,37 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
|
||||
break;
|
||||
}
|
||||
|
||||
case decls_block::Derivative_DECL_ATTR: {
|
||||
bool isImplicit;
|
||||
uint64_t origNameId;
|
||||
DeclID origDeclId;
|
||||
uint64_t rawDerivativeKind;
|
||||
ArrayRef<uint64_t> parameters;
|
||||
|
||||
serialization::decls_block::DerivativeDeclAttrLayout::readRecord(
|
||||
scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind,
|
||||
parameters);
|
||||
|
||||
DeclNameRefWithLoc origName{
|
||||
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
|
||||
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
|
||||
auto derivativeKind =
|
||||
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
|
||||
if (!derivativeKind)
|
||||
MF.fatal();
|
||||
llvm::SmallBitVector parametersBitVector(parameters.size());
|
||||
for (unsigned i : indices(parameters))
|
||||
parametersBitVector[i] = parameters[i];
|
||||
auto *indices = IndexSubset::get(ctx, parametersBitVector);
|
||||
|
||||
auto *derivAttr = DerivativeAttr::create(
|
||||
ctx, isImplicit, SourceLoc(), SourceRange(), origName, indices);
|
||||
derivAttr->setOriginalFunction(origDecl);
|
||||
derivAttr->setDerivativeKind(*derivativeKind);
|
||||
Attr = derivAttr;
|
||||
break;
|
||||
}
|
||||
|
||||
case decls_block::ImplicitlySynthesizesNestedRequirement_DECL_ATTR: {
|
||||
serialization::decls_block::ImplicitlySynthesizesNestedRequirementDeclAttrLayout
|
||||
::readRecord(scratch);
|
||||
|
||||
@@ -2349,21 +2349,17 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
|
||||
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
|
||||
auto derivativeKind =
|
||||
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
|
||||
auto paramIndices = attr->getParameterIndices();
|
||||
// NOTE(TF-837): `@derivative` attribute serialization is blocked by
|
||||
// `@derivative` attribute type-checking (TF-829), which resolves
|
||||
// parameter indices (`IndexSubset *`).
|
||||
if (!paramIndices)
|
||||
return;
|
||||
assert(paramIndices && "Parameter indices must be resolved");
|
||||
auto *parameterIndices = attr->getParameterIndices();
|
||||
assert(parameterIndices && "Parameter indices must be resolved");
|
||||
SmallVector<bool, 4> indices;
|
||||
for (unsigned i : range(paramIndices->getCapacity()))
|
||||
indices.push_back(paramIndices->contains(i));
|
||||
for (unsigned i : range(parameterIndices->getCapacity()))
|
||||
indices.push_back(parameterIndices->contains(i));
|
||||
DerivativeDeclAttrLayout::emitRecord(
|
||||
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
|
||||
origDeclID, derivativeKind, indices);
|
||||
return;
|
||||
}
|
||||
|
||||
case DAK_ImplicitlySynthesizesNestedRequirement: {
|
||||
auto *theAttr = cast<ImplicitlySynthesizesNestedRequirementAttr>(DA);
|
||||
auto abbrCode = S.DeclTypeAbbrCodes[ImplicitlySynthesizesNestedRequirementDeclAttrLayout::Code];
|
||||
|
||||
@@ -1,63 +1,108 @@
|
||||
// RUN: %empty-directory(%t)
|
||||
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
|
||||
// RUN: %target-swift-frontend -enable-experimental-differentiable-programming %s -emit-module -parse-as-library -o %t
|
||||
// RUN: llvm-bcanalyzer %t/derivative_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
|
||||
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/derivative_attr.swiftmodule -o - | %FileCheck %s
|
||||
// RUN: %target-sil-opt -enable-experimental-differentiable-programming -disable-sil-linking -enable-sil-verify-all %t/derivative_attr.swiftmodule -o - | %FileCheck %s
|
||||
|
||||
// BCANALYZER-NOT: UnknownCode
|
||||
|
||||
// TODO(TF-837): Enable this test.
|
||||
// Blocked by TF-829: `@derivative` attribute type-checking.
|
||||
// XFAIL: *
|
||||
// REQUIRES: differentiable_programming
|
||||
|
||||
func add(x: Float, y: Float) -> Float {
|
||||
return x + y
|
||||
}
|
||||
// CHECK: @derivative(of: add, wrt: x)
|
||||
@derivative(of: add, wrt: x)
|
||||
func jvpAddWrtX(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) {
|
||||
return (x + y, { $0 })
|
||||
}
|
||||
// CHECK: @derivative(of: add, wrt: (x, y))
|
||||
@derivative(of: add)
|
||||
func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
|
||||
return (x + y, { ($0, $0) })
|
||||
import _Differentiation
|
||||
|
||||
// Dummy `Differentiable`-conforming type.
|
||||
struct S: Differentiable & AdditiveArithmetic {
|
||||
static var zero: S { S() }
|
||||
static func + (_: S, _: S) -> S { S() }
|
||||
static func - (_: S, _: S) -> S { S() }
|
||||
typealias TangentVector = S
|
||||
}
|
||||
|
||||
func generic<T : Numeric>(x: T) -> T {
|
||||
return x
|
||||
// Test top-level functions.
|
||||
|
||||
func top1(_ x: S) -> S {
|
||||
x
|
||||
}
|
||||
// CHECK: @derivative(of: generic, wrt: x)
|
||||
@derivative(of: generic)
|
||||
func vjpGeneric<T>(x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
|
||||
where T : Numeric, T : Differentiable
|
||||
{
|
||||
return (x, { v in v })
|
||||
// CHECK: @derivative(of: top1, wrt: x)
|
||||
@derivative(of: top1, wrt: x)
|
||||
func derivativeTop1(_ x: S) -> (value: S, differential: (S) -> S) {
|
||||
(x, { $0 })
|
||||
}
|
||||
|
||||
protocol InstanceMethod : Differentiable {
|
||||
func foo(_ x: Self) -> Self
|
||||
func bar<T : Differentiable>(_ x: T) -> Self
|
||||
func top2<T, U>(_ x: T, _ i: Int, _ y: U) -> U {
|
||||
y
|
||||
}
|
||||
extension InstanceMethod {
|
||||
// CHECK: @derivative(of: foo, wrt: (self, x))
|
||||
@derivative(of: foo)
|
||||
func vjpFoo(x: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
|
||||
return (x, { ($0, $0) })
|
||||
// CHECK: @derivative(of: top2, wrt: (x, y))
|
||||
@derivative(of: top2, wrt: (x, y))
|
||||
func derivativeTop2<T: Differentiable, U: Differentiable>(
|
||||
_ x: T, _ i: Int, _ y: U
|
||||
) -> (value: U, differential: (T.TangentVector, U.TangentVector) -> U.TangentVector) {
|
||||
(y, { (dx, dy) in dy })
|
||||
}
|
||||
|
||||
// Test instance methods.
|
||||
|
||||
extension S {
|
||||
func instanceMethod(_ x: S) -> S {
|
||||
self
|
||||
}
|
||||
|
||||
// CHECK: @derivative(of: bar, wrt: (self, x))
|
||||
@derivative(of: bar, wrt: (self, x))
|
||||
func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T) -> TangentVector)
|
||||
where T == T.TangentVector
|
||||
{
|
||||
return (self, { dself, dx in dself })
|
||||
// CHECK: @derivative(of: instanceMethod, wrt: x)
|
||||
@derivative(of: instanceMethod, wrt: x)
|
||||
func derivativeInstanceMethodWrtX(_ x: S) -> (value: S, differential: (S) -> S) {
|
||||
(self, { _ in .zero })
|
||||
}
|
||||
|
||||
// CHECK: @derivative(of: bar, wrt: (self, x))
|
||||
@derivative(of: bar, wrt: (self, x))
|
||||
func vjpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, pullback: (TangentVector) -> (TangentVector, T))
|
||||
where T == T.TangentVector
|
||||
{
|
||||
return (self, { v in (v, .zero) })
|
||||
// CHECK: @derivative(of: instanceMethod, wrt: self)
|
||||
@derivative(of: instanceMethod, wrt: self)
|
||||
func derivativeInstanceMethodWrtSelf(_ x: S) -> (value: S, differential: (S) -> S) {
|
||||
(self, { $0 })
|
||||
}
|
||||
|
||||
// CHECK: @derivative(of: instanceMethod, wrt: (self, x))
|
||||
@derivative(of: instanceMethod, wrt: (self, x))
|
||||
func derivativeInstanceMethodWrtAll(_ x: S) -> (value: S, differential: (S, S) -> S) {
|
||||
(self, { (dself, dx) in self })
|
||||
}
|
||||
}
|
||||
|
||||
// Test static methods.
|
||||
|
||||
extension S {
|
||||
static func staticMethod(_ x: S) -> S {
|
||||
x
|
||||
}
|
||||
|
||||
// CHECK: @derivative(of: staticMethod, wrt: x)
|
||||
@derivative(of: staticMethod, wrt: x)
|
||||
static func derivativeStaticMethod(_ x: S) -> (value: S, differential: (S) -> S) {
|
||||
(x, { $0 })
|
||||
}
|
||||
}
|
||||
|
||||
// Test computed properties.
|
||||
|
||||
extension S {
|
||||
var computedProperty: S {
|
||||
self
|
||||
}
|
||||
|
||||
// CHECK: @derivative(of: computedProperty, wrt: self)
|
||||
@derivative(of: computedProperty, wrt: self)
|
||||
func derivativeProperty() -> (value: S, differential: (S) -> S) {
|
||||
(self, { $0 })
|
||||
}
|
||||
}
|
||||
|
||||
// Test subscripts.
|
||||
|
||||
extension S {
|
||||
subscript<T: Differentiable>(x: T) -> S {
|
||||
self
|
||||
}
|
||||
|
||||
// CHECK: @derivative(of: subscript, wrt: self)
|
||||
@derivative(of: subscript(_:), wrt: self)
|
||||
func derivativeSubscript<T: Differentiable>(x: T) -> (value: S, differential: (S) -> S) {
|
||||
(self, { $0 })
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user