mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[AutoDiff] Make Differentiable derivation support property wrappers.
Differentiable conformance derivation now "peers through" property wrappers.
Synthesized TangentVector structs contain wrapped properties' TangentVectors as
stored properties, not wrappers' TangentVectors.
Property wrapper types are not required to conform to `Differentiable`.
Property wrapper types are required to provide `wrappedValue.set`, which is
needed to synthesize `mutating func move(along:)`.
```
import _Differentiation
@propertyWrapper
struct Wrapper<Value> {
var wrappedValue: Value
}
struct Struct: Differentiable {
@Wrapper var x: Float = 0
// Compiler now synthesizes:
// struct TangentVector: Differentiable & AdditiveArithmetic {
// var x: Float
// ...
// }
}
```
Resolves SR-12638.
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
#include "swift/AST/ParameterList.h"
|
||||
#include "swift/AST/Pattern.h"
|
||||
#include "swift/AST/ProtocolConformance.h"
|
||||
#include "swift/AST/PropertyWrappers.h"
|
||||
#include "swift/AST/Stmt.h"
|
||||
#include "swift/AST/Types.h"
|
||||
#include "DerivedConformances.h"
|
||||
@@ -39,14 +40,23 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
|
||||
auto &C = nominal->getASTContext();
|
||||
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
|
||||
for (auto *vd : nominal->getStoredProperties()) {
|
||||
// Peer through property wrappers: use original wrapped properties instead.
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
|
||||
// Skip property wrappers that do not define `wrappedValue.set`.
|
||||
// `mutating func move(along:)` cannot be synthesized to update these
|
||||
// properties.
|
||||
auto *wrapperDecl =
|
||||
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
|
||||
auto *wrappedValueDecl =
|
||||
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
|
||||
if (!wrappedValueDecl->getAccessor(AccessorKind::Set))
|
||||
continue;
|
||||
// Use the original wrapped property.
|
||||
vd = originalProperty;
|
||||
}
|
||||
// Skip stored properties with `@noDerivative` attribute.
|
||||
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// For property wrapper backing storage properties, skip if original
|
||||
// property has `@noDerivative` attribute.
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty())
|
||||
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// Skip `let` stored properties. `mutating func move(along:)` cannot be
|
||||
// synthesized to update these properties.
|
||||
if (vd->isLet())
|
||||
@@ -224,15 +234,15 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
|
||||
if (confRef.isConcrete())
|
||||
memberMethodDecl = confRef.getConcrete()->getWitnessDecl(methodReq);
|
||||
assert(memberMethodDecl && "Member method declaration must exist");
|
||||
auto memberMethodDRE =
|
||||
auto *memberMethodDRE =
|
||||
new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true);
|
||||
memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
|
||||
|
||||
// Create reference to member method: `x.move(along:)`.
|
||||
auto memberExpr =
|
||||
Expr *memberExpr =
|
||||
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
|
||||
/*Implicit*/ true);
|
||||
auto memberMethodExpr =
|
||||
auto *memberMethodExpr =
|
||||
new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr);
|
||||
|
||||
// Create reference to parameter member: `direction.x`.
|
||||
@@ -483,20 +493,52 @@ static void addAssociatedTypeAliasDecl(Identifier name, DeclContext *sourceDC,
|
||||
static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
|
||||
NominalTypeDecl *nominal,
|
||||
DeclContext *DC) {
|
||||
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
|
||||
// If nominal type can conform to `AdditiveArithmetic`, suggest adding a
|
||||
// conformance to `AdditiveArithmetic` in fix-its.
|
||||
// `Differentiable` protocol requirements all have default implementations
|
||||
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
|
||||
// derived conformances will no longer be necessary.
|
||||
bool nominalCanDeriveAdditiveArithmetic =
|
||||
DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
|
||||
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
|
||||
// Check all stored properties.
|
||||
for (auto *vd : nominal->getStoredProperties()) {
|
||||
// Peer through property wrappers: use original wrapped properties.
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
|
||||
// Skip wrapped properties with `@noDerivative` attribute.
|
||||
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// Diagnose wrapped properties whose property wrappers do not define
|
||||
// `wrappedValue.set`. `mutating func move(along:)` cannot be synthesized
|
||||
// to update these properties.
|
||||
auto *wrapperDecl =
|
||||
vd->getInterfaceType()->getNominalOrBoundGenericNominal();
|
||||
auto *wrappedValueDecl =
|
||||
wrapperDecl->getPropertyWrapperTypeInfo().valueVar;
|
||||
if (!wrappedValueDecl->getAccessor(AccessorKind::Set)) {
|
||||
auto loc =
|
||||
originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
|
||||
Context.Diags
|
||||
.diagnose(
|
||||
loc,
|
||||
diag::
|
||||
differentiable_immutable_wrapper_implicit_noderivative_fixit,
|
||||
wrapperDecl->getNameStr(), nominal->getName(),
|
||||
nominalCanDeriveAdditiveArithmetic)
|
||||
.fixItInsert(loc, "@noDerivative ");
|
||||
// Add an implicit `@noDerivative` attribute.
|
||||
originalProperty->getAttrs().add(
|
||||
new (Context) NoDerivativeAttr(/*Implicit*/ true));
|
||||
continue;
|
||||
}
|
||||
// Use the original wrapped property.
|
||||
vd = originalProperty;
|
||||
}
|
||||
if (vd->getInterfaceType()->hasError())
|
||||
continue;
|
||||
// Skip stored properties with `@noDerivative` attribute.
|
||||
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// For property wrapper backing storage properties, skip if original
|
||||
// property has `@noDerivative` attribute.
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty())
|
||||
if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
|
||||
continue;
|
||||
// Check whether to diagnose stored property.
|
||||
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
|
||||
bool conformsToDifferentiable =
|
||||
@@ -508,14 +550,8 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
|
||||
// Otherwise, add an implicit `@noDerivative` attribute.
|
||||
vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));
|
||||
auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false);
|
||||
if (auto *originalProperty = vd->getOriginalWrappedProperty())
|
||||
loc = originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
|
||||
assert(loc.isValid() && "Expected valid source location");
|
||||
// If nominal type can conform to `AdditiveArithmetic`, suggest conforming
|
||||
// adding a conformance to `AdditiveArithmetic`.
|
||||
// `Differentiable` protocol requirements all have default implementations
|
||||
// when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
|
||||
// derived conformances will no longer be necessary.
|
||||
// Diagnose properties that do not conform to `Differentiable`.
|
||||
if (!conformsToDifferentiable) {
|
||||
Context.Diags
|
||||
.diagnose(
|
||||
@@ -526,11 +562,11 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
|
||||
.fixItInsert(loc, "@noDerivative ");
|
||||
continue;
|
||||
}
|
||||
// Otherwise, diagnose `let` property.
|
||||
Context.Diags
|
||||
.diagnose(loc,
|
||||
diag::differentiable_let_property_implicit_noderivative_fixit,
|
||||
vd->getName(), nominal->getName(),
|
||||
nominalCanDeriveAdditiveArithmetic)
|
||||
nominal->getName(), nominalCanDeriveAdditiveArithmetic)
|
||||
.fixItInsert(loc, "@noDerivative ");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user