mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[AST] Properly handle substitution of generic function types.
If substitution into the generic parameters of a generic function type changes those generic parameters, build a new generic signature from the resulting parameters. Otherwise, we can end up building ill-formed generic signatures that fail validation.
This commit is contained in:
244
lib/AST/Type.cpp
244
lib/AST/Type.cpp
@@ -18,6 +18,7 @@
|
||||
#include "ForeignRepresentationInfo.h"
|
||||
#include "swift/AST/ASTContext.h"
|
||||
#include "swift/AST/ExistentialLayout.h"
|
||||
#include "swift/AST/GenericSignatureBuilder.h"
|
||||
#include "swift/AST/TypeVisitor.h"
|
||||
#include "swift/AST/TypeWalker.h"
|
||||
#include "swift/AST/Decl.h"
|
||||
@@ -3089,15 +3090,110 @@ Type DependentMemberType::substBaseType(Type substBase,
|
||||
getAssocType(), getName(), None);
|
||||
}
|
||||
|
||||
static Type substGenericFunctionType(GenericFunctionType *genericFnType,
|
||||
TypeSubstitutionFn substitutions,
|
||||
LookupConformanceFn lookupConformances,
|
||||
SubstOptions options) {
|
||||
// Substitute into the function type (without generic signature).
|
||||
auto *bareFnType = FunctionType::get(genericFnType->getInput(),
|
||||
genericFnType->getResult(),
|
||||
genericFnType->getExtInfo());
|
||||
Type result =
|
||||
Type(bareFnType).subst(substitutions, lookupConformances, options);
|
||||
if (!result || result->is<ErrorType>()) return result;
|
||||
|
||||
auto *fnType = result->castTo<FunctionType>();
|
||||
// Substitute generic parameters.
|
||||
bool anySemanticChanges = false;
|
||||
SmallVector<GenericTypeParamType *, 4> genericParams;
|
||||
for (auto param : genericFnType->getGenericParams()) {
|
||||
Type paramTy =
|
||||
Type(param).subst(substitutions, lookupConformances, options);
|
||||
if (!paramTy)
|
||||
return Type();
|
||||
|
||||
if (auto newParam = paramTy->getAs<GenericTypeParamType>()) {
|
||||
if (!newParam->isEqual(param))
|
||||
anySemanticChanges = true;
|
||||
|
||||
genericParams.push_back(newParam);
|
||||
} else {
|
||||
anySemanticChanges = true;
|
||||
}
|
||||
}
|
||||
|
||||
// If no generic parameters remain, this is a non-generic function type.
|
||||
if (genericParams.empty())
|
||||
return result;
|
||||
|
||||
// Transform requirements.
|
||||
SmallVector<Requirement, 4> requirements;
|
||||
for (const auto &req : genericFnType->getRequirements()) {
|
||||
// Substitute into the requirement.
|
||||
auto substReqt = req.subst(substitutions, lookupConformances, options);
|
||||
if (!substReqt) {
|
||||
anySemanticChanges = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Did anything change?
|
||||
if (!anySemanticChanges &&
|
||||
(!req.getFirstType()->isEqual(substReqt->getFirstType()) ||
|
||||
(req.getKind() != RequirementKind::Layout &&
|
||||
!req.getSecondType()->isEqual(substReqt->getSecondType())))) {
|
||||
anySemanticChanges = true;
|
||||
}
|
||||
|
||||
// Skip any erroneous requirements.
|
||||
if (substReqt->getFirstType()->hasError() ||
|
||||
(substReqt->getKind() != RequirementKind::Layout &&
|
||||
substReqt->getSecondType()->hasError()))
|
||||
continue;
|
||||
|
||||
requirements.push_back(*substReqt);
|
||||
}
|
||||
|
||||
GenericSignature *genericSig = nullptr;
|
||||
if (anySemanticChanges) {
|
||||
// If there were semantic changes, we need to build a new generic
|
||||
// signature.
|
||||
GenericSignatureBuilder builder(genericFnType->getASTContext());
|
||||
|
||||
// Add the generic parameters to the builder.
|
||||
for (auto gp : genericParams)
|
||||
builder.addGenericParameter(gp);
|
||||
|
||||
// Add the requirements to the builder.
|
||||
auto source =
|
||||
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
|
||||
for (const auto &req : requirements)
|
||||
builder.addRequirement(req, source, /*inferForModule=*/nullptr);
|
||||
|
||||
// Form the generic signature.
|
||||
genericSig = std::move(builder).computeGenericSignature(SourceLoc());
|
||||
} else {
|
||||
// Use the mapped generic signature.
|
||||
genericSig = GenericSignature::get(genericParams, requirements);
|
||||
}
|
||||
|
||||
// Produce the new generic function type.
|
||||
return GenericFunctionType::get(genericSig, fnType->getInput(),
|
||||
fnType->getResult(), fnType->getExtInfo());
|
||||
}
|
||||
|
||||
static Type substType(Type derivedType,
|
||||
TypeSubstitutionFn substitutions,
|
||||
LookupConformanceFn lookupConformances,
|
||||
SubstOptions options) {
|
||||
// Handle substitutions into generic function types.
|
||||
if (auto genericFnType = derivedType->getAs<GenericFunctionType>()) {
|
||||
return substGenericFunctionType(genericFnType, substitutions,
|
||||
lookupConformances, options);
|
||||
}
|
||||
|
||||
// FIXME: Change getTypeOfMember() to not pass GenericFunctionType here
|
||||
if (!derivedType->hasArchetype() &&
|
||||
!derivedType->hasTypeParameter() &&
|
||||
!derivedType->is<GenericFunctionType>())
|
||||
!derivedType->hasTypeParameter())
|
||||
return derivedType;
|
||||
|
||||
return derivedType.transformRec([&](TypeBase *type) -> Optional<Type> {
|
||||
@@ -3409,6 +3505,7 @@ Type TypeBase::getTypeOfMember(ModuleDecl *module, const ValueDecl *member,
|
||||
|
||||
assert(memberType);
|
||||
|
||||
// Perform the substitution.
|
||||
auto substitutions = getMemberSubstitutionMap(module, member);
|
||||
return memberType.subst(substitutions, SubstFlags::UseErrorType);
|
||||
}
|
||||
@@ -3777,6 +3874,7 @@ case TypeKind::Id:
|
||||
return DependentMemberType::get(dependentBase, dependent->getName());
|
||||
}
|
||||
|
||||
case TypeKind::GenericFunction:
|
||||
case TypeKind::Function: {
|
||||
auto function = cast<AnyFunctionType>(base);
|
||||
auto inputTy = function->getInput().transformRec(fn);
|
||||
@@ -3786,131 +3884,33 @@ case TypeKind::Id:
|
||||
if (!resultTy)
|
||||
return Type();
|
||||
|
||||
if (inputTy.getPointer() == function->getInput().getPointer() &&
|
||||
resultTy.getPointer() == function->getResult().getPointer())
|
||||
return *this;
|
||||
bool isUnchanged =
|
||||
inputTy.getPointer() == function->getInput().getPointer() &&
|
||||
resultTy.getPointer() == function->getResult().getPointer();
|
||||
|
||||
if (auto genericFnType = dyn_cast<GenericFunctionType>(base)) {
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Check that generic parameters won't be trasnformed.
|
||||
// Transform generic parameters.
|
||||
for (auto param : genericFnType->getGenericParams()) {
|
||||
assert(Type(param).transformRec(fn).getPointer() == param &&
|
||||
"GenericFunctionType transform() changes type parameter");
|
||||
}
|
||||
#endif
|
||||
if (isUnchanged) return *this;
|
||||
|
||||
auto genericSig = genericFnType->getGenericSignature();
|
||||
return GenericFunctionType::get(genericSig, inputTy, resultTy,
|
||||
function->getExtInfo());
|
||||
}
|
||||
|
||||
if (isUnchanged) return *this;
|
||||
|
||||
return FunctionType::get(inputTy, resultTy,
|
||||
function->getExtInfo());
|
||||
}
|
||||
|
||||
case TypeKind::GenericFunction: {
|
||||
GenericFunctionType *function = cast<GenericFunctionType>(base);
|
||||
bool anyChanges = false;
|
||||
|
||||
// Transform generic parameters.
|
||||
SmallVector<GenericTypeParamType *, 4> genericParams;
|
||||
for (auto param : function->getGenericParams()) {
|
||||
Type paramTy = Type(param).transformRec(fn);
|
||||
if (!paramTy)
|
||||
return Type();
|
||||
|
||||
if (auto newParam = paramTy->getAs<GenericTypeParamType>()) {
|
||||
if (newParam != param)
|
||||
anyChanges = true;
|
||||
|
||||
genericParams.push_back(newParam);
|
||||
} else {
|
||||
anyChanges = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Transform requirements.
|
||||
SmallVector<Requirement, 4> requirements;
|
||||
for (const auto &req : function->getRequirements()) {
|
||||
auto firstType = req.getFirstType().transformRec(fn);
|
||||
if (!firstType)
|
||||
return Type();
|
||||
|
||||
if (firstType.getPointer() != req.getFirstType().getPointer())
|
||||
anyChanges = true;
|
||||
|
||||
if (req.getKind() == RequirementKind::Layout) {
|
||||
if (!firstType->isTypeParameter())
|
||||
continue;
|
||||
|
||||
requirements.push_back(Requirement(req.getKind(), firstType,
|
||||
req.getLayoutConstraint()));
|
||||
continue;
|
||||
}
|
||||
|
||||
Type secondType = req.getSecondType();
|
||||
if (secondType) {
|
||||
secondType = secondType.transformRec(fn);
|
||||
if (!secondType)
|
||||
return Type();
|
||||
|
||||
if (secondType.getPointer() != req.getSecondType().getPointer())
|
||||
anyChanges = true;
|
||||
}
|
||||
|
||||
if (!firstType->isTypeParameter()) {
|
||||
if (!secondType || !secondType->isTypeParameter())
|
||||
continue;
|
||||
std::swap(firstType, secondType);
|
||||
}
|
||||
|
||||
requirements.push_back(Requirement(req.getKind(), firstType,
|
||||
secondType));
|
||||
}
|
||||
|
||||
// Transform input type.
|
||||
auto inputTy = function->getInput().transformRec(fn);
|
||||
if (!inputTy)
|
||||
return Type();
|
||||
|
||||
// Transform result type.
|
||||
auto resultTy = function->getResult().transformRec(fn);
|
||||
if (!resultTy)
|
||||
return Type();
|
||||
|
||||
// Check whether anything changed.
|
||||
if (!anyChanges &&
|
||||
inputTy.getPointer() == function->getInput().getPointer() &&
|
||||
resultTy.getPointer() == function->getResult().getPointer())
|
||||
return *this;
|
||||
|
||||
// If no generic parameters remain, this is a non-generic function type.
|
||||
if (genericParams.empty()) {
|
||||
return FunctionType::get(inputTy, resultTy, function->getExtInfo());
|
||||
}
|
||||
|
||||
// Sort/unique the generic parameters by depth/index.
|
||||
using llvm::array_pod_sort;
|
||||
array_pod_sort(genericParams.begin(), genericParams.end(),
|
||||
[](GenericTypeParamType * const * gpp1,
|
||||
GenericTypeParamType * const * gpp2) {
|
||||
auto gp1 = *gpp1;
|
||||
auto gp2 = *gpp2;
|
||||
|
||||
if (gp1->getDepth() < gp2->getDepth())
|
||||
return -1;
|
||||
|
||||
if (gp1->getDepth() > gp2->getDepth())
|
||||
return 1;
|
||||
|
||||
if (gp1->getIndex() < gp2->getIndex())
|
||||
return -1;
|
||||
|
||||
if (gp1->getIndex() > gp2->getIndex())
|
||||
return 1;
|
||||
|
||||
return 0;
|
||||
});
|
||||
genericParams.erase(std::unique(genericParams.begin(), genericParams.end(),
|
||||
[](GenericTypeParamType *gp1,
|
||||
GenericTypeParamType *gp2) {
|
||||
return gp1->getDepth() == gp2->getDepth()
|
||||
&& gp1->getIndex() == gp2->getIndex();
|
||||
}),
|
||||
genericParams.end());
|
||||
|
||||
// Produce the new generic function type.
|
||||
auto sig = GenericSignature::get(genericParams, requirements);
|
||||
return GenericFunctionType::get(sig, inputTy, resultTy,
|
||||
function->getExtInfo());
|
||||
}
|
||||
|
||||
case TypeKind::ArraySlice: {
|
||||
auto slice = cast<ArraySliceType>(base);
|
||||
auto baseTy = slice->getBaseType().transformRec(fn);
|
||||
|
||||
Reference in New Issue
Block a user