Skip to content

Commit

Permalink
analysis improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Mar 29, 2024
1 parent 6e1b268 commit 2b34b84
Showing 1 changed file with 17 additions and 36 deletions.
53 changes: 17 additions & 36 deletions core/ipyflow/analysis/live_refs.py
Expand Up @@ -231,22 +231,16 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
def visit_AugAssign(self, node: ast.AugAssign) -> None:
self.visit_Assign_impl([], node.value, aug_assign_target=node.target)

def visit_Import(self, node: ast.Import) -> None:
self.visit_Import_or_ImportFrom(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.visit_Import_or_ImportFrom(node)

def visit_Import_or_ImportFrom(
self, node: Union[ast.Import, ast.ImportFrom]
) -> None:
def visit_import(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
targets = []
for name in node.names:
if name.name == "*":
continue
targets.append(ast.Name(id=name.asname or name.name, ctx=ast.Store()))
self.visit_Assign_impl(targets, value=None)

visit_Import = visit_ImportFrom = visit_import

def visit_Assign_target(
self,
target_node: Union[
Expand All @@ -266,19 +260,15 @@ def visit_Assign_target(
else:
logger.warning("unsupported type for node %s" % target_node)

def visit_FunctionDef_or_AsyncFunctionDef(
def generic_visit_function_def(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
self.generic_visit(node.args.defaults)
self.generic_visit(node.decorator_list)
self.dead.add(SymbolRef(node).canonical())
self._func_ast_by_name[node.name] = node

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.visit_FunctionDef_or_AsyncFunctionDef(node)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.visit_FunctionDef_or_AsyncFunctionDef(node)
visit_FunctionDef = visit_AsyncFunctionDef = generic_visit_function_def

def visit_withitem(self, node: ast.withitem):
self.visit(node.context_expr)
Expand All @@ -305,34 +295,33 @@ def visit_Name(self, node: ast.Name) -> None:
)
)

def visit_Tuple_or_List(self, node: Union[ast.List, ast.Tuple]) -> None:
def visit_container_literal(
self, node: Union[ast.List, ast.Set, ast.Tuple]
) -> None:
with self.attrsub_context(False):
for elt in node.elts:
self.visit(elt)

def visit_List(self, node: ast.List) -> None:
self.visit_Tuple_or_List(node)

def visit_Tuple(self, node: ast.Tuple) -> None:
self.visit_Tuple_or_List(node)
visit_List = visit_Set = visit_Tuple = visit_container_literal

def visit_Dict(self, node: ast.Dict) -> None:
with self.attrsub_context(False):
self.generic_visit(node.keys)
self.generic_visit(node.values)

def visit_For(self, node: ast.For) -> None:
def generic_visit_for_loop(self, node: Union[ast.For, ast.AsyncFor]) -> None:
# Case "for a,b in something: "
self.visit(node.iter)
with self.kill_context():
self.visit(node.target)
for line in node.body:
self.visit(line)

visit_For = visit_AsyncFor = generic_visit_for_loop

def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.generic_visit(node.bases)
self.generic_visit(node.decorator_list)
self.generic_visit(node.body)
self.dead.add(SymbolRef(node).canonical())

def visit_Call(self, node: ast.Call) -> None:
Expand Down Expand Up @@ -391,19 +380,7 @@ def visit_Subscript(self, node: ast.Subscript) -> None:
def visit_Delete(self, node: ast.Delete) -> None:
pass

def visit_GeneratorExp(self, node) -> None:
self.visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(node)

def visit_DictComp(self, node) -> None:
self.visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(node)

def visit_ListComp(self, node) -> None:
self.visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(node)

def visit_SetComp(self, node) -> None:
self.visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(node)

def visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(self, node) -> None:
def generic_visit_comprehension(self, node) -> None:
with self.killed_context([gen.target for gen in node.generators]):
if isinstance(node, ast.DictComp):
self.visit(node.key)
Expand All @@ -414,6 +391,10 @@ def visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(self, node) -> None:
self.visit(gen.iter)
self.visit(gen.ifs)

visit_DictComp = (
visit_ListComp
) = visit_SetComp = visit_GeneratorExp = generic_visit_comprehension

def visit_Lambda(self, node: ast.Lambda) -> None:
with self.kill_context():
self.visit(node.args)
Expand Down

0 comments on commit 2b34b84

Please sign in to comment.