mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
Ensure we are adding T : Differentiable conformance from
protocol conditional conformance. Fixes #75711
This commit is contained in:
@@ -7447,6 +7447,14 @@ void SILGenFunction::emitProtocolWitness(
|
||||
// Grab the type of our thunk.
|
||||
auto thunkTy = F.getLoweredFunctionType();
|
||||
|
||||
// The protocol conditional conformance itself might bring some T :
|
||||
// Differentiable conformances. They are already added to the derivative
|
||||
// generic signature. Update witness substitution map generic signature to
|
||||
// have them as well.
|
||||
if (auto *derivativeId = witness.getDerivativeFunctionIdentifier())
|
||||
witnessSubs = SubstitutionMap::get(derivativeId->getDerivativeGenericSignature(),
|
||||
witnessSubs);
|
||||
|
||||
// Then get the type of the witness.
|
||||
auto witnessKind = getWitnessDispatchKind(witness, isSelfConformance);
|
||||
auto witnessInfo = getConstantInfo(getTypeExpansionContext(), witness);
|
||||
|
||||
@@ -298,6 +298,29 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness,
|
||||
return witness;
|
||||
}
|
||||
|
||||
static GenericSignature maybeAddDifferentiableFromContext(DeclContext *dc,
|
||||
GenericSignature derivativeGenSig) {
|
||||
auto conformanceGenSig = dc->getGenericSignatureOfContext();
|
||||
if (!conformanceGenSig)
|
||||
return derivativeGenSig;
|
||||
|
||||
// The protocol conditional conformance itself might bring some T :
|
||||
// Differentiable conformances. Add them the the derivative generic signature.
|
||||
SmallVector<Requirement, 4> diffRequirements;
|
||||
llvm::copy_if(conformanceGenSig.getRequirements(),
|
||||
std::back_inserter(diffRequirements),
|
||||
[](const Requirement &requirement) {
|
||||
if (requirement.getKind() != RequirementKind::Conformance)
|
||||
return false;
|
||||
|
||||
auto protoKind = requirement.getProtocolDecl()->getKnownProtocolKind();
|
||||
return protoKind && *protoKind == KnownProtocolKind::Differentiable;
|
||||
});
|
||||
|
||||
return buildGenericSignature(dc->getASTContext(), derivativeGenSig,
|
||||
{}, std::move(diffRequirements), /*allowInverses=*/true);
|
||||
}
|
||||
|
||||
/// Given a witness, a requirement, and an existing `RequirementMatch` result,
|
||||
/// check if the requirement's `@differentiable` attributes are met by the
|
||||
/// witness.
|
||||
@@ -429,6 +452,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
|
||||
auto derivativeGenSig = witnessAFD->getGenericSignature();
|
||||
if (supersetConfig)
|
||||
derivativeGenSig = supersetConfig->derivativeGenericSignature;
|
||||
|
||||
derivativeGenSig = maybeAddDifferentiableFromContext(dc, derivativeGenSig);
|
||||
|
||||
// Use source location of the witness declaration as the source location
|
||||
// of the implicit `@differentiable` attribute.
|
||||
auto *newAttr = DifferentiableAttr::create(
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
// RUN: %target-swift-frontend -emit-sil -verify %s
|
||||
|
||||
// https://github.com/swiftlang/swift/issues/75711
|
||||
|
||||
// Ensure we propagate T : Differentiable conditional conformance
|
||||
|
||||
import _Differentiation
|
||||
|
||||
struct Wrapper<T> {
|
||||
func read(_ t: T) -> T {
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
protocol P {
|
||||
associatedtype T: Differentiable
|
||||
|
||||
@differentiable(reverse)
|
||||
func read(_: T) -> T
|
||||
}
|
||||
|
||||
extension Wrapper: P where T: Differentiable {}
|
||||
Reference in New Issue
Block a user