From 2b34b84abb217e97dcc5bfccb92e7bf68f419a5b Mon Sep 17 00:00:00 2001 From: Stephen Macke Date: Fri, 29 Mar 2024 15:12:27 -0700 Subject: [PATCH] analysis improvements --- core/ipyflow/analysis/live_refs.py | 53 ++++++++++-------------------- 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/core/ipyflow/analysis/live_refs.py b/core/ipyflow/analysis/live_refs.py index 86fd2410..91340713 100644 --- a/core/ipyflow/analysis/live_refs.py +++ b/core/ipyflow/analysis/live_refs.py @@ -231,15 +231,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: def visit_AugAssign(self, node: ast.AugAssign) -> None: self.visit_Assign_impl([], node.value, aug_assign_target=node.target) - def visit_Import(self, node: ast.Import) -> None: - self.visit_Import_or_ImportFrom(node) - - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - self.visit_Import_or_ImportFrom(node) - - def visit_Import_or_ImportFrom( - self, node: Union[ast.Import, ast.ImportFrom] - ) -> None: + def visit_import(self, node: Union[ast.Import, ast.ImportFrom]) -> None: targets = [] for name in node.names: if name.name == "*": @@ -247,6 +239,8 @@ def visit_Import_or_ImportFrom( targets.append(ast.Name(id=name.asname or name.name, ctx=ast.Store())) self.visit_Assign_impl(targets, value=None) + visit_Import = visit_ImportFrom = visit_import + def visit_Assign_target( self, target_node: Union[ @@ -266,7 +260,7 @@ def visit_Assign_target( else: logger.warning("unsupported type for node %s" % target_node) - def visit_FunctionDef_or_AsyncFunctionDef( + def generic_visit_function_def( self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef] ) -> None: self.generic_visit(node.args.defaults) @@ -274,11 +268,7 @@ def visit_FunctionDef_or_AsyncFunctionDef( self.dead.add(SymbolRef(node).canonical()) self._func_ast_by_name[node.name] = node - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - self.visit_FunctionDef_or_AsyncFunctionDef(node) - - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: - self.visit_FunctionDef_or_AsyncFunctionDef(node) + visit_FunctionDef = visit_AsyncFunctionDef = generic_visit_function_def def visit_withitem(self, node: ast.withitem): self.visit(node.context_expr) @@ -305,23 +295,21 @@ def visit_Name(self, node: ast.Name) -> None: ) ) - def visit_Tuple_or_List(self, node: Union[ast.List, ast.Tuple]) -> None: + def visit_container_literal( + self, node: Union[ast.List, ast.Set, ast.Tuple] + ) -> None: with self.attrsub_context(False): for elt in node.elts: self.visit(elt) - def visit_List(self, node: ast.List) -> None: - self.visit_Tuple_or_List(node) - - def visit_Tuple(self, node: ast.Tuple) -> None: - self.visit_Tuple_or_List(node) + visit_List = visit_Set = visit_Tuple = visit_container_literal def visit_Dict(self, node: ast.Dict) -> None: with self.attrsub_context(False): self.generic_visit(node.keys) self.generic_visit(node.values) - def visit_For(self, node: ast.For) -> None: + def generic_visit_for_loop(self, node: Union[ast.For, ast.AsyncFor]) -> None: # Case "for a,b in something: " self.visit(node.iter) with self.kill_context(): @@ -329,10 +317,11 @@ def visit_For(self, node: ast.For) -> None: for line in node.body: self.visit(line) + visit_For = visit_AsyncFor = generic_visit_for_loop + def visit_ClassDef(self, node: ast.ClassDef) -> None: self.generic_visit(node.bases) self.generic_visit(node.decorator_list) - self.generic_visit(node.body) self.dead.add(SymbolRef(node).canonical()) def visit_Call(self, node: ast.Call) -> None: @@ -391,19 +380,7 @@ def visit_Subscript(self, node: ast.Subscript) -> None: def visit_Delete(self, node: ast.Delete) -> None: pass - def visit_GeneratorExp(self, node) -> None: - self.visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(node) - - def visit_DictComp(self, node) -> None: - self.visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(node) - - def visit_ListComp(self, node) -> None: - self.visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(node) - - def visit_SetComp(self, node) -> None: - self.visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(node) - - def visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(self, node) -> None: + def generic_visit_comprehension(self, node) -> None: with self.killed_context([gen.target for gen in node.generators]): if isinstance(node, ast.DictComp): self.visit(node.key) @@ -414,6 +391,10 @@ def visit_GeneratorExp_or_DictComp_or_ListComp_or_SetComp(self, node) -> None: self.visit(gen.iter) self.visit(gen.ifs) + visit_DictComp = ( + visit_ListComp + ) = visit_SetComp = visit_GeneratorExp = generic_visit_comprehension + def visit_Lambda(self, node: ast.Lambda) -> None: with self.kill_context(): self.visit(node.args)