mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[AST] Properly handle substitution of generic function types.
If substitution into the generic parameters of a generic function type changes those generic parameters, build a new generic signature from the resulting parameters. Otherwise, we can end up building ill-formed generic signatures that fail validation.
This commit is contained in:
244
lib/AST/Type.cpp
244
lib/AST/Type.cpp
@@ -18,6 +18,7 @@
|
|||||||
#include "ForeignRepresentationInfo.h"
|
#include "ForeignRepresentationInfo.h"
|
||||||
#include "swift/AST/ASTContext.h"
|
#include "swift/AST/ASTContext.h"
|
||||||
#include "swift/AST/ExistentialLayout.h"
|
#include "swift/AST/ExistentialLayout.h"
|
||||||
|
#include "swift/AST/GenericSignatureBuilder.h"
|
||||||
#include "swift/AST/TypeVisitor.h"
|
#include "swift/AST/TypeVisitor.h"
|
||||||
#include "swift/AST/TypeWalker.h"
|
#include "swift/AST/TypeWalker.h"
|
||||||
#include "swift/AST/Decl.h"
|
#include "swift/AST/Decl.h"
|
||||||
@@ -3089,15 +3090,110 @@ Type DependentMemberType::substBaseType(Type substBase,
|
|||||||
getAssocType(), getName(), None);
|
getAssocType(), getName(), None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Type substGenericFunctionType(GenericFunctionType *genericFnType,
|
||||||
|
TypeSubstitutionFn substitutions,
|
||||||
|
LookupConformanceFn lookupConformances,
|
||||||
|
SubstOptions options) {
|
||||||
|
// Substitute into the function type (without generic signature).
|
||||||
|
auto *bareFnType = FunctionType::get(genericFnType->getInput(),
|
||||||
|
genericFnType->getResult(),
|
||||||
|
genericFnType->getExtInfo());
|
||||||
|
Type result =
|
||||||
|
Type(bareFnType).subst(substitutions, lookupConformances, options);
|
||||||
|
if (!result || result->is<ErrorType>()) return result;
|
||||||
|
|
||||||
|
auto *fnType = result->castTo<FunctionType>();
|
||||||
|
// Substitute generic parameters.
|
||||||
|
bool anySemanticChanges = false;
|
||||||
|
SmallVector<GenericTypeParamType *, 4> genericParams;
|
||||||
|
for (auto param : genericFnType->getGenericParams()) {
|
||||||
|
Type paramTy =
|
||||||
|
Type(param).subst(substitutions, lookupConformances, options);
|
||||||
|
if (!paramTy)
|
||||||
|
return Type();
|
||||||
|
|
||||||
|
if (auto newParam = paramTy->getAs<GenericTypeParamType>()) {
|
||||||
|
if (!newParam->isEqual(param))
|
||||||
|
anySemanticChanges = true;
|
||||||
|
|
||||||
|
genericParams.push_back(newParam);
|
||||||
|
} else {
|
||||||
|
anySemanticChanges = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no generic parameters remain, this is a non-generic function type.
|
||||||
|
if (genericParams.empty())
|
||||||
|
return result;
|
||||||
|
|
||||||
|
// Transform requirements.
|
||||||
|
SmallVector<Requirement, 4> requirements;
|
||||||
|
for (const auto &req : genericFnType->getRequirements()) {
|
||||||
|
// Substitute into the requirement.
|
||||||
|
auto substReqt = req.subst(substitutions, lookupConformances, options);
|
||||||
|
if (!substReqt) {
|
||||||
|
anySemanticChanges = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Did anything change?
|
||||||
|
if (!anySemanticChanges &&
|
||||||
|
(!req.getFirstType()->isEqual(substReqt->getFirstType()) ||
|
||||||
|
(req.getKind() != RequirementKind::Layout &&
|
||||||
|
!req.getSecondType()->isEqual(substReqt->getSecondType())))) {
|
||||||
|
anySemanticChanges = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip any erroneous requirements.
|
||||||
|
if (substReqt->getFirstType()->hasError() ||
|
||||||
|
(substReqt->getKind() != RequirementKind::Layout &&
|
||||||
|
substReqt->getSecondType()->hasError()))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
requirements.push_back(*substReqt);
|
||||||
|
}
|
||||||
|
|
||||||
|
GenericSignature *genericSig = nullptr;
|
||||||
|
if (anySemanticChanges) {
|
||||||
|
// If there were semantic changes, we need to build a new generic
|
||||||
|
// signature.
|
||||||
|
GenericSignatureBuilder builder(genericFnType->getASTContext());
|
||||||
|
|
||||||
|
// Add the generic parameters to the builder.
|
||||||
|
for (auto gp : genericParams)
|
||||||
|
builder.addGenericParameter(gp);
|
||||||
|
|
||||||
|
// Add the requirements to the builder.
|
||||||
|
auto source =
|
||||||
|
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
|
||||||
|
for (const auto &req : requirements)
|
||||||
|
builder.addRequirement(req, source, /*inferForModule=*/nullptr);
|
||||||
|
|
||||||
|
// Form the generic signature.
|
||||||
|
genericSig = std::move(builder).computeGenericSignature(SourceLoc());
|
||||||
|
} else {
|
||||||
|
// Use the mapped generic signature.
|
||||||
|
genericSig = GenericSignature::get(genericParams, requirements);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Produce the new generic function type.
|
||||||
|
return GenericFunctionType::get(genericSig, fnType->getInput(),
|
||||||
|
fnType->getResult(), fnType->getExtInfo());
|
||||||
|
}
|
||||||
|
|
||||||
static Type substType(Type derivedType,
|
static Type substType(Type derivedType,
|
||||||
TypeSubstitutionFn substitutions,
|
TypeSubstitutionFn substitutions,
|
||||||
LookupConformanceFn lookupConformances,
|
LookupConformanceFn lookupConformances,
|
||||||
SubstOptions options) {
|
SubstOptions options) {
|
||||||
|
// Handle substitutions into generic function types.
|
||||||
|
if (auto genericFnType = derivedType->getAs<GenericFunctionType>()) {
|
||||||
|
return substGenericFunctionType(genericFnType, substitutions,
|
||||||
|
lookupConformances, options);
|
||||||
|
}
|
||||||
|
|
||||||
// FIXME: Change getTypeOfMember() to not pass GenericFunctionType here
|
// FIXME: Change getTypeOfMember() to not pass GenericFunctionType here
|
||||||
if (!derivedType->hasArchetype() &&
|
if (!derivedType->hasArchetype() &&
|
||||||
!derivedType->hasTypeParameter() &&
|
!derivedType->hasTypeParameter())
|
||||||
!derivedType->is<GenericFunctionType>())
|
|
||||||
return derivedType;
|
return derivedType;
|
||||||
|
|
||||||
return derivedType.transformRec([&](TypeBase *type) -> Optional<Type> {
|
return derivedType.transformRec([&](TypeBase *type) -> Optional<Type> {
|
||||||
@@ -3409,6 +3505,7 @@ Type TypeBase::getTypeOfMember(ModuleDecl *module, const ValueDecl *member,
|
|||||||
|
|
||||||
assert(memberType);
|
assert(memberType);
|
||||||
|
|
||||||
|
// Perform the substitution.
|
||||||
auto substitutions = getMemberSubstitutionMap(module, member);
|
auto substitutions = getMemberSubstitutionMap(module, member);
|
||||||
return memberType.subst(substitutions, SubstFlags::UseErrorType);
|
return memberType.subst(substitutions, SubstFlags::UseErrorType);
|
||||||
}
|
}
|
||||||
@@ -3777,6 +3874,7 @@ case TypeKind::Id:
|
|||||||
return DependentMemberType::get(dependentBase, dependent->getName());
|
return DependentMemberType::get(dependentBase, dependent->getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case TypeKind::GenericFunction:
|
||||||
case TypeKind::Function: {
|
case TypeKind::Function: {
|
||||||
auto function = cast<AnyFunctionType>(base);
|
auto function = cast<AnyFunctionType>(base);
|
||||||
auto inputTy = function->getInput().transformRec(fn);
|
auto inputTy = function->getInput().transformRec(fn);
|
||||||
@@ -3786,131 +3884,33 @@ case TypeKind::Id:
|
|||||||
if (!resultTy)
|
if (!resultTy)
|
||||||
return Type();
|
return Type();
|
||||||
|
|
||||||
if (inputTy.getPointer() == function->getInput().getPointer() &&
|
bool isUnchanged =
|
||||||
resultTy.getPointer() == function->getResult().getPointer())
|
inputTy.getPointer() == function->getInput().getPointer() &&
|
||||||
return *this;
|
resultTy.getPointer() == function->getResult().getPointer();
|
||||||
|
|
||||||
|
if (auto genericFnType = dyn_cast<GenericFunctionType>(base)) {
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
// Check that generic parameters won't be trasnformed.
|
||||||
|
// Transform generic parameters.
|
||||||
|
for (auto param : genericFnType->getGenericParams()) {
|
||||||
|
assert(Type(param).transformRec(fn).getPointer() == param &&
|
||||||
|
"GenericFunctionType transform() changes type parameter");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (isUnchanged) return *this;
|
||||||
|
|
||||||
|
auto genericSig = genericFnType->getGenericSignature();
|
||||||
|
return GenericFunctionType::get(genericSig, inputTy, resultTy,
|
||||||
|
function->getExtInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isUnchanged) return *this;
|
||||||
|
|
||||||
return FunctionType::get(inputTy, resultTy,
|
return FunctionType::get(inputTy, resultTy,
|
||||||
function->getExtInfo());
|
function->getExtInfo());
|
||||||
}
|
}
|
||||||
|
|
||||||
case TypeKind::GenericFunction: {
|
|
||||||
GenericFunctionType *function = cast<GenericFunctionType>(base);
|
|
||||||
bool anyChanges = false;
|
|
||||||
|
|
||||||
// Transform generic parameters.
|
|
||||||
SmallVector<GenericTypeParamType *, 4> genericParams;
|
|
||||||
for (auto param : function->getGenericParams()) {
|
|
||||||
Type paramTy = Type(param).transformRec(fn);
|
|
||||||
if (!paramTy)
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
if (auto newParam = paramTy->getAs<GenericTypeParamType>()) {
|
|
||||||
if (newParam != param)
|
|
||||||
anyChanges = true;
|
|
||||||
|
|
||||||
genericParams.push_back(newParam);
|
|
||||||
} else {
|
|
||||||
anyChanges = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transform requirements.
|
|
||||||
SmallVector<Requirement, 4> requirements;
|
|
||||||
for (const auto &req : function->getRequirements()) {
|
|
||||||
auto firstType = req.getFirstType().transformRec(fn);
|
|
||||||
if (!firstType)
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
if (firstType.getPointer() != req.getFirstType().getPointer())
|
|
||||||
anyChanges = true;
|
|
||||||
|
|
||||||
if (req.getKind() == RequirementKind::Layout) {
|
|
||||||
if (!firstType->isTypeParameter())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
requirements.push_back(Requirement(req.getKind(), firstType,
|
|
||||||
req.getLayoutConstraint()));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Type secondType = req.getSecondType();
|
|
||||||
if (secondType) {
|
|
||||||
secondType = secondType.transformRec(fn);
|
|
||||||
if (!secondType)
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
if (secondType.getPointer() != req.getSecondType().getPointer())
|
|
||||||
anyChanges = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!firstType->isTypeParameter()) {
|
|
||||||
if (!secondType || !secondType->isTypeParameter())
|
|
||||||
continue;
|
|
||||||
std::swap(firstType, secondType);
|
|
||||||
}
|
|
||||||
|
|
||||||
requirements.push_back(Requirement(req.getKind(), firstType,
|
|
||||||
secondType));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transform input type.
|
|
||||||
auto inputTy = function->getInput().transformRec(fn);
|
|
||||||
if (!inputTy)
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
// Transform result type.
|
|
||||||
auto resultTy = function->getResult().transformRec(fn);
|
|
||||||
if (!resultTy)
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
// Check whether anything changed.
|
|
||||||
if (!anyChanges &&
|
|
||||||
inputTy.getPointer() == function->getInput().getPointer() &&
|
|
||||||
resultTy.getPointer() == function->getResult().getPointer())
|
|
||||||
return *this;
|
|
||||||
|
|
||||||
// If no generic parameters remain, this is a non-generic function type.
|
|
||||||
if (genericParams.empty()) {
|
|
||||||
return FunctionType::get(inputTy, resultTy, function->getExtInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort/unique the generic parameters by depth/index.
|
|
||||||
using llvm::array_pod_sort;
|
|
||||||
array_pod_sort(genericParams.begin(), genericParams.end(),
|
|
||||||
[](GenericTypeParamType * const * gpp1,
|
|
||||||
GenericTypeParamType * const * gpp2) {
|
|
||||||
auto gp1 = *gpp1;
|
|
||||||
auto gp2 = *gpp2;
|
|
||||||
|
|
||||||
if (gp1->getDepth() < gp2->getDepth())
|
|
||||||
return -1;
|
|
||||||
|
|
||||||
if (gp1->getDepth() > gp2->getDepth())
|
|
||||||
return 1;
|
|
||||||
|
|
||||||
if (gp1->getIndex() < gp2->getIndex())
|
|
||||||
return -1;
|
|
||||||
|
|
||||||
if (gp1->getIndex() > gp2->getIndex())
|
|
||||||
return 1;
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
});
|
|
||||||
genericParams.erase(std::unique(genericParams.begin(), genericParams.end(),
|
|
||||||
[](GenericTypeParamType *gp1,
|
|
||||||
GenericTypeParamType *gp2) {
|
|
||||||
return gp1->getDepth() == gp2->getDepth()
|
|
||||||
&& gp1->getIndex() == gp2->getIndex();
|
|
||||||
}),
|
|
||||||
genericParams.end());
|
|
||||||
|
|
||||||
// Produce the new generic function type.
|
|
||||||
auto sig = GenericSignature::get(genericParams, requirements);
|
|
||||||
return GenericFunctionType::get(sig, inputTy, resultTy,
|
|
||||||
function->getExtInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
case TypeKind::ArraySlice: {
|
case TypeKind::ArraySlice: {
|
||||||
auto slice = cast<ArraySliceType>(base);
|
auto slice = cast<ArraySliceType>(base);
|
||||||
auto baseTy = slice->getBaseType().transformRec(fn);
|
auto baseTy = slice->getBaseType().transformRec(fn);
|
||||||
|
|||||||
Reference in New Issue
Block a user