From fdfd8db90f93b3559039e89abc7a43f20f4fb3cd Mon Sep 17 00:00:00 2001 From: Stephen Macke Date: Sat, 9 Mar 2024 15:40:00 -0800 Subject: [PATCH] liveness analysis handles calls to functions defined in same cell --- core/ipyflow/analysis/live_refs.py | 51 ++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/core/ipyflow/analysis/live_refs.py b/core/ipyflow/analysis/live_refs.py index ca29e915..c5f60fea 100644 --- a/core/ipyflow/analysis/live_refs.py +++ b/core/ipyflow/analysis/live_refs.py @@ -1,10 +1,21 @@ # -*- coding: utf-8 -*- import ast import builtins +import itertools import logging import sys from contextlib import contextmanager -from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, + cast, +) from ipyflow.analysis.mixins import ( SaveOffAttributesMixin, @@ -53,6 +64,9 @@ def __init__( self._skip_simple_names = False self._is_lhs = False self._include_killed_live = include_killed_live + self._func_ast_by_name: Dict[ + str, Union[ast.FunctionDef, ast.AsyncFunctionDef] + ] = {} def __call__( self, node: ast.AST @@ -241,10 +255,19 @@ def visit_Assign_target( else: logger.warning("unsupported type for node %s" % target_node) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + def visit_FunctionDef_or_AsyncFunctionDef( + self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef] + ) -> None: self.generic_visit(node.args.defaults) self.generic_visit(node.decorator_list) self.dead.add(SymbolRef(node).canonical()) + self._func_ast_by_name[node.name] = node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self.visit_FunctionDef_or_AsyncFunctionDef(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self.visit_FunctionDef_or_AsyncFunctionDef(node) def visit_withitem(self, node: ast.withitem): self.visit(node.context_expr) @@ -310,6 +333,30 @@ def visit_Call(self, node: ast.Call) -> None: self._add_attrsub_to_live_if_eligible(SymbolRef(node)) with self.attrsub_context(): self.visit(node.func) + if isinstance(node.func, ast.Name) and node.func.id in self._func_ast_by_name: + func_ast = self._func_ast_by_name[node.func.id] + call_scope = None # TODO: figure this out + call_dead = { + SymbolRef.from_string(arg.arg, scope=self._scope) + for arg in itertools.chain( + func_ast.args.args or [], + func_ast.args.kw_defaults or [], + [a for a in [func_ast.args.vararg, func_ast.args.kwarg] if a], + getattr(func_ast.args, "posonlyargs", None) or [], + getattr(func_ast.args, "kwonlyargs", None) or [], + ) + } + call_live: Set[LiveSymbolRef] = set() + func_ast_by_name = dict(self._func_ast_by_name) + with self.push_attributes( + _scope=call_scope, + dead=call_dead, + live=call_live, + modified=set(), + _func_ast_by_name=func_ast_by_name, + ): + self.generic_visit(func_ast.body) + self.live |= call_live def visit_Attribute(self, node: ast.Attribute) -> None: if not self._inside_attrsub: