Skip to content

Commit

Permalink
liveness analysis handles calls to functions defined in same cell
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Mar 9, 2024
1 parent 05e1b0e commit fdfd8db
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions core/ipyflow/analysis/live_refs.py
@@ -1,10 +1,21 @@
# -*- coding: utf-8 -*-
import ast
import builtins
import itertools
import logging
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
cast,
)

from ipyflow.analysis.mixins import (
SaveOffAttributesMixin,
Expand Down Expand Up @@ -53,6 +64,9 @@ def __init__(
self._skip_simple_names = False
self._is_lhs = False
self._include_killed_live = include_killed_live
self._func_ast_by_name: Dict[
str, Union[ast.FunctionDef, ast.AsyncFunctionDef]
] = {}

def __call__(
self, node: ast.AST
Expand Down Expand Up @@ -241,10 +255,19 @@ def visit_Assign_target(
else:
logger.warning("unsupported type for node %s" % target_node)

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
def visit_FunctionDef_or_AsyncFunctionDef(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
self.generic_visit(node.args.defaults)
self.generic_visit(node.decorator_list)
self.dead.add(SymbolRef(node).canonical())
self._func_ast_by_name[node.name] = node

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.visit_FunctionDef_or_AsyncFunctionDef(node)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.visit_FunctionDef_or_AsyncFunctionDef(node)

def visit_withitem(self, node: ast.withitem):
self.visit(node.context_expr)
Expand Down Expand Up @@ -310,6 +333,30 @@ def visit_Call(self, node: ast.Call) -> None:
self._add_attrsub_to_live_if_eligible(SymbolRef(node))
with self.attrsub_context():
self.visit(node.func)
if isinstance(node.func, ast.Name) and node.func.id in self._func_ast_by_name:
func_ast = self._func_ast_by_name[node.func.id]
call_scope = None # TODO: figure this out
call_dead = {
SymbolRef.from_string(arg.arg, scope=self._scope)
for arg in itertools.chain(
func_ast.args.args or [],
func_ast.args.kw_defaults or [],
[a for a in [func_ast.args.vararg, func_ast.args.kwarg] if a],
getattr(func_ast.args, "posonlyargs", None) or [],
getattr(func_ast.args, "kwonlyargs", None) or [],
)
}
call_live: Set[LiveSymbolRef] = set()
func_ast_by_name = dict(self._func_ast_by_name)
with self.push_attributes(
_scope=call_scope,
dead=call_dead,
live=call_live,
modified=set(),
_func_ast_by_name=func_ast_by_name,
):
self.generic_visit(func_ast.body)
self.live |= call_live

def visit_Attribute(self, node: ast.Attribute) -> None:
if not self._inside_attrsub:
Expand Down

0 comments on commit fdfd8db

Please sign in to comment.