First cut of parameter pack differentiation

This commit is contained in:
Anton Korobeynikov
2025-07-21 16:10:26 -07:00
parent 9758be89c7
commit 595abee2c1
11 changed files with 143 additions and 36 deletions

View File

@@ -676,9 +676,13 @@ struct InferRequirementsWalker : public TypeWalker {
auto *tangentVectorAssocType =
differentiableProtocol->getAssociatedType(ctx.Id_TangentVector);
auto addRequirements = [&](Type type, bool isLinear) {
addConformanceConstraint(type, differentiableProtocol);
if (isLinear)
addSameTypeConstraint(type, tangentVectorAssocType);
// Pack is differentiable if each pattern type is differentiable
if (auto packExpansion = type->getAs<PackExpansionType>())
type = packExpansion->getPatternType();
addConformanceConstraint(type, differentiableProtocol);
if (isLinear)
addSameTypeConstraint(type, tangentVectorAssocType);
};
auto constrainParametersAndResult = [&](bool isLinear) {
for (auto &param : fnTy->getParams())