Merge pull request #29847 from tkachukandrew/add-equatable-conformance

This commit is contained in:
swift-ci
2020-04-27 14:01:17 -07:00
committed by GitHub
8 changed files with 399 additions and 0 deletions

View File

@@ -52,6 +52,8 @@ CURSOR_REFACTORING(TrailingClosure, "Convert To Trailing Closure", trailingclosu
CURSOR_REFACTORING(MemberwiseInitLocalRefactoring, "Generate Memberwise Initializer", memberwise.init.local.refactoring)
CURSOR_REFACTORING(AddEquatableConformance, "Add Equatable Conformance", add.equatable.conformance)
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)

View File

@@ -3172,6 +3172,250 @@ bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
return false;
}
class AddEquatableContext {
/// Declaration context
DeclContext *DC;
/// Adopter type
Type Adopter;
/// Start location of declaration context brace
SourceLoc StartLoc;
/// Array of all inherited protocols' locations
ArrayRef<TypeLoc> ProtocolsLocations;
/// Array of all conformed protocols
SmallVector<swift::ProtocolDecl *, 2> Protocols;
/// Start location of declaration,
/// a place to write protocol name
SourceLoc ProtInsertStartLoc;
/// Stored properties of extending adopter
ArrayRef<VarDecl *> StoredProperties;
/// Range of internal members in declaration
DeclRange Range;
bool conformsToEquatableProtocol() {
for (ProtocolDecl *Protocol : Protocols) {
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) {
return true;
}
}
return false;
}
bool isRequirementValid() {
auto Reqs = getProtocolRequirements();
if (Reqs.empty()) {
return false;
}
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
return Req && Req->getParameters()->size() == 2;
}
bool isPropertiesListValid() {
return !getUserAccessibleProperties().empty();
}
void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent,
ParameterList *Params);
std::vector<ValueDecl *> getProtocolRequirements();
std::vector<VarDecl *> getUserAccessibleProperties();
public:
AddEquatableContext(NominalTypeDecl *Decl) : DC(Decl),
Adopter(Decl->getDeclaredType()), StartLoc(Decl->getBraces().Start),
ProtocolsLocations(Decl->getInherited()),
Protocols(Decl->getAllProtocols()), ProtInsertStartLoc(Decl->getNameLoc()),
StoredProperties(Decl->getStoredProperties()), Range(Decl->getMembers()) {};
AddEquatableContext(ExtensionDecl *Decl) : DC(Decl),
Adopter(Decl->getExtendedType()), StartLoc(Decl->getBraces().Start),
ProtocolsLocations(Decl->getInherited()),
Protocols(Decl->getExtendedNominal()->getAllProtocols()),
ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()),
StoredProperties(Decl->getExtendedNominal()->getStoredProperties()), Range(Decl->getMembers()) {};
AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(),
Protocols(), StoredProperties(), Range(nullptr, nullptr) {};
static AddEquatableContext getDeclarationContextFromInfo(ResolvedCursorInfo Info);
std::string getInsertionTextForProtocol();
std::string getInsertionTextForFunction(SourceManager &SM);
bool isValid() {
// FIXME: Allow to generate explicit == method for declarations which already have
// compiler-generated == method
return StartLoc.isValid() && ProtInsertStartLoc.isValid() &&
!conformsToEquatableProtocol() && isPropertiesListValid() &&
isRequirementValid();
}
SourceLoc getStartLocForProtocolDecl() {
if (ProtocolsLocations.empty()) {
return ProtInsertStartLoc;
}
return ProtocolsLocations.back().getSourceRange().Start;
}
bool isMembersRangeEmpty() {
return Range.empty();
}
SourceLoc getInsertStartLoc();
};
SourceLoc AddEquatableContext::
getInsertStartLoc() {
SourceLoc MaxLoc = StartLoc;
for (auto Mem : Range) {
if (Mem->getEndLoc().getOpaquePointerValue() >
MaxLoc.getOpaquePointerValue()) {
MaxLoc = Mem->getEndLoc();
}
}
return MaxLoc;
}
std::string AddEquatableContext::
getInsertionTextForProtocol() {
StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable);
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
if (ProtocolsLocations.empty()) {
OS << ": " << ProtocolName;
return Buffer;
}
OS << ", " << ProtocolName;
return Buffer;
}
std::string AddEquatableContext::
getInsertionTextForFunction(SourceManager &SM) {
auto Reqs = getProtocolRequirements();
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
auto Params = Req->getParameters();
StringRef ExtraIndent;
StringRef CurrentIndent =
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
std::string Indent;
if (isMembersRangeEmpty()) {
Indent = (CurrentIndent + ExtraIndent).str();
} else {
Indent = CurrentIndent.str();
}
PrintOptions Options = PrintOptions::printVerbose();
Options.PrintDocumentationComments = false;
Options.setBaseType(Adopter);
Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
Printer << " {";
Printer.printNewline();
printFunctionBody(Printer, ExtraIndent, Params);
Printer.printNewline();
Printer << "}";
};
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
ExtraIndentStreamPrinter Printer(OS, Indent);
Printer.printNewline();
if (!isMembersRangeEmpty()) {
Printer.printNewline();
}
Reqs[0]->print(Printer, Options);
return Buffer;
}
std::vector<VarDecl *> AddEquatableContext::
getUserAccessibleProperties() {
std::vector<VarDecl *> PublicProperties;
for (VarDecl *Decl : StoredProperties) {
if (Decl->Decl::isUserAccessible()) {
PublicProperties.push_back(Decl);
}
}
return PublicProperties;
}
std::vector<ValueDecl *> AddEquatableContext::
getProtocolRequirements() {
std::vector<ValueDecl *> Collection;
auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable);
for (auto Member : Proto->getMembers()) {
auto Req = dyn_cast<ValueDecl>(Member);
if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) {
continue;
}
Collection.push_back(Req);
}
return Collection;
}
AddEquatableContext AddEquatableContext::
getDeclarationContextFromInfo(ResolvedCursorInfo Info) {
if (Info.isInvalid()) {
return AddEquatableContext();
}
if (!Info.IsRef) {
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD)) {
return AddEquatableContext(NomDecl);
}
} else if (auto *ExtDecl = Info.ExtTyRef) {
return AddEquatableContext(ExtDecl);
}
return AddEquatableContext();
}
void AddEquatableContext::
printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
llvm::SmallString<128> Return;
llvm::raw_svector_ostream SS(Return);
SS << tok::kw_return;
StringRef Space = " ";
StringRef AdditionalSpace = " ";
StringRef Point = ".";
StringRef Join = " == ";
StringRef And = " &&";
auto Props = getUserAccessibleProperties();
auto FParam = Params->get(0)->getName();
auto SParam = Params->get(1)->getName();
auto Prop = Props[0]->getName();
Printer << ExtraIndent << Return << Space
<< FParam << Point << Prop << Join << SParam << Point << Prop;
if (Props.size() > 1) {
std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){
auto Name = VD->getName();
Printer << And;
Printer.printNewline();
Printer << ExtraIndent << AdditionalSpace << FParam << Point
<< Name << Join << SParam << Point << Name;
});
}
}
bool RefactoringActionAddEquatableConformance::
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid();
}
bool RefactoringActionAddEquatableConformance::
performChange() {
auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo);
EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(),
Context.getInsertionTextForProtocol());
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(),
Context.getInsertionTextForFunction(SM));
return false;
}
static CharSourceRange
findSourceRangeToWrapInCatch(ResolvedCursorInfo CursorInfo,
SourceFile *TheFile,

View File

@@ -0,0 +1,24 @@
class TestAddEquatable: Equatable {
var property = "test"
private var prop = "test2"
let pr = "test3"
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
return lhs.property == rhs.property &&
lhs.prop == rhs.prop &&
lhs.pr == rhs.pr
}
}
extension TestAddEquatable {
func test() -> Bool {
return true
}
}
extension TestAddEquatable {
}

View File

@@ -0,0 +1,24 @@
class TestAddEquatable {
var property = "test"
private var prop = "test2"
let pr = "test3"
}
extension TestAddEquatable: Equatable {
func test() -> Bool {
return true
}
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
return lhs.property == rhs.property &&
lhs.prop == rhs.prop &&
lhs.pr == rhs.pr
}
}
extension TestAddEquatable {
}

View File

@@ -0,0 +1,23 @@
class TestAddEquatable {
var property = "test"
private var prop = "test2"
let pr = "test3"
}
extension TestAddEquatable {
func test() -> Bool {
return true
}
}
extension TestAddEquatable: Equatable {
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
return lhs.property == rhs.property &&
lhs.prop == rhs.prop &&
lhs.pr == rhs.pr
}
}

View File

@@ -0,0 +1,25 @@
class TestAddEquatable {
var property = "test"
private var prop = "test2"
let pr = "test3"
}
extension TestAddEquatable {
func test() -> Bool {
return true
}
}
extension TestAddEquatable {
}
// RUN: rm -rf %t.result && mkdir -p %t.result
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=1:16 > %t.result/first.swift
// RUN: diff -u %S/Outputs/basic/first.swift.expected %t.result/first.swift
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=7:13 > %t.result/second.swift
// RUN: diff -u %S/Outputs/basic/second.swift.expected %t.result/second.swift
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=13:13 > %t.result/third.swift
// RUN: diff -u %S/Outputs/basic/third.swift.expected %t.result/third.swift

View File

@@ -296,6 +296,49 @@ struct S {
}
}
class TestAddEquatable {
var property = "test"
private var prop = "test2"
let pr = "test3"
}
struct TestAddEquatableStruct {
var property = "test"
private var prop = "test2"
let pr = "test3"
}
enum AddEquatableEnum {
case first
case second
}
class TestAddEquatableConforming: Equatable {
var property = "test"
public static func ==(lhs: TestAddEquatableConforming,
rhs: TestAddEquatableConforming) -> Bool {
return lhs.property == rhs.property
}
}
struct TestAddEquatableStructConforming: Equatable {
var property = "test"
}
extension TestAddEquatable {
func test() -> Bool {
return false
}
}
extension TestAddEquatableStructConforming: Equatable {
public static func ==(lhs: TestAddEquatableConforming,
rhs: TestAddEquatableConforming) -> Bool {
return lhs.property == rhs.property
}
}
// RUN: %refactor -source-filename %s -pos=2:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1
// RUN: %refactor -source-filename %s -pos=3:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1
// RUN: %refactor -source-filename %s -pos=4:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1
@@ -397,6 +440,14 @@ struct S {
// RUN: %refactor -source-filename %s -pos=291:3 -end-pos=291:18 | %FileCheck %s -check-prefix=CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY
// RUN: %refactor -source-filename %s -pos=292:3 -end-pos=296:4 | %FileCheck %s -check-prefix=CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY
// RUN: %refactor -source-filename %s -pos=299:16 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE
// RUN: %refactor -source-filename %s -pos=305:12 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE
// RUN: %refactor -source-filename %s -pos=311:9 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED
// RUN: %refactor -source-filename %s -pos=316:11 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED
// RUN: %refactor -source-filename %s -pos=325:12 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED
// RUN: %refactor -source-filename %s -pos=329:15 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE
// RUN: %refactor -source-filename %s -pos=335:15 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED
// CHECK1: Action begins
// CHECK1-NEXT: Extract Method
// CHECK1-NEXT: Action ends
@@ -454,3 +505,8 @@ struct S {
// CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY-NOT: Convert To Computed Property
// CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY: Action ends
// CHECK-ADD-EQUATABLE-CONFORMANCE: Add Equatable Conformance
// CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED: Action begins
// CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED-NOT: Add Equatable Conformance
// CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED: Action ends

View File

@@ -72,6 +72,7 @@ Action(llvm::cl::desc("kind:"), llvm::cl::init(RefactoringKind::None),
clEnumValN(RefactoringKind::ReplaceBodiesWithFatalError,
"replace-bodies-with-fatalError", "Perform trailing closure refactoring"),
clEnumValN(RefactoringKind::MemberwiseInitLocalRefactoring, "memberwise-init", "Generate member wise initializer"),
clEnumValN(RefactoringKind::AddEquatableConformance, "add-equatable-conformance", "Add Equatable conformance"),
clEnumValN(RefactoringKind::ConvertToComputedProperty,
"convert-to-computed-property", "Convert from field initialization to computed property"),
clEnumValN(RefactoringKind::ConvertToSwitchStmt, "convert-to-switch-stmt", "Perform convert to switch statement")));