mirror of
https://github.com/apple/swift.git
synced 2025-12-21 12:14:44 +01:00
Merge pull request #29596 from Regno/feature/vlasov/SR-5740
[Source Tooling] Refactoring action to convert if statement to switch
This commit is contained in:
@@ -2243,6 +2243,272 @@ bool RefactoringActionConvertGuardExprToIfLetExpr::performChange() {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RefactoringActionConvertToSwitchStmt::
|
||||
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {
|
||||
|
||||
class ConditionalChecker : public ASTWalker {
|
||||
public:
|
||||
bool ParamsUseSameVars = true;
|
||||
bool ConditionUseOnlyAllowedFunctions = false;
|
||||
StringRef ExpectName;
|
||||
|
||||
Expr *walkToExprPost(Expr *E) {
|
||||
if (E->getKind() != ExprKind::DeclRef)
|
||||
return E;
|
||||
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
|
||||
if (D->getKind() == DeclKind::Var || D->getKind() == DeclKind::Param)
|
||||
ParamsUseSameVars = checkName(dyn_cast<VarDecl>(D));
|
||||
if (D->getKind() == DeclKind::Func)
|
||||
ConditionUseOnlyAllowedFunctions = checkName(dyn_cast<FuncDecl>(D));
|
||||
if (allCheckPassed())
|
||||
return E;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool allCheckPassed() {
|
||||
return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions;
|
||||
}
|
||||
|
||||
private:
|
||||
bool checkName(VarDecl *VD) {
|
||||
auto Name = VD->getName().str();
|
||||
if (ExpectName.empty())
|
||||
ExpectName = Name;
|
||||
return Name == ExpectName;
|
||||
}
|
||||
|
||||
bool checkName(FuncDecl *FD) {
|
||||
auto Name = FD->getName().str();
|
||||
return Name == "~="
|
||||
|| Name == "=="
|
||||
|| Name == "__derived_enum_equals"
|
||||
|| Name == "__derived_struct_equals"
|
||||
|| Name == "||"
|
||||
|| Name == "...";
|
||||
}
|
||||
};
|
||||
|
||||
class SwitchConvertable {
|
||||
public:
|
||||
SwitchConvertable(ResolvedRangeInfo Info) {
|
||||
this->Info = Info;
|
||||
}
|
||||
|
||||
bool isApplicable() {
|
||||
if (Info.Kind != RangeKind::SingleStatement)
|
||||
return false;
|
||||
if (!findIfStmt())
|
||||
return false;
|
||||
return checkEachCondition();
|
||||
}
|
||||
|
||||
private:
|
||||
ResolvedRangeInfo Info;
|
||||
IfStmt *If = nullptr;
|
||||
ConditionalChecker checker;
|
||||
|
||||
bool findIfStmt() {
|
||||
if (Info.ContainedNodes.size() != 1)
|
||||
return false;
|
||||
if (auto S = Info.ContainedNodes.front().dyn_cast<Stmt*>())
|
||||
If = dyn_cast<IfStmt>(S);
|
||||
return If != nullptr;
|
||||
}
|
||||
|
||||
bool checkEachCondition() {
|
||||
checker = ConditionalChecker();
|
||||
do {
|
||||
if (!checkEachElement())
|
||||
return false;
|
||||
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool checkEachElement() {
|
||||
bool result = true;
|
||||
auto ConditionalList = If->getCond();
|
||||
for (auto Element : ConditionalList) {
|
||||
result &= check(Element);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool check(StmtConditionElement ConditionElement) {
|
||||
if (ConditionElement.getKind() == StmtConditionElement::CK_Availability)
|
||||
return false;
|
||||
if (ConditionElement.getKind() == StmtConditionElement::CK_PatternBinding)
|
||||
checker.ConditionUseOnlyAllowedFunctions = true;
|
||||
ConditionElement.walk(checker);
|
||||
return checker.allCheckPassed();
|
||||
}
|
||||
};
|
||||
return SwitchConvertable(Info).isApplicable();
|
||||
}
|
||||
|
||||
bool RefactoringActionConvertToSwitchStmt::performChange() {
|
||||
|
||||
class VarNameFinder : public ASTWalker {
|
||||
public:
|
||||
std::string VarName;
|
||||
|
||||
Expr *walkToExprPost(Expr *E) {
|
||||
if (E->getKind() != ExprKind::DeclRef)
|
||||
return E;
|
||||
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
|
||||
if (D->getKind() != DeclKind::Var && D->getKind() != DeclKind::Param)
|
||||
return E;
|
||||
VarName = dyn_cast<VarDecl>(D)->getName().str().str();
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
class ConditionalPatternFinder : public ASTWalker {
|
||||
public:
|
||||
ConditionalPatternFinder(SourceManager &SM) : SM(SM) {}
|
||||
|
||||
SmallString<64> ConditionalPattern = SmallString<64>();
|
||||
|
||||
Expr *walkToExprPost(Expr *E) {
|
||||
if (E->getKind() != ExprKind::Binary)
|
||||
return E;
|
||||
auto BE = dyn_cast<BinaryExpr>(E);
|
||||
if (isFunctionNameAllowed(BE))
|
||||
appendPattern(dyn_cast<BinaryExpr>(E)->getArg());
|
||||
return E;
|
||||
}
|
||||
|
||||
std::pair<bool, Pattern*> walkToPatternPre(Pattern *P) {
|
||||
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, P->getSourceRange()).str());
|
||||
if (P->getKind() == PatternKind::OptionalSome)
|
||||
ConditionalPattern.append("?");
|
||||
return { true, nullptr };
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
SourceManager &SM;
|
||||
|
||||
bool isFunctionNameAllowed(BinaryExpr *E) {
|
||||
auto FunctionBody = dyn_cast<DotSyntaxCallExpr>(E->getFn())->getFn();
|
||||
auto FunctionDeclaration = dyn_cast<DeclRefExpr>(FunctionBody)->getDecl();
|
||||
auto FunctionName = dyn_cast<FuncDecl>(FunctionDeclaration)->getName().str();
|
||||
return FunctionName == "~="
|
||||
|| FunctionName == "=="
|
||||
|| FunctionName == "__derived_enum_equals"
|
||||
|| FunctionName == "__derived_struct_equals";
|
||||
}
|
||||
|
||||
void appendPattern(TupleExpr *Tuple) {
|
||||
auto PatternArgument = Tuple->getElements().back();
|
||||
if (PatternArgument->getKind() == ExprKind::DeclRef)
|
||||
PatternArgument = Tuple->getElements().front();
|
||||
if (ConditionalPattern.size() > 0)
|
||||
ConditionalPattern.append(", ");
|
||||
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, PatternArgument->getSourceRange()).str());
|
||||
}
|
||||
};
|
||||
|
||||
class ConverterToSwitch {
|
||||
public:
|
||||
ConverterToSwitch(ResolvedRangeInfo Info, SourceManager &SM) : SM(SM) {
|
||||
this->Info = Info;
|
||||
}
|
||||
|
||||
void performConvert(SmallString<64> &Out) {
|
||||
If = findIf();
|
||||
OptionalLabel = If->getLabelInfo().Name.str().str();
|
||||
ControlExpression = findControlExpression();
|
||||
findPatternsAndBodies(PatternsAndBodies);
|
||||
DefaultStatements = findDefaultStatements();
|
||||
makeSwitchStatement(Out);
|
||||
}
|
||||
|
||||
private:
|
||||
ResolvedRangeInfo Info;
|
||||
SourceManager &SM;
|
||||
|
||||
IfStmt *If;
|
||||
IfStmt *PreviousIf;
|
||||
|
||||
std::string OptionalLabel;
|
||||
std::string ControlExpression;
|
||||
SmallVector<std::pair<std::string, std::string>, 16> PatternsAndBodies;
|
||||
std::string DefaultStatements;
|
||||
|
||||
IfStmt *findIf() {
|
||||
auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>();
|
||||
return dyn_cast<IfStmt>(S);
|
||||
}
|
||||
|
||||
std::string findControlExpression() {
|
||||
auto ConditionElement = If->getCond().front();
|
||||
auto Finder = VarNameFinder();
|
||||
ConditionElement.walk(Finder);
|
||||
return Finder.VarName;
|
||||
}
|
||||
|
||||
void findPatternsAndBodies(SmallVectorImpl<std::pair<std::string, std::string>> &Out) {
|
||||
do {
|
||||
auto pattern = findPattern();
|
||||
auto body = findBodyStatements();
|
||||
Out.push_back(std::make_pair(pattern, body));
|
||||
PreviousIf = If;
|
||||
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
|
||||
}
|
||||
|
||||
std::string findPattern() {
|
||||
auto ConditionElement = If->getCond().front();
|
||||
auto Finder = ConditionalPatternFinder(SM);
|
||||
ConditionElement.walk(Finder);
|
||||
return Finder.ConditionalPattern.str().str();
|
||||
}
|
||||
|
||||
std::string findBodyStatements() {
|
||||
return findBodyWithoutBraces(If->getThenStmt());
|
||||
}
|
||||
|
||||
std::string findDefaultStatements() {
|
||||
auto ElseBody = dyn_cast_or_null<BraceStmt>(PreviousIf->getElseStmt());
|
||||
if (!ElseBody)
|
||||
return getTokenText(tok::kw_break);
|
||||
return findBodyWithoutBraces(ElseBody);
|
||||
}
|
||||
|
||||
std::string findBodyWithoutBraces(Stmt *body) {
|
||||
auto BS = dyn_cast<BraceStmt>(body);
|
||||
if (!BS)
|
||||
return Lexer::getCharSourceRangeFromSourceRange(SM, body->getSourceRange()).str().str();
|
||||
if (BS->getElements().empty())
|
||||
return getTokenText(tok::kw_break);
|
||||
SourceRange BodyRange = BS->getElements().front().getSourceRange();
|
||||
BodyRange.widen(BS->getElements().back().getSourceRange());
|
||||
return Lexer::getCharSourceRangeFromSourceRange(SM, BodyRange).str().str();
|
||||
}
|
||||
|
||||
void makeSwitchStatement(SmallString<64> &Out) {
|
||||
StringRef Space = " ";
|
||||
StringRef NewLine = "\n";
|
||||
llvm::raw_svector_ostream OS(Out);
|
||||
if (OptionalLabel.size() > 0)
|
||||
OS << OptionalLabel << ":" << Space;
|
||||
OS << tok::kw_switch << Space << ControlExpression << Space << tok::l_brace << NewLine;
|
||||
for (auto &pair : PatternsAndBodies) {
|
||||
OS << tok::kw_case << Space << pair.first << tok::colon << NewLine;
|
||||
OS << pair.second << NewLine;
|
||||
}
|
||||
OS << tok::kw_default << tok::colon << NewLine;
|
||||
OS << DefaultStatements << NewLine;
|
||||
OS << tok::r_brace;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
SmallString<64> result;
|
||||
ConverterToSwitch(RangeInfo, SM).performConvert(result);
|
||||
EditConsumer.accept(SM, RangeInfo.ContentRange, result.str());
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Struct containing info about an IfStmt that can be converted into an IfExpr.
|
||||
struct ConvertToTernaryExprInfo {
|
||||
ConvertToTernaryExprInfo() {}
|
||||
|
||||
Reference in New Issue
Block a user