ModulePrinting: Merge synthesized extensions' bodies if they have the common constraints.

This commit is contained in:
Xi Ge
2016-03-10 17:43:39 -08:00
parent 74afb7800c
commit c22e9bcf53
5 changed files with 228 additions and 122 deletions

View File

@@ -31,6 +31,7 @@ class Type;
enum DeclAttrKind : unsigned;
class PrinterArchetypeTransformer;
class SynthesizedExtensionAnalyzer;
struct PrintOptions;
/// Necessary information for archetype transformation during printing.
struct ArchetypeTransformContext {
@@ -47,29 +48,30 @@ struct ArchetypeTransformContext {
SynthesizedExtensionAnalyzer *Analyzer);
Type transform(Type Input);
StringRef transform(StringRef Input);
bool shouldPrintRequirement(ExtensionDecl *ED, StringRef Req);
bool shouldOpenExtension;
bool shouldCloseExtension;
~ArchetypeTransformContext();
private:
struct Implementation;
Implementation &Impl;
};
struct SynthesizedExtensionInfo {
ExtensionDecl *Ext = nullptr;
std::vector<StringRef> KnownSatisfiedRequirements;
operator bool() const { return Ext; }
};
class SynthesizedExtensionAnalyzer {
struct Implementation;
Implementation &Impl;
public:
SynthesizedExtensionAnalyzer(NominalTypeDecl *Target,
PrintOptions Options,
bool IncludeUnconditional = true);
~SynthesizedExtensionAnalyzer();
void forEachSynthesizedExtension(
llvm::function_ref<void(ExtensionDecl*)> Fn);
void forEachSynthesizedExtensionMergeGroup(
llvm::function_ref<void(ArrayRef<ExtensionDecl*>)> Fn);
bool isInSynthesizedExtension(const ValueDecl *VD);
bool shouldPrintRequirement(ExtensionDecl *ED, StringRef Req);
};

View File

@@ -206,20 +206,30 @@ public:
};
struct SynthesizedExtensionAnalyzer::Implementation {
struct SynthesizedExtensionInfo {
ExtensionDecl *Ext = nullptr;
std::vector<StringRef> KnownSatisfiedRequirements;
unsigned MergeGroup;
operator bool() const { return Ext; }
};
typedef llvm::MapVector<ExtensionDecl*, SynthesizedExtensionInfo> ExtMap;
NominalTypeDecl *Target;
Type BaseType;
DeclContext *DC;
std::unique_ptr<ArchetypeSelfTransformer> pTransform;
bool IncludeUnconditional;
PrintOptions Options;
std::unique_ptr<ExtMap> Results;
Implementation(NominalTypeDecl *Target, bool IncludeUnconditional):
Implementation(NominalTypeDecl *Target, bool IncludeUnconditional,
PrintOptions Options):
Target(Target),
BaseType(Target->getDeclaredTypeInContext()),
DC(Target),
pTransform(new ArchetypeSelfTransformer(Target)),
IncludeUnconditional(IncludeUnconditional),
Options(Options),
Results(collectSynthesizedExtensionInfo()) {}
Type checkElementType(StringRef Text) {
@@ -284,12 +294,46 @@ struct SynthesizedExtensionAnalyzer::Implementation {
return MType->getAs<AnyMetatypeType>()->getInstanceType();
}
SynthesizedExtensionInfo isApplicable(ExtensionDecl *Ext) {
struct ExtensionMergeInfo {
struct Requirement {
Type First;
Type Second;
RequirementReprKind Kind;
bool operator< (const Requirement& Rhs) const {
if (Kind != Rhs.Kind)
return Kind < Rhs.Kind;
else if (First.getPointer() != Rhs.First.getPointer())
return First.getPointer() < Rhs.First.getPointer();
else
return Second.getPointer() < Rhs.Second.getPointer();
}
bool operator== (const Requirement& Rhs) const {
return (!(*this < Rhs)) && (!(Rhs < *this));
}
};
bool HasDocComment;
std::set<Requirement> Requirements;
void addRequirement(Type First, Type Second, RequirementReprKind Kind) {
Requirements.insert({First, Second, Kind});
}
bool operator== (const ExtensionMergeInfo& Another) const {
// Trivially unmergable.
if (HasDocComment || Another.HasDocComment)
return false;
return Requirements == Another.Requirements;
}
};
std::pair<SynthesizedExtensionInfo, std::unique_ptr<ExtensionMergeInfo>>
isApplicable(ExtensionDecl *Ext) {
SynthesizedExtensionInfo Result;
std::unique_ptr<ExtensionMergeInfo> MergeInfo(new ExtensionMergeInfo());
MergeInfo->HasDocComment = false;
if (!Ext->isConstrainedExtension()) {
if (IncludeUnconditional)
Result.Ext = Ext;
return Result;
return {Result, std::move(MergeInfo)};
}
assert(Ext->getGenericParams() && "No generic params.");
for (auto Req : Ext->getGenericParams()->getRequirements()){
@@ -308,30 +352,61 @@ struct SynthesizedExtensionAnalyzer::Implementation {
if (First && Second) {
First = First->getDesugaredType();
Second = Second->getDesugaredType();
auto Written = Req.getAsWrittenString();
switch (Kind) {
case RequirementReprKind::TypeConstraint:
if(!canPossiblyConvertTo(First, Second, *DC))
return Result;
return {Result, std::move(MergeInfo)};
else if (isConvertibleTo(First, Second, *DC))
Result.KnownSatisfiedRequirements.push_back(Req.getAsWrittenString());
Result.KnownSatisfiedRequirements.push_back(Written);
else
MergeInfo->addRequirement(First, Second, Kind);
break;
case RequirementReprKind::SameType:
if (!canPossiblyEqual(First, Second, *DC))
return Result;
return {Result, std::move(MergeInfo)};
else if (isEqual(First, Second, *DC))
Result.KnownSatisfiedRequirements.push_back(Req.getAsWrittenString());
Result.KnownSatisfiedRequirements.push_back(Written);
else
MergeInfo->addRequirement(First, Second, Kind);
break;
}
}
}
Result.Ext = Ext;
return Result;
return {Result, std::move(MergeInfo)};
}
typedef llvm::MapVector<ExtensionDecl*, std::unique_ptr<ExtensionMergeInfo>>
ExtensionMergeInfoMap;
typedef llvm::MapVector<ExtensionDecl*, unsigned> ExtensionMergeGroupMap;
void calculateMergeGroup(ExtensionMergeInfoMap &MergeInfoMap,
ExtensionMergeGroupMap &MergeGroupMap) {
std::vector<std::pair<ExtensionMergeInfo*, unsigned>> KnownGroups;
unsigned NewGroupNum = 0;
for (auto &Pair : MergeInfoMap) {
auto Key = Pair.first;
auto Found = std::find_if(KnownGroups.begin(), KnownGroups.end(),
[&](std::pair<ExtensionMergeInfo*, unsigned> LHS) {
return (*LHS.first) == (*Pair.second);
});
if (Found != KnownGroups.end()) {
MergeGroupMap[Key] = (*Found).second;
continue;
}
MergeGroupMap[Key] = NewGroupNum;
KnownGroups.push_back({Pair.second.get(), NewGroupNum});
NewGroupNum ++;
}
}
std::unique_ptr<ExtMap> collectSynthesizedExtensionInfo() {
std::unique_ptr<ExtMap> pMap(new ExtMap());
if (Target->getKind() == DeclKind::Protocol)
return pMap;
ExtensionMergeInfoMap MergeInfoMap;
std::vector<NominalTypeDecl*> Unhandled;
auto addTypeLocNominal = [&](TypeLoc TL){
if (TL.getType()) {
@@ -347,21 +422,32 @@ struct SynthesizedExtensionAnalyzer::Implementation {
NominalTypeDecl* Back = Unhandled.back();
Unhandled.pop_back();
for (ExtensionDecl *E : Back->getExtensions()) {
if (auto Info = isApplicable(E))
(*pMap)[E] = Info;
if (!shouldPrint(E, Options))
continue;
auto Pair = isApplicable(E);
if (Pair.first) {
(*pMap)[E] = Pair.first;
MergeInfoMap[E] = std::move(Pair.second);
}
for (auto TL : Back->getInherited()) {
addTypeLocNominal(TL);
}
}
}
ExtensionMergeGroupMap GroupMap;
calculateMergeGroup(MergeInfoMap, GroupMap);
for(auto It : *pMap) {
(*pMap)[It.first].MergeGroup = GroupMap[It.first];
}
return pMap;
}
};
SynthesizedExtensionAnalyzer::
SynthesizedExtensionAnalyzer(NominalTypeDecl *Target,
PrintOptions Options,
bool IncludeUnconditional):
Impl(*(new Implementation(Target, IncludeUnconditional))) {}
Impl(*(new Implementation(Target, IncludeUnconditional, Options))) {}
SynthesizedExtensionAnalyzer::~SynthesizedExtensionAnalyzer() {delete &Impl;}
@@ -381,6 +467,21 @@ forEachSynthesizedExtension(llvm::function_ref<void(ExtensionDecl*)> Fn) {
}
}
void SynthesizedExtensionAnalyzer::
forEachSynthesizedExtensionMergeGroup(
llvm::function_ref<void(ArrayRef<ExtensionDecl*>)> Fn) {
llvm::DenseMap<unsigned, std::vector<ExtensionDecl*>> GroupBags;
for (auto It : *Impl.Results) {
unsigned Group = It.second.MergeGroup;
if (GroupBags.find(Group) == GroupBags.end())
GroupBags[Group] = std::vector<ExtensionDecl*>();
GroupBags[Group].push_back(It.first);
}
for (auto It : GroupBags) {
Fn(llvm::makeArrayRef(It.second));
}
}
bool SynthesizedExtensionAnalyzer::
shouldPrintRequirement(ExtensionDecl *ED, StringRef Req) {
auto Found = Impl.Results->find(ED);
@@ -865,8 +966,10 @@ private:
bool shouldPrintPattern(const Pattern *P);
void printPatternType(const Pattern *P);
void printAccessors(AbstractStorageDecl *ASD);
void printMembersOfDecl(Decl * NTD, bool needComma = false);
void printMembers(ArrayRef<Decl *> members, bool needComma = false);
void printMembersOfDecl(Decl * NTD, bool needComma = false,
bool openBracket = true, bool closeBracket = true);
void printMembers(ArrayRef<Decl *> members, bool needComma = false,
bool openBracket = true, bool closeBracket = true);
void printNominalDeclGenericParams(NominalTypeDecl *decl);
void printInherited(const Decl *decl,
ArrayRef<TypeLoc> inherited,
@@ -1522,7 +1625,9 @@ void PrintAST::printAccessors(AbstractStorageDecl *ASD) {
Printer << "}";
}
void PrintAST::printMembersOfDecl(Decl *D, bool needComma) {
void PrintAST::printMembersOfDecl(Decl *D, bool needComma,
bool openBracket,
bool closeBracket) {
llvm::SmallVector<Decl *, 3> Members;
auto AddDeclFunc = [&](DeclRange Range) {
for (auto RD : Range)
@@ -1538,10 +1643,12 @@ void PrintAST::printMembersOfDecl(Decl *D, bool needComma) {
AddDeclFunc(Ext->getMembers());
}
}
printMembers(Members, needComma);
printMembers(Members, needComma, openBracket, closeBracket);
}
void PrintAST::printMembers(ArrayRef<Decl *> members, bool needComma) {
void PrintAST::printMembers(ArrayRef<Decl *> members, bool needComma,
bool openBracket, bool closeBracket) {
if (openBracket)
Printer << " {";
Printer.printNewline();
{
@@ -1565,6 +1672,7 @@ void PrintAST::printMembers(ArrayRef<Decl *> members, bool needComma) {
}
}
indent();
if (closeBracket)
Printer << "}";
}
@@ -1795,6 +1903,7 @@ static void printExtendedTypeName(Type ExtendedType, ASTPrinter &Printer,
void PrintAST::
printSynthesizedExtension(NominalTypeDecl* Decl, ExtensionDecl *ExtDecl) {
if (Options.TransformContext->shouldOpenExtension) {
Printer << "/// Synthesized extension from " <<
ExtDecl->getExtendedType()->getAnyNominal()->getName().str() << "\n";
printDocumentationComment(ExtDecl);
@@ -1813,9 +1922,11 @@ printSynthesizedExtension(NominalTypeDecl* Decl, ExtensionDecl *ExtDecl) {
}
printWhereClause(ReqsToPrint);
}
}
if (Options.TypeDefinitions) {
printMembersOfDecl(ExtDecl);
printMembersOfDecl(ExtDecl, false,
Options.TransformContext->shouldOpenExtension,
Options.TransformContext->shouldCloseExtension);
}
}

View File

@@ -479,15 +479,22 @@ void swift::ide::printSubmoduleInterface(
continue;
// Print synthesized extensions.
SynthesizedExtensionAnalyzer Analyzer(NTD);
SynthesizedExtensionAnalyzer Analyzer(NTD, AdjustedOptions);
AdjustedOptions.initArchetypeTransformerForSynthesizedExtensions(NTD,
&Analyzer);
Analyzer.forEachSynthesizedExtension([&](ExtensionDecl *ET){
if (!shouldPrint(ET, AdjustedOptions))
return;
Analyzer.forEachSynthesizedExtensionMergeGroup(
[&](ArrayRef<ExtensionDecl*> Decls){
for (auto ET : Decls) {
AdjustedOptions.TransformContext->shouldOpenExtension =
Decls.front() == ET;
AdjustedOptions.TransformContext->shouldCloseExtension =
Decls.back() == ET;
if (AdjustedOptions.TransformContext->shouldOpenExtension)
Printer << "\n";
ET->print(Printer, AdjustedOptions);
if (AdjustedOptions.TransformContext->shouldCloseExtension)
Printer << "\n";
}
});
AdjustedOptions.clearArchetypeTransformerForSynthesizedExtensions();
}

View File

@@ -125,12 +125,10 @@ public struct S10 : P1 {
// CHECK: <synthesized>/// Synthesized extension from P2
// CHECK-NEXT: extension <ref:Struct>S1</ref> where T : P2 {
// CHECK-NEXT: <decl:Func>public func <loc>p2member()</loc></decl>
// CHECK-NEXT: }</synthesized>
// CHECK: <synthesized>/// Synthesized extension from P1
// CHECK-NEXT: extension <ref:Struct>S1</ref> where T == Int {
// CHECK-NEXT: <decl:Func>public func <loc>p1IntFunc(<decl:Param>i: <ref:Struct>Int</ref></decl>)</loc> -> <ref:Struct>Int</ref></decl>
// CHECK-NEXT: <decl:Func>public func <loc>p2member()</loc></decl></synthesized>
// CHECK-NEXT: <synthesized>
// CHECK-NEXT: <decl:Func>public func <loc>ef1(<decl:Param>t: T</decl>)</loc></decl>
// CHECK-NEXT: <decl:Func>public func <loc>ef2(<decl:Param>t: <ref:Struct>S2</ref></decl>)</loc></decl>
// CHECK-NEXT: }</synthesized>
// CHECK: <synthesized>/// Synthesized extension from P1
@@ -139,9 +137,8 @@ public struct S10 : P1 {
// CHECK-NEXT: }</synthesized>
// CHECK: <synthesized>/// Synthesized extension from P1
// CHECK-NEXT: extension <ref:Struct>S1</ref> where T : P2 {
// CHECK-NEXT: <decl:Func>public func <loc>ef1(<decl:Param>t: T</decl>)</loc></decl>
// CHECK-NEXT: <decl:Func>public func <loc>ef2(<decl:Param>t: <ref:Struct>S2</ref></decl>)</loc></decl>
// CHECK-NEXT: extension <ref:Struct>S1</ref> where T == Int {
// CHECK-NEXT: <decl:Func>public func <loc>p1IntFunc(<decl:Param>i: <ref:Struct>Int</ref></decl>)</loc> -> <ref:Struct>Int</ref></decl>
// CHECK-NEXT: }</synthesized>
// CHECK: <synthesized>/// Synthesized extension from P1
@@ -151,16 +148,10 @@ public struct S10 : P1 {
// CHECK: <synthesized>/// Synthesized extension from P1
// CHECK-NEXT: extension <ref:Struct>S10</ref> {
// CHECK-NEXT: <decl:Func>public func <loc>p3Func(<decl:Param>i: <ref:Struct>Int</ref></decl>)</loc> -> <ref:Struct>Int</ref></decl>
// CHECK-NEXT: }</synthesized>
// CHECK: <synthesized>/// Synthesized extension from P1
// CHECK-NEXT: extension <ref:Struct>S10</ref> {
// CHECK-NEXT: <decl:Func>public func <loc>ef5(<decl:Param>t: <ref:Struct>S9</ref><<ref:Struct>Int</ref>></decl>)</loc></decl>
// CHECK-NEXT: }</synthesized>
// CHECK: <synthesized>/// Synthesized extension from P1
// CHECK-NEXT: extension <ref:Struct>S10</ref> {
// CHECK-NEXT: <decl:Func>public func <loc>p3Func(<decl:Param>i: <ref:Struct>Int</ref></decl>)</loc> -> <ref:Struct>Int</ref></decl></synthesized>
// CHECK-NEXT: <synthesized>
// CHECK-NEXT: <decl:Func>public func <loc>ef5(<decl:Param>t: <ref:Struct>S9</ref><<ref:Struct>Int</ref>></decl>)</loc></decl></synthesized>
// CHECK-NEXT: <synthesized>
// CHECK-NEXT: <decl:Func>public func <loc>S9IntFunc()</loc></decl>
// CHECK-NEXT: }</synthesized>
@@ -171,20 +162,14 @@ public struct S10 : P1 {
// CHECK: <synthesized>/// Synthesized extension from P1
// CHECK-NEXT: extension <ref:Struct>S6</ref> {
// CHECK-NEXT: <decl:Func>public func <loc>p3Func(<decl:Param>i: <ref:Struct>Int</ref></decl>)</loc> -> <ref:Struct>Int</ref></decl>
// CHECK-NEXT: }</synthesized>
// CHECK: <synthesized>/// Synthesized extension from P1
// CHECK-NEXT: extension <ref:Struct>S6</ref> {
// CHECK-NEXT: <decl:Func>public func <loc>p3Func(<decl:Param>i: <ref:Struct>Int</ref></decl>)</loc> -> <ref:Struct>Int</ref></decl></synthesized>
// CHECK-NEXT: <synthesized>
// CHECK-NEXT: <decl:Func>public func <loc>ef5(<decl:Param>t: <ref:Struct>S5</ref></decl>)</loc></decl>
// CHECK-NEXT: }</synthesized>
// CHECK: <synthesized>/// Synthesized extension from P1
// CHECK-NEXT: extension <ref:module>print_synthesized_extensions</ref>.<ref:Struct>S7</ref>.<ref:Struct>S8</ref> {
// CHECK-NEXT: <decl:Func>public func <loc>p3Func(<decl:Param>i: <ref:Struct>Int</ref></decl>)</loc> -> <ref:Struct>Int</ref></decl>
// CHECK-NEXT: }</synthesized>
// CHECK: <synthesized>/// Synthesized extension from P1
// CHECK-NEXT: extension <ref:module>print_synthesized_extensions</ref>.<ref:Struct>S7</ref>.<ref:Struct>S8</ref> {
// CHECK-NEXT: <decl:Func>public func <loc>p3Func(<decl:Param>i: <ref:Struct>Int</ref></decl>)</loc> -> <ref:Struct>Int</ref></decl></synthesized>
// CHECK-NEXT: <synthesized>
// CHECK-NEXT: <decl:Func>public func <loc>ef5(<decl:Param>t: <ref:Struct>S5</ref></decl>)</loc></decl>
// CHECK-NEXT: }</synthesized>

View File

@@ -594,7 +594,8 @@ static bool passCursorInfoForDecl(const ValueDecl *VD,
bool InSynthesizedExtension = false;
if (BaseType) {
if(auto Target = BaseType->getAnyNominal()) {
SynthesizedExtensionAnalyzer Analyzer(Target);
SynthesizedExtensionAnalyzer Analyzer(Target,
PrintOptions::printInterface());
InSynthesizedExtension = Analyzer.isInSynthesizedExtension(VD);
}
}