[AST] Properly handle substitution of generic function types.

If substitution into the generic parameters of a generic function type
changes those generic parameters, build a new generic signature from
the resulting parameters. Otherwise, we can end up building ill-formed
generic signatures that fail validation.
This commit is contained in:
Doug Gregor
2017-11-06 21:19:46 -08:00
parent 42868b3bd6
commit 06d1679524

View File

@@ -18,6 +18,7 @@
#include "ForeignRepresentationInfo.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/ExistentialLayout.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/TypeVisitor.h"
#include "swift/AST/TypeWalker.h"
#include "swift/AST/Decl.h"
@@ -3089,15 +3090,110 @@ Type DependentMemberType::substBaseType(Type substBase,
getAssocType(), getName(), None);
}
static Type substGenericFunctionType(GenericFunctionType *genericFnType,
TypeSubstitutionFn substitutions,
LookupConformanceFn lookupConformances,
SubstOptions options) {
// Substitute into the function type (without generic signature).
auto *bareFnType = FunctionType::get(genericFnType->getInput(),
genericFnType->getResult(),
genericFnType->getExtInfo());
Type result =
Type(bareFnType).subst(substitutions, lookupConformances, options);
if (!result || result->is<ErrorType>()) return result;
auto *fnType = result->castTo<FunctionType>();
// Substitute generic parameters.
bool anySemanticChanges = false;
SmallVector<GenericTypeParamType *, 4> genericParams;
for (auto param : genericFnType->getGenericParams()) {
Type paramTy =
Type(param).subst(substitutions, lookupConformances, options);
if (!paramTy)
return Type();
if (auto newParam = paramTy->getAs<GenericTypeParamType>()) {
if (!newParam->isEqual(param))
anySemanticChanges = true;
genericParams.push_back(newParam);
} else {
anySemanticChanges = true;
}
}
// If no generic parameters remain, this is a non-generic function type.
if (genericParams.empty())
return result;
// Transform requirements.
SmallVector<Requirement, 4> requirements;
for (const auto &req : genericFnType->getRequirements()) {
// Substitute into the requirement.
auto substReqt = req.subst(substitutions, lookupConformances, options);
if (!substReqt) {
anySemanticChanges = true;
continue;
}
// Did anything change?
if (!anySemanticChanges &&
(!req.getFirstType()->isEqual(substReqt->getFirstType()) ||
(req.getKind() != RequirementKind::Layout &&
!req.getSecondType()->isEqual(substReqt->getSecondType())))) {
anySemanticChanges = true;
}
// Skip any erroneous requirements.
if (substReqt->getFirstType()->hasError() ||
(substReqt->getKind() != RequirementKind::Layout &&
substReqt->getSecondType()->hasError()))
continue;
requirements.push_back(*substReqt);
}
GenericSignature *genericSig = nullptr;
if (anySemanticChanges) {
// If there were semantic changes, we need to build a new generic
// signature.
GenericSignatureBuilder builder(genericFnType->getASTContext());
// Add the generic parameters to the builder.
for (auto gp : genericParams)
builder.addGenericParameter(gp);
// Add the requirements to the builder.
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
for (const auto &req : requirements)
builder.addRequirement(req, source, /*inferForModule=*/nullptr);
// Form the generic signature.
genericSig = std::move(builder).computeGenericSignature(SourceLoc());
} else {
// Use the mapped generic signature.
genericSig = GenericSignature::get(genericParams, requirements);
}
// Produce the new generic function type.
return GenericFunctionType::get(genericSig, fnType->getInput(),
fnType->getResult(), fnType->getExtInfo());
}
static Type substType(Type derivedType,
TypeSubstitutionFn substitutions,
LookupConformanceFn lookupConformances,
SubstOptions options) {
// Handle substitutions into generic function types.
if (auto genericFnType = derivedType->getAs<GenericFunctionType>()) {
return substGenericFunctionType(genericFnType, substitutions,
lookupConformances, options);
}
// FIXME: Change getTypeOfMember() to not pass GenericFunctionType here
if (!derivedType->hasArchetype() &&
!derivedType->hasTypeParameter() &&
!derivedType->is<GenericFunctionType>())
!derivedType->hasTypeParameter())
return derivedType;
return derivedType.transformRec([&](TypeBase *type) -> Optional<Type> {
@@ -3409,6 +3505,7 @@ Type TypeBase::getTypeOfMember(ModuleDecl *module, const ValueDecl *member,
assert(memberType);
// Perform the substitution.
auto substitutions = getMemberSubstitutionMap(module, member);
return memberType.subst(substitutions, SubstFlags::UseErrorType);
}
@@ -3777,6 +3874,7 @@ case TypeKind::Id:
return DependentMemberType::get(dependentBase, dependent->getName());
}
case TypeKind::GenericFunction:
case TypeKind::Function: {
auto function = cast<AnyFunctionType>(base);
auto inputTy = function->getInput().transformRec(fn);
@@ -3786,131 +3884,33 @@ case TypeKind::Id:
if (!resultTy)
return Type();
if (inputTy.getPointer() == function->getInput().getPointer() &&
resultTy.getPointer() == function->getResult().getPointer())
return *this;
bool isUnchanged =
inputTy.getPointer() == function->getInput().getPointer() &&
resultTy.getPointer() == function->getResult().getPointer();
if (auto genericFnType = dyn_cast<GenericFunctionType>(base)) {
#ifndef NDEBUG
// Check that generic parameters won't be trasnformed.
// Transform generic parameters.
for (auto param : genericFnType->getGenericParams()) {
assert(Type(param).transformRec(fn).getPointer() == param &&
"GenericFunctionType transform() changes type parameter");
}
#endif
if (isUnchanged) return *this;
auto genericSig = genericFnType->getGenericSignature();
return GenericFunctionType::get(genericSig, inputTy, resultTy,
function->getExtInfo());
}
if (isUnchanged) return *this;
return FunctionType::get(inputTy, resultTy,
function->getExtInfo());
}
case TypeKind::GenericFunction: {
GenericFunctionType *function = cast<GenericFunctionType>(base);
bool anyChanges = false;
// Transform generic parameters.
SmallVector<GenericTypeParamType *, 4> genericParams;
for (auto param : function->getGenericParams()) {
Type paramTy = Type(param).transformRec(fn);
if (!paramTy)
return Type();
if (auto newParam = paramTy->getAs<GenericTypeParamType>()) {
if (newParam != param)
anyChanges = true;
genericParams.push_back(newParam);
} else {
anyChanges = true;
}
}
// Transform requirements.
SmallVector<Requirement, 4> requirements;
for (const auto &req : function->getRequirements()) {
auto firstType = req.getFirstType().transformRec(fn);
if (!firstType)
return Type();
if (firstType.getPointer() != req.getFirstType().getPointer())
anyChanges = true;
if (req.getKind() == RequirementKind::Layout) {
if (!firstType->isTypeParameter())
continue;
requirements.push_back(Requirement(req.getKind(), firstType,
req.getLayoutConstraint()));
continue;
}
Type secondType = req.getSecondType();
if (secondType) {
secondType = secondType.transformRec(fn);
if (!secondType)
return Type();
if (secondType.getPointer() != req.getSecondType().getPointer())
anyChanges = true;
}
if (!firstType->isTypeParameter()) {
if (!secondType || !secondType->isTypeParameter())
continue;
std::swap(firstType, secondType);
}
requirements.push_back(Requirement(req.getKind(), firstType,
secondType));
}
// Transform input type.
auto inputTy = function->getInput().transformRec(fn);
if (!inputTy)
return Type();
// Transform result type.
auto resultTy = function->getResult().transformRec(fn);
if (!resultTy)
return Type();
// Check whether anything changed.
if (!anyChanges &&
inputTy.getPointer() == function->getInput().getPointer() &&
resultTy.getPointer() == function->getResult().getPointer())
return *this;
// If no generic parameters remain, this is a non-generic function type.
if (genericParams.empty()) {
return FunctionType::get(inputTy, resultTy, function->getExtInfo());
}
// Sort/unique the generic parameters by depth/index.
using llvm::array_pod_sort;
array_pod_sort(genericParams.begin(), genericParams.end(),
[](GenericTypeParamType * const * gpp1,
GenericTypeParamType * const * gpp2) {
auto gp1 = *gpp1;
auto gp2 = *gpp2;
if (gp1->getDepth() < gp2->getDepth())
return -1;
if (gp1->getDepth() > gp2->getDepth())
return 1;
if (gp1->getIndex() < gp2->getIndex())
return -1;
if (gp1->getIndex() > gp2->getIndex())
return 1;
return 0;
});
genericParams.erase(std::unique(genericParams.begin(), genericParams.end(),
[](GenericTypeParamType *gp1,
GenericTypeParamType *gp2) {
return gp1->getDepth() == gp2->getDepth()
&& gp1->getIndex() == gp2->getIndex();
}),
genericParams.end());
// Produce the new generic function type.
auto sig = GenericSignature::get(genericParams, requirements);
return GenericFunctionType::get(sig, inputTy, resultTy,
function->getExtInfo());
}
case TypeKind::ArraySlice: {
auto slice = cast<ArraySliceType>(base);
auto baseTy = slice->getBaseType().transformRec(fn);