Merge remote-tracking branch 'origin/master' into master-next

This commit is contained in:
swift_jenkins
2019-12-13 19:20:21 -08:00
4 changed files with 155 additions and 54 deletions

View File

@@ -897,6 +897,20 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
break;
}
case DAK_Derivative: {
Printer.printAttrName("@derivative");
Printer << "(of: ";
auto *attr = cast<DerivativeAttr>(this);
Printer << attr->getOriginalFunctionName().Name;
auto *derivative = cast<AbstractFunctionDecl>(D);
auto diffParamsString = getDifferentiationParametersClauseString(
derivative, attr->getParameterIndices(), attr->getParsedParameters());
if (!diffParamsString.empty())
Printer << ", " << diffParamsString;
Printer << ')';
break;
}
case DAK_ImplicitlySynthesizesNestedRequirement:
Printer.printAttrName("@_implicitly_synthesizes_nested_requirement");
Printer << "(\"" << cast<ImplicitlySynthesizesNestedRequirementAttr>(this)->Value << "\")";

View File

@@ -2124,6 +2124,21 @@ getActualReadWriteImplKind(unsigned rawKind) {
return None;
}
/// Translate from the serialization DifferentiabilityKind enumerators, which
/// are guaranteed to be stable, to the AST ones.
static Optional<swift::AutoDiffDerivativeFunctionKind>
getActualAutoDiffDerivativeFunctionKind(uint8_t raw) {
switch (serialization::AutoDiffDerivativeFunctionKind(raw)) {
#define CASE(ID) \
case serialization::AutoDiffDerivativeFunctionKind::ID: \
return {swift::AutoDiffDerivativeFunctionKind::ID};
CASE(JVP)
CASE(VJP)
#undef CASE
}
return None;
}
void ModuleFile::configureStorage(AbstractStorageDecl *decl,
uint8_t rawOpaqueReadOwnership,
uint8_t rawReadImplKind,
@@ -4164,6 +4179,37 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
break;
}
case decls_block::Derivative_DECL_ATTR: {
bool isImplicit;
uint64_t origNameId;
DeclID origDeclId;
uint64_t rawDerivativeKind;
ArrayRef<uint64_t> parameters;
serialization::decls_block::DerivativeDeclAttrLayout::readRecord(
scratch, isImplicit, origNameId, origDeclId, rawDerivativeKind,
parameters);
DeclNameRefWithLoc origName{
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
auto derivativeKind =
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
if (!derivativeKind)
MF.fatal();
llvm::SmallBitVector parametersBitVector(parameters.size());
for (unsigned i : indices(parameters))
parametersBitVector[i] = parameters[i];
auto *indices = IndexSubset::get(ctx, parametersBitVector);
auto *derivAttr = DerivativeAttr::create(
ctx, isImplicit, SourceLoc(), SourceRange(), origName, indices);
derivAttr->setOriginalFunction(origDecl);
derivAttr->setDerivativeKind(*derivativeKind);
Attr = derivAttr;
break;
}
case decls_block::ImplicitlySynthesizesNestedRequirement_DECL_ATTR: {
serialization::decls_block::ImplicitlySynthesizesNestedRequirementDeclAttrLayout
::readRecord(scratch);

View File

@@ -2349,21 +2349,17 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
auto derivativeKind =
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
auto paramIndices = attr->getParameterIndices();
// NOTE(TF-837): `@derivative` attribute serialization is blocked by
// `@derivative` attribute type-checking (TF-829), which resolves
// parameter indices (`IndexSubset *`).
if (!paramIndices)
return;
assert(paramIndices && "Parameter indices must be resolved");
auto *parameterIndices = attr->getParameterIndices();
assert(parameterIndices && "Parameter indices must be resolved");
SmallVector<bool, 4> indices;
for (unsigned i : range(paramIndices->getCapacity()))
indices.push_back(paramIndices->contains(i));
for (unsigned i : range(parameterIndices->getCapacity()))
indices.push_back(parameterIndices->contains(i));
DerivativeDeclAttrLayout::emitRecord(
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
origDeclID, derivativeKind, indices);
return;
}
case DAK_ImplicitlySynthesizesNestedRequirement: {
auto *theAttr = cast<ImplicitlySynthesizesNestedRequirementAttr>(DA);
auto abbrCode = S.DeclTypeAbbrCodes[ImplicitlySynthesizesNestedRequirementDeclAttrLayout::Code];

View File

@@ -1,63 +1,108 @@
// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
// RUN: %target-swift-frontend -enable-experimental-differentiable-programming %s -emit-module -parse-as-library -o %t
// RUN: llvm-bcanalyzer %t/derivative_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/derivative_attr.swiftmodule -o - | %FileCheck %s
// RUN: %target-sil-opt -enable-experimental-differentiable-programming -disable-sil-linking -enable-sil-verify-all %t/derivative_attr.swiftmodule -o - | %FileCheck %s
// BCANALYZER-NOT: UnknownCode
// TODO(TF-837): Enable this test.
// Blocked by TF-829: `@derivative` attribute type-checking.
// XFAIL: *
// REQUIRES: differentiable_programming
func add(x: Float, y: Float) -> Float {
return x + y
}
// CHECK: @derivative(of: add, wrt: x)
@derivative(of: add, wrt: x)
func jvpAddWrtX(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x + y, { $0 })
}
// CHECK: @derivative(of: add, wrt: (x, y))
@derivative(of: add)
func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
import _Differentiation
// Dummy `Differentiable`-conforming type.
struct S: Differentiable & AdditiveArithmetic {
static var zero: S { S() }
static func + (_: S, _: S) -> S { S() }
static func - (_: S, _: S) -> S { S() }
typealias TangentVector = S
}
func generic<T : Numeric>(x: T) -> T {
return x
// Test top-level functions.
func top1(_ x: S) -> S {
x
}
// CHECK: @derivative(of: generic, wrt: x)
@derivative(of: generic)
func vjpGeneric<T>(x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
where T : Numeric, T : Differentiable
{
return (x, { v in v })
// CHECK: @derivative(of: top1, wrt: x)
@derivative(of: top1, wrt: x)
func derivativeTop1(_ x: S) -> (value: S, differential: (S) -> S) {
(x, { $0 })
}
protocol InstanceMethod : Differentiable {
func foo(_ x: Self) -> Self
func bar<T : Differentiable>(_ x: T) -> Self
func top2<T, U>(_ x: T, _ i: Int, _ y: U) -> U {
y
}
extension InstanceMethod {
// CHECK: @derivative(of: foo, wrt: (self, x))
@derivative(of: foo)
func vjpFoo(x: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (x, { ($0, $0) })
// CHECK: @derivative(of: top2, wrt: (x, y))
@derivative(of: top2, wrt: (x, y))
func derivativeTop2<T: Differentiable, U: Differentiable>(
_ x: T, _ i: Int, _ y: U
) -> (value: U, differential: (T.TangentVector, U.TangentVector) -> U.TangentVector) {
(y, { (dx, dy) in dy })
}
// Test instance methods.
extension S {
func instanceMethod(_ x: S) -> S {
self
}
// CHECK: @derivative(of: bar, wrt: (self, x))
@derivative(of: bar, wrt: (self, x))
func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T) -> TangentVector)
where T == T.TangentVector
{
return (self, { dself, dx in dself })
// CHECK: @derivative(of: instanceMethod, wrt: x)
@derivative(of: instanceMethod, wrt: x)
func derivativeInstanceMethodWrtX(_ x: S) -> (value: S, differential: (S) -> S) {
(self, { _ in .zero })
}
// CHECK: @derivative(of: bar, wrt: (self, x))
@derivative(of: bar, wrt: (self, x))
func vjpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, pullback: (TangentVector) -> (TangentVector, T))
where T == T.TangentVector
{
return (self, { v in (v, .zero) })
// CHECK: @derivative(of: instanceMethod, wrt: self)
@derivative(of: instanceMethod, wrt: self)
func derivativeInstanceMethodWrtSelf(_ x: S) -> (value: S, differential: (S) -> S) {
(self, { $0 })
}
// CHECK: @derivative(of: instanceMethod, wrt: (self, x))
@derivative(of: instanceMethod, wrt: (self, x))
func derivativeInstanceMethodWrtAll(_ x: S) -> (value: S, differential: (S, S) -> S) {
(self, { (dself, dx) in self })
}
}
// Test static methods.
extension S {
static func staticMethod(_ x: S) -> S {
x
}
// CHECK: @derivative(of: staticMethod, wrt: x)
@derivative(of: staticMethod, wrt: x)
static func derivativeStaticMethod(_ x: S) -> (value: S, differential: (S) -> S) {
(x, { $0 })
}
}
// Test computed properties.
extension S {
var computedProperty: S {
self
}
// CHECK: @derivative(of: computedProperty, wrt: self)
@derivative(of: computedProperty, wrt: self)
func derivativeProperty() -> (value: S, differential: (S) -> S) {
(self, { $0 })
}
}
// Test subscripts.
extension S {
subscript<T: Differentiable>(x: T) -> S {
self
}
// CHECK: @derivative(of: subscript, wrt: self)
@derivative(of: subscript(_:), wrt: self)
func derivativeSubscript<T: Differentiable>(x: T) -> (value: S, differential: (S) -> S) {
(self, { $0 })
}
}