mirror of
https://github.com/apple/swift.git
synced 2025-12-14 20:36:38 +01:00
[refactoring] Rework "add codable implementation" refactoring
* Support extensions including conditional conformance * Correct access modifiers * More correct lookup for the synthesized declarations * Avoid printing decls in nested types (rdar://98025945)
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
|
||||
#include "RefactoringActions.h"
|
||||
#include "Utils.h"
|
||||
#include "swift/AST/ProtocolConformance.h"
|
||||
|
||||
using namespace swift::refactoring;
|
||||
|
||||
@@ -19,176 +20,201 @@ namespace {
|
||||
class AddCodableContext {
|
||||
|
||||
/// Declaration context
|
||||
DeclContext *DC;
|
||||
IterableDeclContext *IDC;
|
||||
|
||||
/// Start location of declaration context brace
|
||||
SourceLoc StartLoc;
|
||||
AddCodableContext(NominalTypeDecl *nominal) : IDC(nominal){};
|
||||
AddCodableContext(ExtensionDecl *extension) : IDC(extension){};
|
||||
AddCodableContext(std::nullptr_t) : IDC(nullptr){};
|
||||
|
||||
/// Array of all conformed protocols
|
||||
SmallVector<swift::ProtocolDecl *, 2> Protocols;
|
||||
const NominalTypeDecl *getNominal() const {
|
||||
switch (IDC->getIterableContextKind()) {
|
||||
case IterableDeclContextKind::NominalTypeDecl:
|
||||
return cast<NominalTypeDecl>(IDC);
|
||||
case IterableDeclContextKind::ExtensionDecl:
|
||||
return cast<ExtensionDecl>(IDC)->getExtendedNominal();
|
||||
}
|
||||
assert(false && "unhandled IterableDeclContextKind");
|
||||
}
|
||||
|
||||
/// Range of internal members in declaration
|
||||
DeclRange Range;
|
||||
/// Get the left brace location of the type-or-extension decl.
|
||||
SourceLoc getLeftBraceLoc() const {
|
||||
switch (IDC->getIterableContextKind()) {
|
||||
case IterableDeclContextKind::NominalTypeDecl:
|
||||
return cast<NominalTypeDecl>(IDC)->getBraces().Start;
|
||||
case IterableDeclContextKind::ExtensionDecl:
|
||||
return cast<ExtensionDecl>(IDC)->getBraces().Start;
|
||||
}
|
||||
assert(false && "unhandled IterableDeclContextKind");
|
||||
}
|
||||
|
||||
bool conformsToCodableProtocol() {
|
||||
for (ProtocolDecl *Protocol : Protocols) {
|
||||
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Encodable ||
|
||||
Protocol->getKnownProtocolKind() == KnownProtocolKind::Decodable) {
|
||||
return true;
|
||||
/// Get the token location where the text should be inserted after.
|
||||
SourceLoc getInsertStartLoc() const {
|
||||
// Prefer the end of elements.
|
||||
for (auto *member : llvm::reverse(IDC->getParsedMembers())) {
|
||||
if (isa<AccessorDecl>(member) || isa<VarDecl>(member)) {
|
||||
// These are part of 'PatternBindingDecl' but are hoisted in AST.
|
||||
continue;
|
||||
}
|
||||
return member->getEndLoc();
|
||||
}
|
||||
|
||||
// After the starting brace if empty.
|
||||
return getLeftBraceLoc();
|
||||
}
|
||||
|
||||
std::string getBaseIndent() const {
|
||||
SourceManager &SM = IDC->getDecl()->getASTContext().SourceMgr;
|
||||
SourceLoc startLoc = getInsertStartLoc();
|
||||
StringRef extraIndent;
|
||||
StringRef currentIndent =
|
||||
Lexer::getIndentationForLine(SM, startLoc, &extraIndent);
|
||||
if (startLoc == getLeftBraceLoc()) {
|
||||
return (currentIndent + extraIndent).str();
|
||||
} else {
|
||||
return currentIndent.str();
|
||||
}
|
||||
}
|
||||
|
||||
void printInsertText(llvm::raw_ostream &OS) const {
|
||||
auto &ctx = IDC->getDecl()->getASTContext();
|
||||
|
||||
PrintOptions Options = PrintOptions::printDeclarations();
|
||||
Options.SynthesizeSugarOnTypes = true;
|
||||
Options.FunctionDefinitions = true;
|
||||
Options.VarInitializers = true;
|
||||
Options.PrintExprs = true;
|
||||
Options.TypeDefinitions = false;
|
||||
Options.PrintSpaceBeforeInheritance = false;
|
||||
Options.ExcludeAttrList.push_back(DeclAttrKind::HasInitialValue);
|
||||
Options.PrintInternalAccessKeyword = false;
|
||||
|
||||
std::string baseIndent = getBaseIndent();
|
||||
ExtraIndentStreamPrinter Printer(OS, baseIndent);
|
||||
|
||||
// The insertion starts at the end of the last token.
|
||||
Printer.printNewline();
|
||||
|
||||
// Synthesized 'CodingKeys' are placed in the main nominal decl.
|
||||
// Iterate members and look for synthesized enums that conforms to
|
||||
// 'CodingKey' protocol.
|
||||
auto *codingKeyProto = ctx.getProtocol(KnownProtocolKind::CodingKey);
|
||||
for (auto *member : getNominal()->getMembers()) {
|
||||
auto *enumD = dyn_cast<EnumDecl>(member);
|
||||
if (!enumD || !enumD->isSynthesized())
|
||||
continue;
|
||||
llvm::SmallVector<ProtocolConformance *, 1> codingKeyConformance;
|
||||
if (!enumD->lookupConformance(codingKeyProto, codingKeyConformance))
|
||||
continue;
|
||||
|
||||
// Print the decl, but without the body.
|
||||
Printer.printNewline();
|
||||
enumD->print(Printer, Options);
|
||||
|
||||
// Manually print elements because CodingKey enums have some synthesized
|
||||
// members for the protocol conformance e.g 'init(intValue:)'.
|
||||
// We don't want to print them here.
|
||||
Printer << " {";
|
||||
Printer.printNewline();
|
||||
Printer.setIndent(2);
|
||||
for (auto *elementD : enumD->getAllElements()) {
|
||||
elementD->print(Printer, Options);
|
||||
Printer.printNewline();
|
||||
}
|
||||
Printer.setIndent(0);
|
||||
Printer << "}";
|
||||
Printer.printNewline();
|
||||
}
|
||||
|
||||
// Look for synthesized witness decls and print them.
|
||||
for (auto *conformance : IDC->getLocalConformances()) {
|
||||
auto protocol = conformance->getProtocol();
|
||||
auto kind = protocol->getKnownProtocolKind();
|
||||
if (kind == KnownProtocolKind::Encodable ||
|
||||
kind == KnownProtocolKind::Decodable) {
|
||||
for (auto requirement : protocol->getProtocolRequirements()) {
|
||||
auto witness = conformance->getWitnessDecl(requirement);
|
||||
if (witness->isSynthesized()) {
|
||||
Printer.printNewline();
|
||||
witness->print(Printer, Options);
|
||||
Printer.printNewline();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
static AddCodableContext getFromCursorInfo(ResolvedCursorInfoPtr Info);
|
||||
|
||||
bool isApplicable() const {
|
||||
if (!IDC || !getNominal())
|
||||
return false;
|
||||
|
||||
// Check if 'IDC' conforms to 'Encodable' and/or 'Decodable' and any of the
|
||||
// requirements are synthesized.
|
||||
for (auto *conformance : IDC->getLocalConformances()) {
|
||||
auto protocol = conformance->getProtocol();
|
||||
auto kind = protocol->getKnownProtocolKind();
|
||||
if (kind == KnownProtocolKind::Encodable ||
|
||||
kind == KnownProtocolKind::Decodable) {
|
||||
// Check if any of the protocol requirements are synthesized.
|
||||
for (auto requirement : protocol->getProtocolRequirements()) {
|
||||
auto witness = conformance->getWitnessDecl(requirement);
|
||||
if (!witness || witness->isSynthesized())
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
AddCodableContext(NominalTypeDecl *Decl)
|
||||
: DC(Decl), StartLoc(Decl->getBraces().Start),
|
||||
Protocols(getAllProtocols(Decl)), Range(Decl->getMembers()){};
|
||||
|
||||
AddCodableContext(ExtensionDecl *Decl)
|
||||
: DC(Decl), StartLoc(Decl->getBraces().Start),
|
||||
Protocols(getAllProtocols(Decl->getExtendedNominal())),
|
||||
Range(Decl->getMembers()){};
|
||||
|
||||
AddCodableContext() : DC(nullptr), Protocols(), Range(nullptr, nullptr){};
|
||||
|
||||
static AddCodableContext
|
||||
getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info);
|
||||
|
||||
void printInsertionText(ResolvedCursorInfoPtr CursorInfo, SourceManager &SM,
|
||||
llvm::raw_ostream &OS);
|
||||
|
||||
bool isValid() { return StartLoc.isValid() && conformsToCodableProtocol(); }
|
||||
|
||||
SourceLoc getInsertStartLoc();
|
||||
};
|
||||
|
||||
SourceLoc AddCodableContext::getInsertStartLoc() {
|
||||
SourceLoc MaxLoc = StartLoc;
|
||||
for (auto Mem : Range) {
|
||||
if (Mem->getEndLoc().getOpaquePointerValue() >
|
||||
MaxLoc.getOpaquePointerValue()) {
|
||||
MaxLoc = Mem->getEndLoc();
|
||||
}
|
||||
}
|
||||
return MaxLoc;
|
||||
}
|
||||
|
||||
/// Walks an AST and prints the synthesized Codable implementation.
|
||||
class SynthesizedCodablePrinter : public ASTWalker {
|
||||
private:
|
||||
ASTPrinter &Printer;
|
||||
|
||||
public:
|
||||
SynthesizedCodablePrinter(ASTPrinter &Printer) : Printer(Printer) {}
|
||||
|
||||
MacroWalking getMacroWalkingBehavior() const override {
|
||||
return MacroWalking::Arguments;
|
||||
}
|
||||
|
||||
PreWalkAction walkToDeclPre(Decl *D) override {
|
||||
auto *VD = dyn_cast<ValueDecl>(D);
|
||||
if (!VD)
|
||||
return Action::SkipNode();
|
||||
|
||||
if (!VD->isSynthesized()) {
|
||||
return Action::Continue();
|
||||
}
|
||||
SmallString<32> Scratch;
|
||||
auto name = VD->getName().getString(Scratch);
|
||||
// Print all synthesized enums,
|
||||
// since Codable can synthesize multiple enums (for associated values).
|
||||
auto shouldPrint =
|
||||
isa<EnumDecl>(VD) || name == "init(from:)" || name == "encode(to:)";
|
||||
if (!shouldPrint) {
|
||||
// Some other synthesized decl that we don't want to print.
|
||||
return Action::SkipNode();
|
||||
}
|
||||
|
||||
Printer.printNewline();
|
||||
|
||||
if (auto enumDecl = dyn_cast<EnumDecl>(D)) {
|
||||
// Manually print enum here, since we don't want to print synthesized
|
||||
// functions.
|
||||
Printer << "enum " << enumDecl->getNameStr();
|
||||
PrintOptions Options;
|
||||
Options.PrintSpaceBeforeInheritance = false;
|
||||
enumDecl->printInherited(Printer, Options);
|
||||
Printer << " {";
|
||||
for (Decl *EC : enumDecl->getAllElements()) {
|
||||
Printer.printNewline();
|
||||
Printer << " ";
|
||||
EC->print(Printer, Options);
|
||||
}
|
||||
Printer.printNewline();
|
||||
Printer << "}";
|
||||
return Action::SkipNode();
|
||||
}
|
||||
|
||||
PrintOptions Options;
|
||||
Options.SynthesizeSugarOnTypes = true;
|
||||
Options.FunctionDefinitions = true;
|
||||
Options.VarInitializers = true;
|
||||
Options.PrintExprs = true;
|
||||
Options.TypeDefinitions = true;
|
||||
Options.ExcludeAttrList.push_back(DeclAttrKind::HasInitialValue);
|
||||
|
||||
Printer.printNewline();
|
||||
D->print(Printer, Options);
|
||||
|
||||
return Action::SkipNode();
|
||||
void getInsertion(SourceLoc &insertLoc, std::string &insertText) const {
|
||||
insertLoc = getInsertStartLoc();
|
||||
llvm::raw_string_ostream OS(insertText);
|
||||
printInsertText(OS);
|
||||
}
|
||||
};
|
||||
|
||||
void AddCodableContext::printInsertionText(ResolvedCursorInfoPtr CursorInfo,
|
||||
SourceManager &SM,
|
||||
llvm::raw_ostream &OS) {
|
||||
StringRef ExtraIndent;
|
||||
StringRef CurrentIndent =
|
||||
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
|
||||
std::string Indent;
|
||||
if (getInsertStartLoc() == StartLoc) {
|
||||
Indent = (CurrentIndent + ExtraIndent).str();
|
||||
} else {
|
||||
Indent = CurrentIndent.str();
|
||||
}
|
||||
|
||||
ExtraIndentStreamPrinter Printer(OS, Indent);
|
||||
Printer.printNewline();
|
||||
SynthesizedCodablePrinter Walker(Printer);
|
||||
DC->getAsDecl()->walk(Walker);
|
||||
}
|
||||
|
||||
AddCodableContext
|
||||
AddCodableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) {
|
||||
AddCodableContext::getFromCursorInfo(ResolvedCursorInfoPtr Info) {
|
||||
auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
|
||||
if (!ValueRefInfo) {
|
||||
return AddCodableContext();
|
||||
return nullptr;
|
||||
}
|
||||
if (!ValueRefInfo->isRef()) {
|
||||
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
|
||||
return AddCodableContext(NomDecl);
|
||||
|
||||
if (auto *ext = ValueRefInfo->getExtTyRef()) {
|
||||
// For 'extension Outer.Inner: Codable {}', only 'Inner' part is valid.
|
||||
if (ext->getExtendedNominal() == ValueRefInfo->getValueD()) {
|
||||
return AddCodableContext(ext);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
// TODO: support extensions
|
||||
// (would need to get synthesized nodes from the main decl,
|
||||
// and only if it's in the same file?)
|
||||
return AddCodableContext();
|
||||
|
||||
if (!ValueRefInfo->isRef()) {
|
||||
if (auto *nominal = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
|
||||
return AddCodableContext(nominal);
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool RefactoringActionAddExplicitCodableImplementation::isApplicable(
|
||||
ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
|
||||
return AddCodableContext::getDeclarationContextFromInfo(Tok).isValid();
|
||||
return AddCodableContext::getFromCursorInfo(Tok).isApplicable();
|
||||
}
|
||||
|
||||
bool RefactoringActionAddExplicitCodableImplementation::performChange() {
|
||||
auto Context = AddCodableContext::getDeclarationContextFromInfo(CursorInfo);
|
||||
auto Context = AddCodableContext::getFromCursorInfo(CursorInfo);
|
||||
assert(Context.isApplicable() &&
|
||||
"Should not run performChange when refactoring is not applicable");
|
||||
|
||||
SmallString<64> Buffer;
|
||||
llvm::raw_svector_ostream OS(Buffer);
|
||||
Context.printInsertionText(CursorInfo, SM, OS);
|
||||
SourceLoc insertLoc;
|
||||
std::string insertText;
|
||||
Context.getInsertion(insertLoc, insertText);
|
||||
|
||||
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), OS.str());
|
||||
EditConsumer.insertAfter(SM, insertLoc, insertText);
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user