First cut of parameter pack differentiation

This commit is contained in:
Anton Korobeynikov
2025-07-21 16:10:26 -07:00
parent 9758be89c7
commit 595abee2c1
11 changed files with 143 additions and 36 deletions

View File

@@ -353,7 +353,14 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
// Require differentiability parameters to conform to `Differentiable`.
for (unsigned paramIdx : diffParamIndices->getIndices()) {
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
addRequirement(paramType);
if (auto silPackTy = dyn_cast<SILPackType>(paramType)) {
for (auto elTy : silPackTy.getElementTypes())
if (auto packExpansionTy = dyn_cast<PackExpansionType>(elTy))
addRequirement(packExpansionTy.getPatternType());
else
addRequirement(elTy);
} else
addRequirement(paramType);
}
// Require differentiability results to conform to `Differentiable`.
@@ -485,6 +492,10 @@ Type TangentSpace::getType() const {
return value.tangentVectorType;
case Kind::Tuple:
return value.tupleType;
case Kind::PackExpansion:
return value.packExpansionType;
case Kind::SILPackType:
return value.silPackType;
}
llvm_unreachable("invalid tangent space kind");
}