[AutoDiff] Make Differentiable derivation support property wrappers.

Differentiable conformance derivation now "peers through" property wrappers.

Synthesized TangentVector structs contain wrapped properties' TangentVectors as
stored properties, not wrappers' TangentVectors.

Property wrapper types are not required to conform to `Differentiable`.
Property wrapper types are required to provide `wrappedValue.set`, which is
needed to synthesize `mutating func move(along:)`.

```
import _Differentiation

@propertyWrapper
struct Wrapper<Value> {
  var wrappedValue: Value
}

struct Struct: Differentiable {
  @Wrapper var x: Float = 0

  // Compiler now synthesizes:
  // struct TangentVector: Differentiable & AdditiveArithmetic {
  //   var x: Float
  //   ...
  // }
}
```

Resolves SR-12638.
This commit is contained in:
Dan Zheng
2020-04-20 21:29:53 -07:00
parent 37657d0e06
commit d96b73a827
5 changed files with 158 additions and 54 deletions

View File

@@ -2734,13 +2734,21 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
"stored property %0 has no derivative because %1 does not conform to "
"'Differentiable'; add an explicit '@noDerivative' attribute"
"%select{|, or conform %2 to 'AdditiveArithmetic'}3",
(Identifier, Type, Identifier, bool))
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
(/*propName*/ Identifier, /*propType*/ Type, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
"requires all stored properties to be mutable; use 'var' instead, or add "
"an explicit '@noDerivative' attribute"
"requires all stored properties not marked with `@noDerivative` to be "
"mutable; add an explicit '@noDerivative' attribute"
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
(Identifier, Identifier, bool))
(/*wrapperType*/ StringRef, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
"requires all stored properties not marked with `@noDerivative` to be "
"mutable; use 'var' instead, or add an explicit '@noDerivative' attribute"
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))
NOTE(codable_extraneous_codingkey_case_here,none,
"CodingKey case %0 does not match any stored properties", (Identifier))

View File

@@ -25,6 +25,7 @@
#include "swift/AST/ParameterList.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/PropertyWrappers.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Types.h"
#include "DerivedConformances.h"
@@ -39,14 +40,23 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
auto &C = nominal->getASTContext();
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
for (auto *vd : nominal->getStoredProperties()) {
// Peer through property wrappers: use original wrapped properties instead.
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
// Skip property wrappers that do not define `wrappedValue.set`.
// `mutating func move(along:)` cannot be synthesized to update these
// properties.
auto *wrapperDecl =
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
auto *wrappedValueDecl =
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
if (!wrappedValueDecl->getAccessor(AccessorKind::Set))
continue;
// Use the original wrapped property.
vd = originalProperty;
}
// Skip stored properties with `@noDerivative` attribute.
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// For property wrapper backing storage properties, skip if original
// property has `@noDerivative` attribute.
if (auto *originalProperty = vd->getOriginalWrappedProperty())
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// Skip `let` stored properties. `mutating func move(along:)` cannot be
// synthesized to update these properties.
if (vd->isLet())
@@ -224,15 +234,15 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
if (confRef.isConcrete())
memberMethodDecl = confRef.getConcrete()->getWitnessDecl(methodReq);
assert(memberMethodDecl && "Member method declaration must exist");
auto memberMethodDRE =
auto *memberMethodDRE =
new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true);
memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
// Create reference to member method: `x.move(along:)`.
auto memberExpr =
Expr *memberExpr =
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
/*Implicit*/ true);
auto memberMethodExpr =
auto *memberMethodExpr =
new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr);
// Create reference to parameter member: `direction.x`.
@@ -483,20 +493,52 @@ static void addAssociatedTypeAliasDecl(Identifier name, DeclContext *sourceDC,
static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
NominalTypeDecl *nominal,
DeclContext *DC) {
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
// If nominal type can conform to `AdditiveArithmetic`, suggest adding a
// conformance to `AdditiveArithmetic` in fix-its.
// `Differentiable` protocol requirements all have default implementations
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
// derived conformances will no longer be necessary.
bool nominalCanDeriveAdditiveArithmetic =
DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
// Check all stored properties.
for (auto *vd : nominal->getStoredProperties()) {
// Peer through property wrappers: use original wrapped properties.
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
// Skip wrapped properties with `@noDerivative` attribute.
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// Diagnose wrapped properties whose property wrappers do not define
// `wrappedValue.set`. `mutating func move(along:)` cannot be synthesized
// to update these properties.
auto *wrapperDecl =
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
auto *wrappedValueDecl =
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
if (!wrappedValueDecl->getAccessor(AccessorKind::Set)) {
auto loc =
originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
Context.Diags
.diagnose(
loc,
diag::
differentiable_immutable_wrapper_implicit_noderivative_fixit,
wrapperDecl->getNameStr(), nominal->getName(),
nominalCanDeriveAdditiveArithmetic)
.fixItInsert(loc, "@noDerivative ");
// Add an implicit `@noDerivative` attribute.
originalProperty->getAttrs().add(
new (Context) NoDerivativeAttr(/*Implicit*/ true));
continue;
}
// Use the original wrapped property.
vd = originalProperty;
}
if (vd->getInterfaceType()->hasError())
continue;
// Skip stored properties with `@noDerivative` attribute.
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// For property wrapper backing storage properties, skip if original
// property has `@noDerivative` attribute.
if (auto *originalProperty = vd->getOriginalWrappedProperty())
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// Check whether to diagnose stored property.
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
bool conformsToDifferentiable =
@@ -508,14 +550,8 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
// Otherwise, add an implicit `@noDerivative` attribute.
vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));
auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false);
if (auto *originalProperty = vd->getOriginalWrappedProperty())
loc = originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
assert(loc.isValid() && "Expected valid source location");
// If nominal type can conform to `AdditiveArithmetic`, suggest conforming
// adding a conformance to `AdditiveArithmetic`.
// `Differentiable` protocol requirements all have default implementations
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
// derived conformances will no longer be necessary.
// Diagnose properties that do not conform to `Differentiable`.
if (!conformsToDifferentiable) {
Context.Diags
.diagnose(
@@ -526,11 +562,11 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
.fixItInsert(loc, "@noDerivative ");
continue;
}
// Otherwise, diagnose `let` property.
Context.Diags
.diagnose(loc,
diag::differentiable_let_property_implicit_noderivative_fixit,
vd->getName(), nominal->getName(),
nominalCanDeriveAdditiveArithmetic)
nominal->getName(), nominalCanDeriveAdditiveArithmetic)
.fixItInsert(loc, "@noDerivative ");
}
}

View File

@@ -38,7 +38,7 @@ class ImmutableStoredProperties: Differentiable {
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let nondiff: Int
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let diff: Float
init() {
@@ -56,7 +56,8 @@ class MutableStoredPropertiesWithInitialValue: Differentiable {
}
// Test class with both an empty constructor and memberwise initializer.
class AllMixedStoredPropertiesHaveInitialValue: Differentiable {
let x = Float(1) // expected-warning {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let x = Float(1)
var y = Float(1)
// Memberwise initializer should be `init(y:)` since `x` is immutable.
static func testMemberwiseInitializer() {
@@ -506,26 +507,35 @@ where T: AdditiveArithmetic {}
extension NoMemberwiseInitializerExtended: Differentiable
where T: Differentiable & AdditiveArithmetic {}
// Test property wrappers.
// TF-1190: Test `@noDerivative` warning for property wrapper backing storage properties.
@propertyWrapper
struct Wrapper<Value> {
struct ImmutableWrapper<Value> {
private var value: Value
var wrappedValue: Value {
get { value }
set { value = newValue }
var wrappedValue: Value { value }
init(wrappedValue: Value) {
self.value = wrappedValue
}
}
struct TF_1190<T> {}
class TF_1190_Outer: Differentiable {
// expected-warning @+1 {{stored property '_x' has no derivative because 'Wrapper<TF_1190<Float>>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
@Wrapper var x: TF_1190<Float>
@noDerivative @Wrapper var y: TF_1190<Float>
init(x: TF_1190<Float>, y: TF_1190<Float>) {
self.x = x
self.y = y
}
@propertyWrapper
struct Wrapper<Value> {
var wrappedValue: Value
}
struct Generic<T> {}
extension Generic: Differentiable where T: Differentiable {}
class WrappedProperties: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires all stored properties not marked with `@noDerivative` to be mutable; add an explicit '@noDerivative' attribute}}
@ImmutableWrapper var immutableInt: Generic<Int> = Generic()
// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
@Wrapper var mutableInt: Generic<Int> = Generic()
@Wrapper var float: Generic<Float> = Generic()
@noDerivative @ImmutableWrapper var nondiff: Generic<Int> = Generic()
}
// Test derived conformances in disallowed contexts.

View File

@@ -80,3 +80,40 @@ struct UsableFromInlineStruct: Differentiable {}
// CHECK-AST: internal init()
// CHECK-AST: @usableFromInline
// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic {
// Test property wrappers.
@propertyWrapper
struct Wrapper<Value> {
var wrappedValue: Value
}
struct WrappedPropertiesStruct: Differentiable {
@Wrapper @Wrapper var x: Float
@Wrapper var y: Float
var z: Float
@noDerivative @Wrapper var nondiff: Float
}
// CHECK-AST-LABEL: internal struct WrappedPropertiesStruct : Differentiable {
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
// CHECK-AST: internal var x: Float.TangentVector
// CHECK-AST: internal var y: Float.TangentVector
// CHECK-AST: internal var z: Float.TangentVector
// CHECK-AST: }
// CHECK-AST: }
class WrappedPropertiesClass: Differentiable {
@Wrapper @Wrapper var x: Float = 1
@Wrapper var y: Float = 2
var z: Float = 3
@noDerivative @Wrapper var noDeriv: Float = 4
}
// CHECK-AST-LABEL: internal class WrappedPropertiesClass : Differentiable {
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
// CHECK-AST: internal var x: Float.TangentVector
// CHECK-AST: internal var y: Float.TangentVector
// CHECK-AST: internal var z: Float.TangentVector
// CHECK-AST: }
// CHECK-AST: }

View File

@@ -24,7 +24,7 @@ struct ImmutableStoredProperties: Differentiable {
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }}
let nondiff: Int
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic}} {{3-3=@noDerivative }}
let diff: Float
}
func testImmutableStoredProperties() {
@@ -36,7 +36,8 @@ struct MutableStoredPropertiesWithInitialValue: Differentiable {
}
// Test struct with both an empty constructor and memberwise initializer.
struct AllMixedStoredPropertiesHaveInitialValue: Differentiable {
let x = Float(1) // expected-warning {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let x = Float(1)
var y = Float(1)
// Memberwise initializer should be `init(y:)` since `x` is immutable.
static func testMemberwiseInitializer() {
@@ -323,20 +324,32 @@ where T: AdditiveArithmetic {}
extension NoMemberwiseInitializerExtended: Differentiable
where T: Differentiable & AdditiveArithmetic {}
// Test property wrappers.
// TF-1190: Test `@noDerivative` warning for property wrapper backing storage properties.
@propertyWrapper
struct Wrapper<Value> {
struct ImmutableWrapper<Value> {
private var value: Value
var wrappedValue: Value {
value
}
var wrappedValue: Value { value }
}
struct TF_1190<T> {}
struct TF_1190_Outer: Differentiable {
// expected-warning @+1 {{stored property '_x' has no derivative because 'Wrapper<TF_1190<Float>>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
@Wrapper var x: TF_1190<Float>
@noDerivative @Wrapper var y: TF_1190<Float>
@propertyWrapper
struct Wrapper<Value> {
var wrappedValue: Value
}
struct Generic<T> {}
extension Generic: Differentiable where T: Differentiable {}
struct WrappedProperties: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires all stored properties not marked with `@noDerivative` to be mutable; add an explicit '@noDerivative' attribute}}
@ImmutableWrapper var immutableInt: Generic<Int>
// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
@Wrapper var mutableInt: Generic<Int>
@Wrapper var float: Generic<Float>
@noDerivative @ImmutableWrapper var nondiff: Generic<Int>
}
// Verify that cross-file derived conformances are disallowed.