mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
First cut of parameter pack differentiation
This commit is contained in:
@@ -353,7 +353,14 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
|
||||
// Require differentiability parameters to conform to `Differentiable`.
|
||||
for (unsigned paramIdx : diffParamIndices->getIndices()) {
|
||||
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
|
||||
addRequirement(paramType);
|
||||
if (auto silPackTy = dyn_cast<SILPackType>(paramType)) {
|
||||
for (auto elTy : silPackTy.getElementTypes())
|
||||
if (auto packExpansionTy = dyn_cast<PackExpansionType>(elTy))
|
||||
addRequirement(packExpansionTy.getPatternType());
|
||||
else
|
||||
addRequirement(elTy);
|
||||
} else
|
||||
addRequirement(paramType);
|
||||
}
|
||||
|
||||
// Require differentiability results to conform to `Differentiable`.
|
||||
@@ -485,6 +492,10 @@ Type TangentSpace::getType() const {
|
||||
return value.tangentVectorType;
|
||||
case Kind::Tuple:
|
||||
return value.tupleType;
|
||||
case Kind::PackExpansion:
|
||||
return value.packExpansionType;
|
||||
case Kind::SILPackType:
|
||||
return value.silPackType;
|
||||
}
|
||||
llvm_unreachable("invalid tangent space kind");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user