mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Merge pull request #29847 from tkachukandrew/add-equatable-conformance
This commit is contained in:
@@ -52,6 +52,8 @@ CURSOR_REFACTORING(TrailingClosure, "Convert To Trailing Closure", trailingclosu
|
||||
|
||||
CURSOR_REFACTORING(MemberwiseInitLocalRefactoring, "Generate Memberwise Initializer", memberwise.init.local.refactoring)
|
||||
|
||||
CURSOR_REFACTORING(AddEquatableConformance, "Add Equatable Conformance", add.equatable.conformance)
|
||||
|
||||
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
|
||||
|
||||
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)
|
||||
|
||||
@@ -3172,6 +3172,250 @@ bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
|
||||
return false;
|
||||
}
|
||||
|
||||
class AddEquatableContext {
|
||||
|
||||
/// Declaration context
|
||||
DeclContext *DC;
|
||||
|
||||
/// Adopter type
|
||||
Type Adopter;
|
||||
|
||||
/// Start location of declaration context brace
|
||||
SourceLoc StartLoc;
|
||||
|
||||
/// Array of all inherited protocols' locations
|
||||
ArrayRef<TypeLoc> ProtocolsLocations;
|
||||
|
||||
/// Array of all conformed protocols
|
||||
SmallVector<swift::ProtocolDecl *, 2> Protocols;
|
||||
|
||||
/// Start location of declaration,
|
||||
/// a place to write protocol name
|
||||
SourceLoc ProtInsertStartLoc;
|
||||
|
||||
/// Stored properties of extending adopter
|
||||
ArrayRef<VarDecl *> StoredProperties;
|
||||
|
||||
/// Range of internal members in declaration
|
||||
DeclRange Range;
|
||||
|
||||
bool conformsToEquatableProtocol() {
|
||||
for (ProtocolDecl *Protocol : Protocols) {
|
||||
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isRequirementValid() {
|
||||
auto Reqs = getProtocolRequirements();
|
||||
if (Reqs.empty()) {
|
||||
return false;
|
||||
}
|
||||
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
|
||||
return Req && Req->getParameters()->size() == 2;
|
||||
}
|
||||
|
||||
bool isPropertiesListValid() {
|
||||
return !getUserAccessibleProperties().empty();
|
||||
}
|
||||
|
||||
void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent,
|
||||
ParameterList *Params);
|
||||
|
||||
std::vector<ValueDecl *> getProtocolRequirements();
|
||||
|
||||
std::vector<VarDecl *> getUserAccessibleProperties();
|
||||
|
||||
public:
|
||||
|
||||
AddEquatableContext(NominalTypeDecl *Decl) : DC(Decl),
|
||||
Adopter(Decl->getDeclaredType()), StartLoc(Decl->getBraces().Start),
|
||||
ProtocolsLocations(Decl->getInherited()),
|
||||
Protocols(Decl->getAllProtocols()), ProtInsertStartLoc(Decl->getNameLoc()),
|
||||
StoredProperties(Decl->getStoredProperties()), Range(Decl->getMembers()) {};
|
||||
|
||||
AddEquatableContext(ExtensionDecl *Decl) : DC(Decl),
|
||||
Adopter(Decl->getExtendedType()), StartLoc(Decl->getBraces().Start),
|
||||
ProtocolsLocations(Decl->getInherited()),
|
||||
Protocols(Decl->getExtendedNominal()->getAllProtocols()),
|
||||
ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()),
|
||||
StoredProperties(Decl->getExtendedNominal()->getStoredProperties()), Range(Decl->getMembers()) {};
|
||||
|
||||
AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(),
|
||||
Protocols(), StoredProperties(), Range(nullptr, nullptr) {};
|
||||
|
||||
static AddEquatableContext getDeclarationContextFromInfo(ResolvedCursorInfo Info);
|
||||
|
||||
std::string getInsertionTextForProtocol();
|
||||
|
||||
std::string getInsertionTextForFunction(SourceManager &SM);
|
||||
|
||||
bool isValid() {
|
||||
// FIXME: Allow to generate explicit == method for declarations which already have
|
||||
// compiler-generated == method
|
||||
return StartLoc.isValid() && ProtInsertStartLoc.isValid() &&
|
||||
!conformsToEquatableProtocol() && isPropertiesListValid() &&
|
||||
isRequirementValid();
|
||||
}
|
||||
|
||||
SourceLoc getStartLocForProtocolDecl() {
|
||||
if (ProtocolsLocations.empty()) {
|
||||
return ProtInsertStartLoc;
|
||||
}
|
||||
return ProtocolsLocations.back().getSourceRange().Start;
|
||||
}
|
||||
|
||||
bool isMembersRangeEmpty() {
|
||||
return Range.empty();
|
||||
}
|
||||
|
||||
SourceLoc getInsertStartLoc();
|
||||
};
|
||||
|
||||
SourceLoc AddEquatableContext::
|
||||
getInsertStartLoc() {
|
||||
SourceLoc MaxLoc = StartLoc;
|
||||
for (auto Mem : Range) {
|
||||
if (Mem->getEndLoc().getOpaquePointerValue() >
|
||||
MaxLoc.getOpaquePointerValue()) {
|
||||
MaxLoc = Mem->getEndLoc();
|
||||
}
|
||||
}
|
||||
return MaxLoc;
|
||||
}
|
||||
|
||||
std::string AddEquatableContext::
|
||||
getInsertionTextForProtocol() {
|
||||
StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable);
|
||||
std::string Buffer;
|
||||
llvm::raw_string_ostream OS(Buffer);
|
||||
if (ProtocolsLocations.empty()) {
|
||||
OS << ": " << ProtocolName;
|
||||
return Buffer;
|
||||
}
|
||||
OS << ", " << ProtocolName;
|
||||
return Buffer;
|
||||
}
|
||||
|
||||
std::string AddEquatableContext::
|
||||
getInsertionTextForFunction(SourceManager &SM) {
|
||||
auto Reqs = getProtocolRequirements();
|
||||
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
|
||||
auto Params = Req->getParameters();
|
||||
StringRef ExtraIndent;
|
||||
StringRef CurrentIndent =
|
||||
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
|
||||
std::string Indent;
|
||||
if (isMembersRangeEmpty()) {
|
||||
Indent = (CurrentIndent + ExtraIndent).str();
|
||||
} else {
|
||||
Indent = CurrentIndent.str();
|
||||
}
|
||||
PrintOptions Options = PrintOptions::printVerbose();
|
||||
Options.PrintDocumentationComments = false;
|
||||
Options.setBaseType(Adopter);
|
||||
Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
|
||||
Printer << " {";
|
||||
Printer.printNewline();
|
||||
printFunctionBody(Printer, ExtraIndent, Params);
|
||||
Printer.printNewline();
|
||||
Printer << "}";
|
||||
};
|
||||
std::string Buffer;
|
||||
llvm::raw_string_ostream OS(Buffer);
|
||||
ExtraIndentStreamPrinter Printer(OS, Indent);
|
||||
Printer.printNewline();
|
||||
if (!isMembersRangeEmpty()) {
|
||||
Printer.printNewline();
|
||||
}
|
||||
Reqs[0]->print(Printer, Options);
|
||||
return Buffer;
|
||||
}
|
||||
|
||||
std::vector<VarDecl *> AddEquatableContext::
|
||||
getUserAccessibleProperties() {
|
||||
std::vector<VarDecl *> PublicProperties;
|
||||
for (VarDecl *Decl : StoredProperties) {
|
||||
if (Decl->Decl::isUserAccessible()) {
|
||||
PublicProperties.push_back(Decl);
|
||||
}
|
||||
}
|
||||
return PublicProperties;
|
||||
}
|
||||
|
||||
std::vector<ValueDecl *> AddEquatableContext::
|
||||
getProtocolRequirements() {
|
||||
std::vector<ValueDecl *> Collection;
|
||||
auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable);
|
||||
for (auto Member : Proto->getMembers()) {
|
||||
auto Req = dyn_cast<ValueDecl>(Member);
|
||||
if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) {
|
||||
continue;
|
||||
}
|
||||
Collection.push_back(Req);
|
||||
}
|
||||
return Collection;
|
||||
}
|
||||
|
||||
AddEquatableContext AddEquatableContext::
|
||||
getDeclarationContextFromInfo(ResolvedCursorInfo Info) {
|
||||
if (Info.isInvalid()) {
|
||||
return AddEquatableContext();
|
||||
}
|
||||
if (!Info.IsRef) {
|
||||
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD)) {
|
||||
return AddEquatableContext(NomDecl);
|
||||
}
|
||||
} else if (auto *ExtDecl = Info.ExtTyRef) {
|
||||
return AddEquatableContext(ExtDecl);
|
||||
}
|
||||
return AddEquatableContext();
|
||||
}
|
||||
|
||||
void AddEquatableContext::
|
||||
printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
|
||||
llvm::SmallString<128> Return;
|
||||
llvm::raw_svector_ostream SS(Return);
|
||||
SS << tok::kw_return;
|
||||
StringRef Space = " ";
|
||||
StringRef AdditionalSpace = " ";
|
||||
StringRef Point = ".";
|
||||
StringRef Join = " == ";
|
||||
StringRef And = " &&";
|
||||
auto Props = getUserAccessibleProperties();
|
||||
auto FParam = Params->get(0)->getName();
|
||||
auto SParam = Params->get(1)->getName();
|
||||
auto Prop = Props[0]->getName();
|
||||
Printer << ExtraIndent << Return << Space
|
||||
<< FParam << Point << Prop << Join << SParam << Point << Prop;
|
||||
if (Props.size() > 1) {
|
||||
std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){
|
||||
auto Name = VD->getName();
|
||||
Printer << And;
|
||||
Printer.printNewline();
|
||||
Printer << ExtraIndent << AdditionalSpace << FParam << Point
|
||||
<< Name << Join << SParam << Point << Name;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
bool RefactoringActionAddEquatableConformance::
|
||||
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
|
||||
return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid();
|
||||
}
|
||||
|
||||
bool RefactoringActionAddEquatableConformance::
|
||||
performChange() {
|
||||
auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo);
|
||||
EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(),
|
||||
Context.getInsertionTextForProtocol());
|
||||
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(),
|
||||
Context.getInsertionTextForFunction(SM));
|
||||
return false;
|
||||
}
|
||||
|
||||
static CharSourceRange
|
||||
findSourceRangeToWrapInCatch(ResolvedCursorInfo CursorInfo,
|
||||
SourceFile *TheFile,
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
class TestAddEquatable: Equatable {
|
||||
var property = "test"
|
||||
private var prop = "test2"
|
||||
let pr = "test3"
|
||||
|
||||
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
|
||||
return lhs.property == rhs.property &&
|
||||
lhs.prop == rhs.prop &&
|
||||
lhs.pr == rhs.pr
|
||||
}
|
||||
}
|
||||
|
||||
extension TestAddEquatable {
|
||||
func test() -> Bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension TestAddEquatable {
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
class TestAddEquatable {
|
||||
var property = "test"
|
||||
private var prop = "test2"
|
||||
let pr = "test3"
|
||||
}
|
||||
|
||||
extension TestAddEquatable: Equatable {
|
||||
func test() -> Bool {
|
||||
return true
|
||||
}
|
||||
|
||||
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
|
||||
return lhs.property == rhs.property &&
|
||||
lhs.prop == rhs.prop &&
|
||||
lhs.pr == rhs.pr
|
||||
}
|
||||
}
|
||||
|
||||
extension TestAddEquatable {
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
class TestAddEquatable {
|
||||
var property = "test"
|
||||
private var prop = "test2"
|
||||
let pr = "test3"
|
||||
}
|
||||
|
||||
extension TestAddEquatable {
|
||||
func test() -> Bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension TestAddEquatable: Equatable {
|
||||
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
|
||||
return lhs.property == rhs.property &&
|
||||
lhs.prop == rhs.prop &&
|
||||
lhs.pr == rhs.pr
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
25
test/refactoring/AddEquatableConformance/basic.swift
Normal file
25
test/refactoring/AddEquatableConformance/basic.swift
Normal file
@@ -0,0 +1,25 @@
|
||||
class TestAddEquatable {
|
||||
var property = "test"
|
||||
private var prop = "test2"
|
||||
let pr = "test3"
|
||||
}
|
||||
|
||||
extension TestAddEquatable {
|
||||
func test() -> Bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
extension TestAddEquatable {
|
||||
}
|
||||
|
||||
// RUN: rm -rf %t.result && mkdir -p %t.result
|
||||
|
||||
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=1:16 > %t.result/first.swift
|
||||
// RUN: diff -u %S/Outputs/basic/first.swift.expected %t.result/first.swift
|
||||
|
||||
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=7:13 > %t.result/second.swift
|
||||
// RUN: diff -u %S/Outputs/basic/second.swift.expected %t.result/second.swift
|
||||
|
||||
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=13:13 > %t.result/third.swift
|
||||
// RUN: diff -u %S/Outputs/basic/third.swift.expected %t.result/third.swift
|
||||
@@ -296,6 +296,49 @@ struct S {
|
||||
}
|
||||
}
|
||||
|
||||
class TestAddEquatable {
|
||||
var property = "test"
|
||||
private var prop = "test2"
|
||||
let pr = "test3"
|
||||
}
|
||||
|
||||
struct TestAddEquatableStruct {
|
||||
var property = "test"
|
||||
private var prop = "test2"
|
||||
let pr = "test3"
|
||||
}
|
||||
|
||||
enum AddEquatableEnum {
|
||||
case first
|
||||
case second
|
||||
}
|
||||
|
||||
class TestAddEquatableConforming: Equatable {
|
||||
var property = "test"
|
||||
|
||||
public static func ==(lhs: TestAddEquatableConforming,
|
||||
rhs: TestAddEquatableConforming) -> Bool {
|
||||
return lhs.property == rhs.property
|
||||
}
|
||||
}
|
||||
|
||||
struct TestAddEquatableStructConforming: Equatable {
|
||||
var property = "test"
|
||||
}
|
||||
|
||||
extension TestAddEquatable {
|
||||
func test() -> Bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
extension TestAddEquatableStructConforming: Equatable {
|
||||
public static func ==(lhs: TestAddEquatableConforming,
|
||||
rhs: TestAddEquatableConforming) -> Bool {
|
||||
return lhs.property == rhs.property
|
||||
}
|
||||
}
|
||||
|
||||
// RUN: %refactor -source-filename %s -pos=2:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1
|
||||
// RUN: %refactor -source-filename %s -pos=3:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1
|
||||
// RUN: %refactor -source-filename %s -pos=4:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1
|
||||
@@ -397,6 +440,14 @@ struct S {
|
||||
// RUN: %refactor -source-filename %s -pos=291:3 -end-pos=291:18 | %FileCheck %s -check-prefix=CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY
|
||||
// RUN: %refactor -source-filename %s -pos=292:3 -end-pos=296:4 | %FileCheck %s -check-prefix=CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY
|
||||
|
||||
// RUN: %refactor -source-filename %s -pos=299:16 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE
|
||||
// RUN: %refactor -source-filename %s -pos=305:12 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE
|
||||
// RUN: %refactor -source-filename %s -pos=311:9 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED
|
||||
// RUN: %refactor -source-filename %s -pos=316:11 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED
|
||||
// RUN: %refactor -source-filename %s -pos=325:12 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED
|
||||
// RUN: %refactor -source-filename %s -pos=329:15 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE
|
||||
// RUN: %refactor -source-filename %s -pos=335:15 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED
|
||||
|
||||
// CHECK1: Action begins
|
||||
// CHECK1-NEXT: Extract Method
|
||||
// CHECK1-NEXT: Action ends
|
||||
@@ -454,3 +505,8 @@ struct S {
|
||||
// CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY-NOT: Convert To Computed Property
|
||||
// CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY: Action ends
|
||||
|
||||
// CHECK-ADD-EQUATABLE-CONFORMANCE: Add Equatable Conformance
|
||||
|
||||
// CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED: Action begins
|
||||
// CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED-NOT: Add Equatable Conformance
|
||||
// CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED: Action ends
|
||||
|
||||
@@ -72,6 +72,7 @@ Action(llvm::cl::desc("kind:"), llvm::cl::init(RefactoringKind::None),
|
||||
clEnumValN(RefactoringKind::ReplaceBodiesWithFatalError,
|
||||
"replace-bodies-with-fatalError", "Perform trailing closure refactoring"),
|
||||
clEnumValN(RefactoringKind::MemberwiseInitLocalRefactoring, "memberwise-init", "Generate member wise initializer"),
|
||||
clEnumValN(RefactoringKind::AddEquatableConformance, "add-equatable-conformance", "Add Equatable conformance"),
|
||||
clEnumValN(RefactoringKind::ConvertToComputedProperty,
|
||||
"convert-to-computed-property", "Convert from field initialization to computed property"),
|
||||
clEnumValN(RefactoringKind::ConvertToSwitchStmt, "convert-to-switch-stmt", "Perform convert to switch statement")));
|
||||
|
||||
Reference in New Issue
Block a user