mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[AutoDiff] Make Differentiable derivation support property wrappers.
Differentiable conformance derivation now "peers through" property wrappers.
Synthesized TangentVector structs contain wrapped properties' TangentVectors as
stored properties, not wrappers' TangentVectors.
Property wrapper types are not required to conform to `Differentiable`.
Property wrapper types are required to provide `wrappedValue.set`, which is
needed to synthesize `mutating func move(along:)`.
```
import _Differentiation
@propertyWrapper
struct Wrapper<Value> {
var wrappedValue: Value
}
struct Struct: Differentiable {
@Wrapper var x: Float = 0
// Compiler now synthesizes:
// struct TangentVector: Differentiable & AdditiveArithmetic {
// var x: Float
// ...
// }
}
```
Resolves SR-12638.
This commit is contained in:
@@ -2734,13 +2734,21 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
|
||||
"stored property %0 has no derivative because %1 does not conform to "
|
||||
"'Differentiable'; add an explicit '@noDerivative' attribute"
|
||||
"%select{|, or conform %2 to 'AdditiveArithmetic'}3",
|
||||
(Identifier, Type, Identifier, bool))
|
||||
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
|
||||
(/*propName*/ Identifier, /*propType*/ Type, /*nominalName*/ Identifier,
|
||||
/*nominalCanDeriveAdditiveArithmetic*/ bool))
|
||||
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
|
||||
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
|
||||
"requires all stored properties to be mutable; use 'var' instead, or add "
|
||||
"an explicit '@noDerivative' attribute"
|
||||
"requires all stored properties not marked with `@noDerivative` to be "
|
||||
"mutable; add an explicit '@noDerivative' attribute"
|
||||
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
|
||||
(Identifier, Identifier, bool))
|
||||
(/*wrapperType*/ StringRef, /*nominalName*/ Identifier,
|
||||
/*nominalCanDeriveAdditiveArithmetic*/ bool))
|
||||
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
|
||||
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
|
||||
"requires all stored properties not marked with `@noDerivative` to be "
|
||||
"mutable; use 'var' instead, or add an explicit '@noDerivative' attribute"
|
||||
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
|
||||
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))
|
||||
|
||||
NOTE(codable_extraneous_codingkey_case_here,none,
|
||||
"CodingKey case %0 does not match any stored properties", (Identifier))
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "swift/AST/ParameterList.h"
|
||||
#include "swift/AST/Pattern.h"
|
||||
#include "swift/AST/ProtocolConformance.h"
|
||||
#include "swift/AST/PropertyWrappers.h"
|
||||
#include "swift/AST/Stmt.h"
|
||||
#include "swift/AST/Types.h"
|
||||
#include "DerivedConformances.h"
|
||||
@@ -39,14 +40,23 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
|
||||
auto &C = nominal->getASTContext();
|
||||
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
|
||||
for (auto *vd : nominal->getStoredProperties()) {
|
||||
// Peer through property wrappers: use original wrapped properties instead.
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
|
||||
// Skip property wrappers that do not define `wrappedValue.set`.
|
||||
// `mutating func move(along:)` cannot be synthesized to update these
|
||||
// properties.
|
||||
auto *wrapperDecl =
|
||||
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
|
||||
auto *wrappedValueDecl =
|
||||
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
|
||||
if (!wrappedValueDecl->getAccessor(AccessorKind::Set))
|
||||
continue;
|
||||
// Use the original wrapped property.
|
||||
vd = originalProperty;
|
||||
}
|
||||
// Skip stored properties with `@noDerivative` attribute.
|
||||
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// For property wrapper backing storage properties, skip if original
|
||||
// property has `@noDerivative` attribute.
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty())
|
||||
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// Skip `let` stored properties. `mutating func move(along:)` cannot be
|
||||
// synthesized to update these properties.
|
||||
if (vd->isLet())
|
||||
@@ -224,15 +234,15 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
|
||||
if (confRef.isConcrete())
|
||||
memberMethodDecl = confRef.getConcrete()->getWitnessDecl(methodReq);
|
||||
assert(memberMethodDecl && "Member method declaration must exist");
|
||||
auto memberMethodDRE =
|
||||
auto *memberMethodDRE =
|
||||
new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true);
|
||||
memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
|
||||
|
||||
// Create reference to member method: `x.move(along:)`.
|
||||
auto memberExpr =
|
||||
Expr *memberExpr =
|
||||
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
|
||||
/*Implicit*/ true);
|
||||
auto memberMethodExpr =
|
||||
auto *memberMethodExpr =
|
||||
new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr);
|
||||
|
||||
// Create reference to parameter member: `direction.x`.
|
||||
@@ -483,20 +493,52 @@ static void addAssociatedTypeAliasDecl(Identifier name, DeclContext *sourceDC,
|
||||
static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
|
||||
NominalTypeDecl *nominal,
|
||||
DeclContext *DC) {
|
||||
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
|
||||
// If nominal type can conform to `AdditiveArithmetic`, suggest adding a
|
||||
// conformance to `AdditiveArithmetic` in fix-its.
|
||||
// `Differentiable` protocol requirements all have default implementations
|
||||
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
|
||||
// derived conformances will no longer be necessary.
|
||||
bool nominalCanDeriveAdditiveArithmetic =
|
||||
DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
|
||||
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
|
||||
// Check all stored properties.
|
||||
for (auto *vd : nominal->getStoredProperties()) {
|
||||
// Peer through property wrappers: use original wrapped properties.
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
|
||||
// Skip wrapped properties with `@noDerivative` attribute.
|
||||
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// Diagnose wrapped properties whose property wrappers do not define
|
||||
// `wrappedValue.set`. `mutating func move(along:)` cannot be synthesized
|
||||
// to update these properties.
|
||||
auto *wrapperDecl =
|
||||
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
|
||||
auto *wrappedValueDecl =
|
||||
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
|
||||
if (!wrappedValueDecl->getAccessor(AccessorKind::Set)) {
|
||||
auto loc =
|
||||
originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
|
||||
Context.Diags
|
||||
.diagnose(
|
||||
loc,
|
||||
diag::
|
||||
differentiable_immutable_wrapper_implicit_noderivative_fixit,
|
||||
wrapperDecl->getNameStr(), nominal->getName(),
|
||||
nominalCanDeriveAdditiveArithmetic)
|
||||
.fixItInsert(loc, "@noDerivative ");
|
||||
// Add an implicit `@noDerivative` attribute.
|
||||
originalProperty->getAttrs().add(
|
||||
new (Context) NoDerivativeAttr(/*Implicit*/ true));
|
||||
continue;
|
||||
}
|
||||
// Use the original wrapped property.
|
||||
vd = originalProperty;
|
||||
}
|
||||
if (vd->getInterfaceType()->hasError())
|
||||
continue;
|
||||
// Skip stored properties with `@noDerivative` attribute.
|
||||
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// For property wrapper backing storage properties, skip if original
|
||||
// property has `@noDerivative` attribute.
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty())
|
||||
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// Check whether to diagnose stored property.
|
||||
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
|
||||
bool conformsToDifferentiable =
|
||||
@@ -508,14 +550,8 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
|
||||
// Otherwise, add an implicit `@noDerivative` attribute.
|
||||
vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));
|
||||
auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false);
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty())
|
||||
loc = originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
|
||||
assert(loc.isValid() && "Expected valid source location");
|
||||
// If nominal type can conform to `AdditiveArithmetic`, suggest conforming
|
||||
// adding a conformance to `AdditiveArithmetic`.
|
||||
// `Differentiable` protocol requirements all have default implementations
|
||||
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
|
||||
// derived conformances will no longer be necessary.
|
||||
// Diagnose properties that do not conform to `Differentiable`.
|
||||
if (!conformsToDifferentiable) {
|
||||
Context.Diags
|
||||
.diagnose(
|
||||
@@ -526,11 +562,11 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
|
||||
.fixItInsert(loc, "@noDerivative ");
|
||||
continue;
|
||||
}
|
||||
// Otherwise, diagnose `let` property.
|
||||
Context.Diags
|
||||
.diagnose(loc,
|
||||
diag::differentiable_let_property_implicit_noderivative_fixit,
|
||||
vd->getName(), nominal->getName(),
|
||||
nominalCanDeriveAdditiveArithmetic)
|
||||
nominal->getName(), nominalCanDeriveAdditiveArithmetic)
|
||||
.fixItInsert(loc, "@noDerivative ");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ class ImmutableStoredProperties: Differentiable {
|
||||
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
|
||||
let nondiff: Int
|
||||
|
||||
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
|
||||
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
|
||||
let diff: Float
|
||||
|
||||
init() {
|
||||
@@ -56,7 +56,8 @@ class MutableStoredPropertiesWithInitialValue: Differentiable {
|
||||
}
|
||||
// Test class with both an empty constructor and memberwise initializer.
|
||||
class AllMixedStoredPropertiesHaveInitialValue: Differentiable {
|
||||
let x = Float(1) // expected-warning {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
|
||||
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
|
||||
let x = Float(1)
|
||||
var y = Float(1)
|
||||
// Memberwise initializer should be `init(y:)` since `x` is immutable.
|
||||
static func testMemberwiseInitializer() {
|
||||
@@ -506,26 +507,35 @@ where T: AdditiveArithmetic {}
|
||||
extension NoMemberwiseInitializerExtended: Differentiable
|
||||
where T: Differentiable & AdditiveArithmetic {}
|
||||
|
||||
// Test property wrappers.
|
||||
// TF-1190: Test `@noDerivative` warning for property wrapper backing storage properties.
|
||||
|
||||
@propertyWrapper
|
||||
struct Wrapper<Value> {
|
||||
struct ImmutableWrapper<Value> {
|
||||
private var value: Value
|
||||
var wrappedValue: Value {
|
||||
get { value }
|
||||
set { value = newValue }
|
||||
var wrappedValue: Value { value }
|
||||
init(wrappedValue: Value) {
|
||||
self.value = wrappedValue
|
||||
}
|
||||
}
|
||||
struct TF_1190<T> {}
|
||||
class TF_1190_Outer: Differentiable {
|
||||
// expected-warning @+1 {{stored property '_x' has no derivative because 'Wrapper<TF_1190<Float>>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
|
||||
@Wrapper var x: TF_1190<Float>
|
||||
@noDerivative @Wrapper var y: TF_1190<Float>
|
||||
|
||||
init(x: TF_1190<Float>, y: TF_1190<Float>) {
|
||||
self.x = x
|
||||
self.y = y
|
||||
@propertyWrapper
|
||||
struct Wrapper<Value> {
|
||||
var wrappedValue: Value
|
||||
}
|
||||
|
||||
struct Generic<T> {}
|
||||
extension Generic: Differentiable where T: Differentiable {}
|
||||
|
||||
class WrappedProperties: Differentiable {
|
||||
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires all stored properties not marked with `@noDerivative` to be mutable; add an explicit '@noDerivative' attribute}}
|
||||
@ImmutableWrapper var immutableInt: Generic<Int> = Generic()
|
||||
|
||||
// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
|
||||
@Wrapper var mutableInt: Generic<Int> = Generic()
|
||||
|
||||
@Wrapper var float: Generic<Float> = Generic()
|
||||
@noDerivative @ImmutableWrapper var nondiff: Generic<Int> = Generic()
|
||||
}
|
||||
|
||||
// Test derived conformances in disallowed contexts.
|
||||
|
||||
@@ -80,3 +80,40 @@ struct UsableFromInlineStruct: Differentiable {}
|
||||
// CHECK-AST: internal init()
|
||||
// CHECK-AST: @usableFromInline
|
||||
// CHECK-AST: struct TangentVector : Differentiable, AdditiveArithmetic {
|
||||
|
||||
// Test property wrappers.
|
||||
|
||||
@propertyWrapper
|
||||
struct Wrapper<Value> {
|
||||
var wrappedValue: Value
|
||||
}
|
||||
|
||||
struct WrappedPropertiesStruct: Differentiable {
|
||||
@Wrapper @Wrapper var x: Float
|
||||
@Wrapper var y: Float
|
||||
var z: Float
|
||||
@noDerivative @Wrapper var nondiff: Float
|
||||
}
|
||||
|
||||
// CHECK-AST-LABEL: internal struct WrappedPropertiesStruct : Differentiable {
|
||||
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
|
||||
// CHECK-AST: internal var x: Float.TangentVector
|
||||
// CHECK-AST: internal var y: Float.TangentVector
|
||||
// CHECK-AST: internal var z: Float.TangentVector
|
||||
// CHECK-AST: }
|
||||
// CHECK-AST: }
|
||||
|
||||
class WrappedPropertiesClass: Differentiable {
|
||||
@Wrapper @Wrapper var x: Float = 1
|
||||
@Wrapper var y: Float = 2
|
||||
var z: Float = 3
|
||||
@noDerivative @Wrapper var noDeriv: Float = 4
|
||||
}
|
||||
|
||||
// CHECK-AST-LABEL: internal class WrappedPropertiesClass : Differentiable {
|
||||
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic {
|
||||
// CHECK-AST: internal var x: Float.TangentVector
|
||||
// CHECK-AST: internal var y: Float.TangentVector
|
||||
// CHECK-AST: internal var z: Float.TangentVector
|
||||
// CHECK-AST: }
|
||||
// CHECK-AST: }
|
||||
|
||||
@@ -24,7 +24,7 @@ struct ImmutableStoredProperties: Differentiable {
|
||||
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }}
|
||||
let nondiff: Int
|
||||
|
||||
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }}
|
||||
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic}} {{3-3=@noDerivative }}
|
||||
let diff: Float
|
||||
}
|
||||
func testImmutableStoredProperties() {
|
||||
@@ -36,7 +36,8 @@ struct MutableStoredPropertiesWithInitialValue: Differentiable {
|
||||
}
|
||||
// Test struct with both an empty constructor and memberwise initializer.
|
||||
struct AllMixedStoredPropertiesHaveInitialValue: Differentiable {
|
||||
let x = Float(1) // expected-warning {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
|
||||
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
|
||||
let x = Float(1)
|
||||
var y = Float(1)
|
||||
// Memberwise initializer should be `init(y:)` since `x` is immutable.
|
||||
static func testMemberwiseInitializer() {
|
||||
@@ -323,20 +324,32 @@ where T: AdditiveArithmetic {}
|
||||
extension NoMemberwiseInitializerExtended: Differentiable
|
||||
where T: Differentiable & AdditiveArithmetic {}
|
||||
|
||||
// Test property wrappers.
|
||||
// TF-1190: Test `@noDerivative` warning for property wrapper backing storage properties.
|
||||
|
||||
@propertyWrapper
|
||||
struct Wrapper<Value> {
|
||||
struct ImmutableWrapper<Value> {
|
||||
private var value: Value
|
||||
var wrappedValue: Value {
|
||||
value
|
||||
var wrappedValue: Value { value }
|
||||
}
|
||||
|
||||
@propertyWrapper
|
||||
struct Wrapper<Value> {
|
||||
var wrappedValue: Value
|
||||
}
|
||||
struct TF_1190<T> {}
|
||||
struct TF_1190_Outer: Differentiable {
|
||||
// expected-warning @+1 {{stored property '_x' has no derivative because 'Wrapper<TF_1190<Float>>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
|
||||
@Wrapper var x: TF_1190<Float>
|
||||
@noDerivative @Wrapper var y: TF_1190<Float>
|
||||
|
||||
struct Generic<T> {}
|
||||
extension Generic: Differentiable where T: Differentiable {}
|
||||
|
||||
struct WrappedProperties: Differentiable {
|
||||
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires all stored properties not marked with `@noDerivative` to be mutable; add an explicit '@noDerivative' attribute}}
|
||||
@ImmutableWrapper var immutableInt: Generic<Int>
|
||||
|
||||
// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
|
||||
@Wrapper var mutableInt: Generic<Int>
|
||||
|
||||
@Wrapper var float: Generic<Float>
|
||||
@noDerivative @ImmutableWrapper var nondiff: Generic<Int>
|
||||
}
|
||||
|
||||
// Verify that cross-file derived conformances are disallowed.
|
||||
|
||||
Reference in New Issue
Block a user