diff --git a/include/swift/IDE/RefactoringKinds.def b/include/swift/IDE/RefactoringKinds.def index 0fbd03dbb70..6a6dff01c87 100644 --- a/include/swift/IDE/RefactoringKinds.def +++ b/include/swift/IDE/RefactoringKinds.def @@ -52,6 +52,8 @@ CURSOR_REFACTORING(TrailingClosure, "Convert To Trailing Closure", trailingclosu CURSOR_REFACTORING(MemberwiseInitLocalRefactoring, "Generate Memberwise Initializer", memberwise.init.local.refactoring) +CURSOR_REFACTORING(AddEquatableConformance, "Add Equatable Conformance", add.equatable.conformance) + RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr) RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function) diff --git a/lib/IDE/Refactoring.cpp b/lib/IDE/Refactoring.cpp index 30fd18e8e72..b9951d98927 100644 --- a/lib/IDE/Refactoring.cpp +++ b/lib/IDE/Refactoring.cpp @@ -3172,6 +3172,250 @@ bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() { return false; } +class AddEquatableContext { + + /// Declaration context + DeclContext *DC; + + /// Adopter type + Type Adopter; + + /// Start location of declaration context brace + SourceLoc StartLoc; + + /// Array of all inherited protocols' locations + ArrayRef ProtocolsLocations; + + /// Array of all conformed protocols + SmallVector Protocols; + + /// Start location of declaration, + /// a place to write protocol name + SourceLoc ProtInsertStartLoc; + + /// Stored properties of extending adopter + ArrayRef StoredProperties; + + /// Range of internal members in declaration + DeclRange Range; + + bool conformsToEquatableProtocol() { + for (ProtocolDecl *Protocol : Protocols) { + if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) { + return true; + } + } + return false; + } + + bool isRequirementValid() { + auto Reqs = getProtocolRequirements(); + if (Reqs.empty()) { + return false; + } + auto Req = dyn_cast(Reqs[0]); + return Req && Req->getParameters()->size() == 2; + } + + bool isPropertiesListValid() { + return !getUserAccessibleProperties().empty(); + } + + void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, + ParameterList *Params); + + std::vector getProtocolRequirements(); + + std::vector getUserAccessibleProperties(); + +public: + + AddEquatableContext(NominalTypeDecl *Decl) : DC(Decl), + Adopter(Decl->getDeclaredType()), StartLoc(Decl->getBraces().Start), + ProtocolsLocations(Decl->getInherited()), + Protocols(Decl->getAllProtocols()), ProtInsertStartLoc(Decl->getNameLoc()), + StoredProperties(Decl->getStoredProperties()), Range(Decl->getMembers()) {}; + + AddEquatableContext(ExtensionDecl *Decl) : DC(Decl), + Adopter(Decl->getExtendedType()), StartLoc(Decl->getBraces().Start), + ProtocolsLocations(Decl->getInherited()), + Protocols(Decl->getExtendedNominal()->getAllProtocols()), + ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()), + StoredProperties(Decl->getExtendedNominal()->getStoredProperties()), Range(Decl->getMembers()) {}; + + AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(), + Protocols(), StoredProperties(), Range(nullptr, nullptr) {}; + + static AddEquatableContext getDeclarationContextFromInfo(ResolvedCursorInfo Info); + + std::string getInsertionTextForProtocol(); + + std::string getInsertionTextForFunction(SourceManager &SM); + + bool isValid() { + // FIXME: Allow to generate explicit == method for declarations which already have + // compiler-generated == method + return StartLoc.isValid() && ProtInsertStartLoc.isValid() && + !conformsToEquatableProtocol() && isPropertiesListValid() && + isRequirementValid(); + } + + SourceLoc getStartLocForProtocolDecl() { + if (ProtocolsLocations.empty()) { + return ProtInsertStartLoc; + } + return ProtocolsLocations.back().getSourceRange().Start; + } + + bool isMembersRangeEmpty() { + return Range.empty(); + } + + SourceLoc getInsertStartLoc(); +}; + +SourceLoc AddEquatableContext:: +getInsertStartLoc() { + SourceLoc MaxLoc = StartLoc; + for (auto Mem : Range) { + if (Mem->getEndLoc().getOpaquePointerValue() > + MaxLoc.getOpaquePointerValue()) { + MaxLoc = Mem->getEndLoc(); + } + } + return MaxLoc; +} + +std::string AddEquatableContext:: +getInsertionTextForProtocol() { + StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable); + std::string Buffer; + llvm::raw_string_ostream OS(Buffer); + if (ProtocolsLocations.empty()) { + OS << ": " << ProtocolName; + return Buffer; + } + OS << ", " << ProtocolName; + return Buffer; +} + +std::string AddEquatableContext:: +getInsertionTextForFunction(SourceManager &SM) { + auto Reqs = getProtocolRequirements(); + auto Req = dyn_cast(Reqs[0]); + auto Params = Req->getParameters(); + StringRef ExtraIndent; + StringRef CurrentIndent = + Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent); + std::string Indent; + if (isMembersRangeEmpty()) { + Indent = (CurrentIndent + ExtraIndent).str(); + } else { + Indent = CurrentIndent.str(); + } + PrintOptions Options = PrintOptions::printVerbose(); + Options.PrintDocumentationComments = false; + Options.setBaseType(Adopter); + Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) { + Printer << " {"; + Printer.printNewline(); + printFunctionBody(Printer, ExtraIndent, Params); + Printer.printNewline(); + Printer << "}"; + }; + std::string Buffer; + llvm::raw_string_ostream OS(Buffer); + ExtraIndentStreamPrinter Printer(OS, Indent); + Printer.printNewline(); + if (!isMembersRangeEmpty()) { + Printer.printNewline(); + } + Reqs[0]->print(Printer, Options); + return Buffer; +} + +std::vector AddEquatableContext:: +getUserAccessibleProperties() { + std::vector PublicProperties; + for (VarDecl *Decl : StoredProperties) { + if (Decl->Decl::isUserAccessible()) { + PublicProperties.push_back(Decl); + } + } + return PublicProperties; +} + +std::vector AddEquatableContext:: +getProtocolRequirements() { + std::vector Collection; + auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable); + for (auto Member : Proto->getMembers()) { + auto Req = dyn_cast(Member); + if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) { + continue; + } + Collection.push_back(Req); + } + return Collection; +} + +AddEquatableContext AddEquatableContext:: +getDeclarationContextFromInfo(ResolvedCursorInfo Info) { + if (Info.isInvalid()) { + return AddEquatableContext(); + } + if (!Info.IsRef) { + if (auto *NomDecl = dyn_cast(Info.ValueD)) { + return AddEquatableContext(NomDecl); + } + } else if (auto *ExtDecl = Info.ExtTyRef) { + return AddEquatableContext(ExtDecl); + } + return AddEquatableContext(); +} + +void AddEquatableContext:: +printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) { + llvm::SmallString<128> Return; + llvm::raw_svector_ostream SS(Return); + SS << tok::kw_return; + StringRef Space = " "; + StringRef AdditionalSpace = " "; + StringRef Point = "."; + StringRef Join = " == "; + StringRef And = " &&"; + auto Props = getUserAccessibleProperties(); + auto FParam = Params->get(0)->getName(); + auto SParam = Params->get(1)->getName(); + auto Prop = Props[0]->getName(); + Printer << ExtraIndent << Return << Space + << FParam << Point << Prop << Join << SParam << Point << Prop; + if (Props.size() > 1) { + std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){ + auto Name = VD->getName(); + Printer << And; + Printer.printNewline(); + Printer << ExtraIndent << AdditionalSpace << FParam << Point + << Name << Join << SParam << Point << Name; + }); + } +} + +bool RefactoringActionAddEquatableConformance:: +isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) { + return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid(); +} + +bool RefactoringActionAddEquatableConformance:: +performChange() { + auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo); + EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(), + Context.getInsertionTextForProtocol()); + EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), + Context.getInsertionTextForFunction(SM)); + return false; +} + static CharSourceRange findSourceRangeToWrapInCatch(ResolvedCursorInfo CursorInfo, SourceFile *TheFile, diff --git a/test/refactoring/AddEquatableConformance/Outputs/basic/first.swift.expected b/test/refactoring/AddEquatableConformance/Outputs/basic/first.swift.expected new file mode 100644 index 00000000000..a4077abea94 --- /dev/null +++ b/test/refactoring/AddEquatableConformance/Outputs/basic/first.swift.expected @@ -0,0 +1,24 @@ +class TestAddEquatable: Equatable { + var property = "test" + private var prop = "test2" + let pr = "test3" + + static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool { + return lhs.property == rhs.property && + lhs.prop == rhs.prop && + lhs.pr == rhs.pr + } +} + +extension TestAddEquatable { + func test() -> Bool { + return true + } +} + +extension TestAddEquatable { +} + + + + diff --git a/test/refactoring/AddEquatableConformance/Outputs/basic/second.swift.expected b/test/refactoring/AddEquatableConformance/Outputs/basic/second.swift.expected new file mode 100644 index 00000000000..bc37aa869ca --- /dev/null +++ b/test/refactoring/AddEquatableConformance/Outputs/basic/second.swift.expected @@ -0,0 +1,24 @@ +class TestAddEquatable { + var property = "test" + private var prop = "test2" + let pr = "test3" +} + +extension TestAddEquatable: Equatable { + func test() -> Bool { + return true + } + + static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool { + return lhs.property == rhs.property && + lhs.prop == rhs.prop && + lhs.pr == rhs.pr + } +} + +extension TestAddEquatable { +} + + + + diff --git a/test/refactoring/AddEquatableConformance/Outputs/basic/third.swift.expected b/test/refactoring/AddEquatableConformance/Outputs/basic/third.swift.expected new file mode 100644 index 00000000000..1141aa687d7 --- /dev/null +++ b/test/refactoring/AddEquatableConformance/Outputs/basic/third.swift.expected @@ -0,0 +1,23 @@ +class TestAddEquatable { + var property = "test" + private var prop = "test2" + let pr = "test3" +} + +extension TestAddEquatable { + func test() -> Bool { + return true + } +} + +extension TestAddEquatable: Equatable { + static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool { + return lhs.property == rhs.property && + lhs.prop == rhs.prop && + lhs.pr == rhs.pr + } +} + + + + diff --git a/test/refactoring/AddEquatableConformance/basic.swift b/test/refactoring/AddEquatableConformance/basic.swift new file mode 100644 index 00000000000..5d585ffb17c --- /dev/null +++ b/test/refactoring/AddEquatableConformance/basic.swift @@ -0,0 +1,25 @@ +class TestAddEquatable { + var property = "test" + private var prop = "test2" + let pr = "test3" +} + +extension TestAddEquatable { + func test() -> Bool { + return true + } +} + +extension TestAddEquatable { +} + +// RUN: rm -rf %t.result && mkdir -p %t.result + +// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=1:16 > %t.result/first.swift +// RUN: diff -u %S/Outputs/basic/first.swift.expected %t.result/first.swift + +// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=7:13 > %t.result/second.swift +// RUN: diff -u %S/Outputs/basic/second.swift.expected %t.result/second.swift + +// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=13:13 > %t.result/third.swift +// RUN: diff -u %S/Outputs/basic/third.swift.expected %t.result/third.swift diff --git a/test/refactoring/RefactoringKind/basic.swift b/test/refactoring/RefactoringKind/basic.swift index 16bf14f604b..49bb1b5f155 100644 --- a/test/refactoring/RefactoringKind/basic.swift +++ b/test/refactoring/RefactoringKind/basic.swift @@ -296,6 +296,49 @@ struct S { } } +class TestAddEquatable { + var property = "test" + private var prop = "test2" + let pr = "test3" +} + +struct TestAddEquatableStruct { + var property = "test" + private var prop = "test2" + let pr = "test3" +} + +enum AddEquatableEnum { + case first + case second +} + +class TestAddEquatableConforming: Equatable { + var property = "test" + + public static func ==(lhs: TestAddEquatableConforming, + rhs: TestAddEquatableConforming) -> Bool { + return lhs.property == rhs.property + } +} + +struct TestAddEquatableStructConforming: Equatable { + var property = "test" +} + +extension TestAddEquatable { + func test() -> Bool { + return false + } +} + +extension TestAddEquatableStructConforming: Equatable { + public static func ==(lhs: TestAddEquatableConforming, + rhs: TestAddEquatableConforming) -> Bool { + return lhs.property == rhs.property + } +} + // RUN: %refactor -source-filename %s -pos=2:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1 // RUN: %refactor -source-filename %s -pos=3:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1 // RUN: %refactor -source-filename %s -pos=4:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1 @@ -397,6 +440,14 @@ struct S { // RUN: %refactor -source-filename %s -pos=291:3 -end-pos=291:18 | %FileCheck %s -check-prefix=CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY // RUN: %refactor -source-filename %s -pos=292:3 -end-pos=296:4 | %FileCheck %s -check-prefix=CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY +// RUN: %refactor -source-filename %s -pos=299:16 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE +// RUN: %refactor -source-filename %s -pos=305:12 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE +// RUN: %refactor -source-filename %s -pos=311:9 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED +// RUN: %refactor -source-filename %s -pos=316:11 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED +// RUN: %refactor -source-filename %s -pos=325:12 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED +// RUN: %refactor -source-filename %s -pos=329:15 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE +// RUN: %refactor -source-filename %s -pos=335:15 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED + // CHECK1: Action begins // CHECK1-NEXT: Extract Method // CHECK1-NEXT: Action ends @@ -454,3 +505,8 @@ struct S { // CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY-NOT: Convert To Computed Property // CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY: Action ends +// CHECK-ADD-EQUATABLE-CONFORMANCE: Add Equatable Conformance + +// CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED: Action begins +// CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED-NOT: Add Equatable Conformance +// CHECK-ADD-EQUATABLE-CONFORMANCE-NOT-INCLUDED: Action ends diff --git a/tools/swift-refactor/swift-refactor.cpp b/tools/swift-refactor/swift-refactor.cpp index e037aafd936..2bc77c2b7d6 100644 --- a/tools/swift-refactor/swift-refactor.cpp +++ b/tools/swift-refactor/swift-refactor.cpp @@ -72,6 +72,7 @@ Action(llvm::cl::desc("kind:"), llvm::cl::init(RefactoringKind::None), clEnumValN(RefactoringKind::ReplaceBodiesWithFatalError, "replace-bodies-with-fatalError", "Perform trailing closure refactoring"), clEnumValN(RefactoringKind::MemberwiseInitLocalRefactoring, "memberwise-init", "Generate member wise initializer"), + clEnumValN(RefactoringKind::AddEquatableConformance, "add-equatable-conformance", "Add Equatable conformance"), clEnumValN(RefactoringKind::ConvertToComputedProperty, "convert-to-computed-property", "Convert from field initialization to computed property"), clEnumValN(RefactoringKind::ConvertToSwitchStmt, "convert-to-switch-stmt", "Perform convert to switch statement")));