[refactoring] Rework "add codable implementation" refactoring

* Support extensions including conditional conformance
* Correct access modifiers
* More correct lookup for the synthesized declarations
* Avoid printing decls in nested types (rdar://98025945)
This commit is contained in:
Rintaro Ishizaki
2024-03-13 06:20:28 +09:00
parent 990c870754
commit 39606e6269
22 changed files with 658 additions and 159 deletions

View File

@@ -12,6 +12,7 @@
#include "RefactoringActions.h"
#include "Utils.h"
#include "swift/AST/ProtocolConformance.h"
using namespace swift::refactoring;
@@ -19,176 +20,201 @@ namespace {
class AddCodableContext {
/// Declaration context
DeclContext *DC;
IterableDeclContext *IDC;
/// Start location of declaration context brace
SourceLoc StartLoc;
AddCodableContext(NominalTypeDecl *nominal) : IDC(nominal){};
AddCodableContext(ExtensionDecl *extension) : IDC(extension){};
AddCodableContext(std::nullptr_t) : IDC(nullptr){};
/// Array of all conformed protocols
SmallVector<swift::ProtocolDecl *, 2> Protocols;
const NominalTypeDecl *getNominal() const {
switch (IDC->getIterableContextKind()) {
case IterableDeclContextKind::NominalTypeDecl:
return cast<NominalTypeDecl>(IDC);
case IterableDeclContextKind::ExtensionDecl:
return cast<ExtensionDecl>(IDC)->getExtendedNominal();
}
assert(false && "unhandled IterableDeclContextKind");
}
/// Range of internal members in declaration
DeclRange Range;
/// Get the left brace location of the type-or-extension decl.
SourceLoc getLeftBraceLoc() const {
switch (IDC->getIterableContextKind()) {
case IterableDeclContextKind::NominalTypeDecl:
return cast<NominalTypeDecl>(IDC)->getBraces().Start;
case IterableDeclContextKind::ExtensionDecl:
return cast<ExtensionDecl>(IDC)->getBraces().Start;
}
assert(false && "unhandled IterableDeclContextKind");
}
bool conformsToCodableProtocol() {
for (ProtocolDecl *Protocol : Protocols) {
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Encodable ||
Protocol->getKnownProtocolKind() == KnownProtocolKind::Decodable) {
return true;
/// Get the token location where the text should be inserted after.
SourceLoc getInsertStartLoc() const {
// Prefer the end of elements.
for (auto *member : llvm::reverse(IDC->getParsedMembers())) {
if (isa<AccessorDecl>(member) || isa<VarDecl>(member)) {
// These are part of 'PatternBindingDecl' but are hoisted in AST.
continue;
}
return member->getEndLoc();
}
// After the starting brace if empty.
return getLeftBraceLoc();
}
std::string getBaseIndent() const {
SourceManager &SM = IDC->getDecl()->getASTContext().SourceMgr;
SourceLoc startLoc = getInsertStartLoc();
StringRef extraIndent;
StringRef currentIndent =
Lexer::getIndentationForLine(SM, startLoc, &extraIndent);
if (startLoc == getLeftBraceLoc()) {
return (currentIndent + extraIndent).str();
} else {
return currentIndent.str();
}
}
void printInsertText(llvm::raw_ostream &OS) const {
auto &ctx = IDC->getDecl()->getASTContext();
PrintOptions Options = PrintOptions::printDeclarations();
Options.SynthesizeSugarOnTypes = true;
Options.FunctionDefinitions = true;
Options.VarInitializers = true;
Options.PrintExprs = true;
Options.TypeDefinitions = false;
Options.PrintSpaceBeforeInheritance = false;
Options.ExcludeAttrList.push_back(DeclAttrKind::HasInitialValue);
Options.PrintInternalAccessKeyword = false;
std::string baseIndent = getBaseIndent();
ExtraIndentStreamPrinter Printer(OS, baseIndent);
// The insertion starts at the end of the last token.
Printer.printNewline();
// Synthesized 'CodingKeys' are placed in the main nominal decl.
// Iterate members and look for synthesized enums that conforms to
// 'CodingKey' protocol.
auto *codingKeyProto = ctx.getProtocol(KnownProtocolKind::CodingKey);
for (auto *member : getNominal()->getMembers()) {
auto *enumD = dyn_cast<EnumDecl>(member);
if (!enumD || !enumD->isSynthesized())
continue;
llvm::SmallVector<ProtocolConformance *, 1> codingKeyConformance;
if (!enumD->lookupConformance(codingKeyProto, codingKeyConformance))
continue;
// Print the decl, but without the body.
Printer.printNewline();
enumD->print(Printer, Options);
// Manually print elements because CodingKey enums have some synthesized
// members for the protocol conformance e.g 'init(intValue:)'.
// We don't want to print them here.
Printer << " {";
Printer.printNewline();
Printer.setIndent(2);
for (auto *elementD : enumD->getAllElements()) {
elementD->print(Printer, Options);
Printer.printNewline();
}
Printer.setIndent(0);
Printer << "}";
Printer.printNewline();
}
// Look for synthesized witness decls and print them.
for (auto *conformance : IDC->getLocalConformances()) {
auto protocol = conformance->getProtocol();
auto kind = protocol->getKnownProtocolKind();
if (kind == KnownProtocolKind::Encodable ||
kind == KnownProtocolKind::Decodable) {
for (auto requirement : protocol->getProtocolRequirements()) {
auto witness = conformance->getWitnessDecl(requirement);
if (witness->isSynthesized()) {
Printer.printNewline();
witness->print(Printer, Options);
Printer.printNewline();
}
}
}
}
}
public:
static AddCodableContext getFromCursorInfo(ResolvedCursorInfoPtr Info);
bool isApplicable() const {
if (!IDC || !getNominal())
return false;
// Check if 'IDC' conforms to 'Encodable' and/or 'Decodable' and any of the
// requirements are synthesized.
for (auto *conformance : IDC->getLocalConformances()) {
auto protocol = conformance->getProtocol();
auto kind = protocol->getKnownProtocolKind();
if (kind == KnownProtocolKind::Encodable ||
kind == KnownProtocolKind::Decodable) {
// Check if any of the protocol requirements are synthesized.
for (auto requirement : protocol->getProtocolRequirements()) {
auto witness = conformance->getWitnessDecl(requirement);
if (!witness || witness->isSynthesized())
return true;
}
}
}
return false;
}
public:
AddCodableContext(NominalTypeDecl *Decl)
: DC(Decl), StartLoc(Decl->getBraces().Start),
Protocols(getAllProtocols(Decl)), Range(Decl->getMembers()){};
AddCodableContext(ExtensionDecl *Decl)
: DC(Decl), StartLoc(Decl->getBraces().Start),
Protocols(getAllProtocols(Decl->getExtendedNominal())),
Range(Decl->getMembers()){};
AddCodableContext() : DC(nullptr), Protocols(), Range(nullptr, nullptr){};
static AddCodableContext
getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info);
void printInsertionText(ResolvedCursorInfoPtr CursorInfo, SourceManager &SM,
llvm::raw_ostream &OS);
bool isValid() { return StartLoc.isValid() && conformsToCodableProtocol(); }
SourceLoc getInsertStartLoc();
};
SourceLoc AddCodableContext::getInsertStartLoc() {
SourceLoc MaxLoc = StartLoc;
for (auto Mem : Range) {
if (Mem->getEndLoc().getOpaquePointerValue() >
MaxLoc.getOpaquePointerValue()) {
MaxLoc = Mem->getEndLoc();
}
}
return MaxLoc;
}
/// Walks an AST and prints the synthesized Codable implementation.
class SynthesizedCodablePrinter : public ASTWalker {
private:
ASTPrinter &Printer;
public:
SynthesizedCodablePrinter(ASTPrinter &Printer) : Printer(Printer) {}
MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}
PreWalkAction walkToDeclPre(Decl *D) override {
auto *VD = dyn_cast<ValueDecl>(D);
if (!VD)
return Action::SkipNode();
if (!VD->isSynthesized()) {
return Action::Continue();
}
SmallString<32> Scratch;
auto name = VD->getName().getString(Scratch);
// Print all synthesized enums,
// since Codable can synthesize multiple enums (for associated values).
auto shouldPrint =
isa<EnumDecl>(VD) || name == "init(from:)" || name == "encode(to:)";
if (!shouldPrint) {
// Some other synthesized decl that we don't want to print.
return Action::SkipNode();
}
Printer.printNewline();
if (auto enumDecl = dyn_cast<EnumDecl>(D)) {
// Manually print enum here, since we don't want to print synthesized
// functions.
Printer << "enum " << enumDecl->getNameStr();
PrintOptions Options;
Options.PrintSpaceBeforeInheritance = false;
enumDecl->printInherited(Printer, Options);
Printer << " {";
for (Decl *EC : enumDecl->getAllElements()) {
Printer.printNewline();
Printer << " ";
EC->print(Printer, Options);
}
Printer.printNewline();
Printer << "}";
return Action::SkipNode();
}
PrintOptions Options;
Options.SynthesizeSugarOnTypes = true;
Options.FunctionDefinitions = true;
Options.VarInitializers = true;
Options.PrintExprs = true;
Options.TypeDefinitions = true;
Options.ExcludeAttrList.push_back(DeclAttrKind::HasInitialValue);
Printer.printNewline();
D->print(Printer, Options);
return Action::SkipNode();
void getInsertion(SourceLoc &insertLoc, std::string &insertText) const {
insertLoc = getInsertStartLoc();
llvm::raw_string_ostream OS(insertText);
printInsertText(OS);
}
};
void AddCodableContext::printInsertionText(ResolvedCursorInfoPtr CursorInfo,
SourceManager &SM,
llvm::raw_ostream &OS) {
StringRef ExtraIndent;
StringRef CurrentIndent =
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
std::string Indent;
if (getInsertStartLoc() == StartLoc) {
Indent = (CurrentIndent + ExtraIndent).str();
} else {
Indent = CurrentIndent.str();
}
ExtraIndentStreamPrinter Printer(OS, Indent);
Printer.printNewline();
SynthesizedCodablePrinter Walker(Printer);
DC->getAsDecl()->walk(Walker);
}
AddCodableContext
AddCodableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) {
AddCodableContext::getFromCursorInfo(ResolvedCursorInfoPtr Info) {
auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
if (!ValueRefInfo) {
return AddCodableContext();
return nullptr;
}
if (!ValueRefInfo->isRef()) {
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
return AddCodableContext(NomDecl);
if (auto *ext = ValueRefInfo->getExtTyRef()) {
// For 'extension Outer.Inner: Codable {}', only 'Inner' part is valid.
if (ext->getExtendedNominal() == ValueRefInfo->getValueD()) {
return AddCodableContext(ext);
} else {
return nullptr;
}
}
// TODO: support extensions
// (would need to get synthesized nodes from the main decl,
// and only if it's in the same file?)
return AddCodableContext();
if (!ValueRefInfo->isRef()) {
if (auto *nominal = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
return AddCodableContext(nominal);
}
}
return nullptr;
}
} // namespace
bool RefactoringActionAddExplicitCodableImplementation::isApplicable(
ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
return AddCodableContext::getDeclarationContextFromInfo(Tok).isValid();
return AddCodableContext::getFromCursorInfo(Tok).isApplicable();
}
bool RefactoringActionAddExplicitCodableImplementation::performChange() {
auto Context = AddCodableContext::getDeclarationContextFromInfo(CursorInfo);
auto Context = AddCodableContext::getFromCursorInfo(CursorInfo);
assert(Context.isApplicable() &&
"Should not run performChange when refactoring is not applicable");
SmallString<64> Buffer;
llvm::raw_svector_ostream OS(Buffer);
Context.printInsertionText(CursorInfo, SM, OS);
SourceLoc insertLoc;
std::string insertText;
Context.getInsertion(insertLoc, insertText);
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), OS.str());
EditConsumer.insertAfter(SM, insertLoc, insertText);
return false;
}