diff --git a/lib/AST/NameLookup.cpp b/lib/AST/NameLookup.cpp index 8b1387da730..424a2c5e131 100644 --- a/lib/AST/NameLookup.cpp +++ b/lib/AST/NameLookup.cpp @@ -17,8 +17,8 @@ #include "swift/AST/NameLookup.h" #include "swift/AST/AST.h" -#include "swift/AST/Diagnostics.h" -#include +#include "swift/AST/ASTVisitor.h" + using namespace swift; MemberLookup::MemberLookup(Type BaseTy, Identifier Name, Module &M) { @@ -229,9 +229,105 @@ Expr *MemberLookup::createResultAST(Expr *Base, SourceLoc DotLoc, NameLoc); } -/// lookupGlobalValue - Perform a value lookup within the current Module. -/// Unlike lookupValue, this does look through import declarations to resolve -/// the name. +struct FindLocalVal : public StmtVisitor { + SourceLoc Loc; + Identifier Name; + ValueDecl *MatchingValue; + + FindLocalVal(SourceLoc Loc, Identifier Name) + : Loc(Loc), Name(Name), MatchingValue(nullptr) {} + + bool IntersectsRange(SourceRange R) { + return R.Start.Value.getPointer() <= Loc.Value.getPointer() && + R.End.Value.getPointer() >= Loc.Value.getPointer(); + } + + void checkValueDecl(ValueDecl *D) { + if (D->getName() == Name) { + assert(!MatchingValue); + MatchingValue = D; + } + } + + void checkPattern(Pattern *Pat) { + switch (Pat->getKind()) { + case PatternKind::Tuple: + for (auto &field : cast(Pat)->getFields()) + checkPattern(field.getPattern()); + return; + case PatternKind::Paren: + return checkPattern(cast(Pat)->getSubPattern()); + case PatternKind::Typed: + return checkPattern(cast(Pat)->getSubPattern()); + case PatternKind::Named: + return checkValueDecl(cast(Pat)->getDecl()); + // Handle non-vars. + case PatternKind::Any: + return; + } + } + + void checkTranslationUnit(TranslationUnit *TU) { + for (Decl *D : TU->Decls) { + if (TopLevelCodeDecl *TLCD = dyn_cast(D)) { + if (Stmt *S = TLCD->getBody().dyn_cast()) + visit(S); + } + } + } + + void visitBreakStmt(BreakStmt *) {} + void visitContinueStmt(ContinueStmt *) {} + void visitSemiStmt(SemiStmt *) {} + void visitErrorStmt(ErrorStmt *) {} + void visitAssignStmt(AssignStmt *) {} + void visitReturnStmt(ReturnStmt *) {} + void visitIfStmt(IfStmt * S) { + visit(S->getThenStmt()); + if (S->getElseStmt()) + visit(S->getElseStmt()); + } + void visitWhileStmt (WhileStmt *S) { + visit(S->getBody()); + } + + void visitForStmt (ForStmt *S) { + if (!IntersectsRange(S->getSourceRange())) + return; + visit(S->getBody()); + if (MatchingValue) + return; + for (Decl *D : S->getInitializerVarDecls()) { + if (ValueDecl *VD = dyn_cast(D)) + checkValueDecl(VD); + } + } + void visitForEachStmt (ForEachStmt *S) { + if (!IntersectsRange(S->getSourceRange())) + return; + visit(S->getBody()); + if (MatchingValue) + return; + checkPattern(S->getPattern()); + } + void visitBraceStmt(BraceStmt *S) { + if (!IntersectsRange(S->getSourceRange())) + return; + for (auto elem : S->getElements()) { + if (Stmt *S = elem.dyn_cast()) + visit(S); + } + if (MatchingValue) + return; + for (auto elem : S->getElements()) { + if (Decl *D = elem.dyn_cast()) { + if (ValueDecl *VD = dyn_cast(D)) + checkValueDecl(VD); + } + } + } +}; + UnqualifiedLookup::UnqualifiedLookup(Identifier Name, DeclContext *DC, SourceLoc Loc) { typedef UnqualifiedLookupResult Result; @@ -248,6 +344,21 @@ UnqualifiedLookup::UnqualifiedLookup(Identifier Name, DeclContext *DC, ValueDecl *BaseDecl = 0; Type ExtendedType; if (FuncExpr *FE = dyn_cast(DC)) { + // Look for local variables; normally, the parser resolves these + // for us, but it can't do the right thing inside local types. + if (Loc.isValid()) { + FindLocalVal localVal(Loc, Name); + localVal.visit(FE->getBody()); + if (!localVal.MatchingValue) { + for (Pattern *P : FE->getParamPatterns()) + localVal.checkPattern(P); + } + if (localVal.MatchingValue) { + Results.push_back(Result::getLocalDecl(localVal.MatchingValue)); + return; + } + } + FuncDecl *FD = FE->getDecl(); if (FD && FD->getExtensionType() && !FD->isStatic()) { ExtendedType = FD->getExtensionType(); @@ -290,6 +401,22 @@ UnqualifiedLookup::UnqualifiedLookup(Identifier Name, DeclContext *DC, DC = DC->getParent(); } + if (Loc.isValid()) { + if (TranslationUnit *TU = dyn_cast(&M)) { + // Look for local variables in top-level code; normally, the parser + // resolves these for us, but it can't do the right thing for + // local types. + // FIXME: Need to suppress this for cases in name binding where we know + // it's a waste of time. + FindLocalVal localVal(Loc, Name); + localVal.checkTranslationUnit(TU); + if (localVal.MatchingValue) { + Results.push_back(Result::getLocalDecl(localVal.MatchingValue)); + return; + } + } + } + // Do a local lookup within the current module. llvm::SmallVector CurModuleResults; M.lookupValue(Module::AccessPathTy(), Name, NLKind::UnqualifiedLookup,