Ensure we are adding T : Differentiable conformance from

protocol conditional conformance.

Fixes #75711
This commit is contained in:
Anton Korobeynikov
2024-11-07 16:50:33 +03:00
parent 11964e5a25
commit 4778bbce59
3 changed files with 56 additions and 0 deletions

View File

@@ -7447,6 +7447,14 @@ void SILGenFunction::emitProtocolWitness(
// Grab the type of our thunk.
auto thunkTy = F.getLoweredFunctionType();
// The protocol conditional conformance itself might bring some T :
// Differentiable conformances. They are already added to the derivative
// generic signature. Update witness substitution map generic signature to
// have them as well.
if (auto *derivativeId = witness.getDerivativeFunctionIdentifier())
witnessSubs = SubstitutionMap::get(derivativeId->getDerivativeGenericSignature(),
witnessSubs);
// Then get the type of the witness.
auto witnessKind = getWitnessDispatchKind(witness, isSelfConformance);
auto witnessInfo = getConstantInfo(getTypeExpansionContext(), witness);

View File

@@ -298,6 +298,29 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness,
return witness;
}
static GenericSignature maybeAddDifferentiableFromContext(DeclContext *dc,
GenericSignature derivativeGenSig) {
auto conformanceGenSig = dc->getGenericSignatureOfContext();
if (!conformanceGenSig)
return derivativeGenSig;
// The protocol conditional conformance itself might bring some T :
// Differentiable conformances. Add them the the derivative generic signature.
SmallVector<Requirement, 4> diffRequirements;
llvm::copy_if(conformanceGenSig.getRequirements(),
std::back_inserter(diffRequirements),
[](const Requirement &requirement) {
if (requirement.getKind() != RequirementKind::Conformance)
return false;
auto protoKind = requirement.getProtocolDecl()->getKnownProtocolKind();
return protoKind && *protoKind == KnownProtocolKind::Differentiable;
});
return buildGenericSignature(dc->getASTContext(), derivativeGenSig,
{}, std::move(diffRequirements), /*allowInverses=*/true);
}
/// Given a witness, a requirement, and an existing `RequirementMatch` result,
/// check if the requirement's `@differentiable` attributes are met by the
/// witness.
@@ -429,6 +452,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
auto derivativeGenSig = witnessAFD->getGenericSignature();
if (supersetConfig)
derivativeGenSig = supersetConfig->derivativeGenericSignature;
derivativeGenSig = maybeAddDifferentiableFromContext(dc, derivativeGenSig);
// Use source location of the witness declaration as the source location
// of the implicit `@differentiable` attribute.
auto *newAttr = DifferentiableAttr::create(

View File

@@ -0,0 +1,22 @@
// RUN: %target-swift-frontend -emit-sil -verify %s
// https://github.com/swiftlang/swift/issues/75711
// Ensure we propagate T : Differentiable conditional conformance
import _Differentiation
struct Wrapper<T> {
func read(_ t: T) -> T {
return t
}
}
protocol P {
associatedtype T: Differentiable
@differentiable(reverse)
func read(_: T) -> T
}
extension Wrapper: P where T: Differentiable {}