Skip to content

Commit

Permalink
track code ranges in symbol references
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Feb 6, 2024
1 parent 2a77e63 commit 26f91d4
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 17 deletions.
21 changes: 14 additions & 7 deletions core/ipyflow/analysis/live_refs.py
Expand Up @@ -12,7 +12,7 @@
VisitListsMixin,
)
from ipyflow.analysis.resolved_symbols import ResolvedSymbol
from ipyflow.analysis.symbol_ref import Atom, LiveSymbolRef, SymbolRef
from ipyflow.analysis.symbol_ref import Atom, LiveSymbolRef, SymbolRef, visit_stack
from ipyflow.config import FlowDirection
from ipyflow.data_model.timestamp import Timestamp
from ipyflow.singletons import flow, tracer
Expand Down Expand Up @@ -91,7 +91,7 @@ def attrsub_context(self, inside=True):
return self.push_attributes(_inside_attrsub=inside, _skip_simple_names=inside)

def _add_attrsub_to_live_if_eligible(self, ref: SymbolRef) -> None:
is_killed = ref.nonreactive() in self.dead
is_killed = ref.canonical() in self.dead
if is_killed and not self._include_killed_live:
return
if len(ref.chain) == 0:
Expand Down Expand Up @@ -229,7 +229,7 @@ def visit_Assign_target(
],
) -> None:
if isinstance(target_node, (ast.Name, ast.Attribute, ast.Subscript)):
self.dead.add(SymbolRef(target_node, scope=self._scope).nonreactive())
self.dead.add(SymbolRef(target_node, scope=self._scope).canonical())
if isinstance(target_node, ast.Subscript):
with self.live_context():
self.visit(target_node.slice)
Expand All @@ -244,7 +244,7 @@ def visit_Assign_target(
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.generic_visit(node.args.defaults)
self.generic_visit(node.decorator_list)
self.dead.add(SymbolRef(node).nonreactive())
self.dead.add(SymbolRef(node).canonical())

def visit_withitem(self, node: ast.withitem):
self.visit(node.context_expr)
Expand All @@ -255,7 +255,7 @@ def visit_withitem(self, node: ast.withitem):
def visit_Name(self, node: ast.Name) -> None:
ref = SymbolRef(node, scope=self._scope)
if self._in_kill_context:
self.dead.add(ref.nonreactive())
self.dead.add(ref.canonical())
elif not self._skip_simple_names:
is_killed = ref in self.dead
if is_killed and not self._include_killed_live:
Expand Down Expand Up @@ -299,7 +299,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.generic_visit(node.bases)
self.generic_visit(node.decorator_list)
self.generic_visit(node.body)
self.dead.add(SymbolRef(node).nonreactive())
self.dead.add(SymbolRef(node).canonical())

def visit_Call(self, node: ast.Call) -> None:
with self.attrsub_context(False):
Expand Down Expand Up @@ -358,7 +358,7 @@ def visit_Lambda(self, node: ast.Lambda) -> None:
def visit_arg(self, node) -> None:
ref = SymbolRef(node.arg, scope=self._scope)
if self._in_kill_context:
self.dead.add(ref.nonreactive())
self.dead.add(ref.canonical())
elif not self._skip_simple_names:
is_killed = ref in self.dead
if is_killed and not self._include_killed_live:
Expand All @@ -378,6 +378,13 @@ def visit_Module(self, node: ast.Module) -> None:
self.visit(child)
self._module_stmt_counter += 1

def visit(self, node):
visit_stack.append(node)
try:
return super().visit(node)
finally:
visit_stack.pop()


def get_symbols_for_references(
symbol_refs: Iterable[SymbolRef],
Expand Down
43 changes: 37 additions & 6 deletions core/ipyflow/analysis/symbol_ref.py
Expand Up @@ -19,7 +19,7 @@
from ipyflow.singletons import flow, tracer
from ipyflow.types import SubscriptIndices, SupportedIndexType
from ipyflow.utils import CommonEqualityMixin
from ipyflow.utils.ast_utils import subscript_to_slice
from ipyflow.utils.ast_utils import AstRange, subscript_to_slice

if TYPE_CHECKING:
from ipyflow.data_model.symbol import Scope, Symbol
Expand Down Expand Up @@ -312,13 +312,17 @@ def generic_visit(self, node) -> None:
return


visit_stack: List[ast.AST] = []


class SymbolRef:
_cached_symbol_ref_visitor = SymbolRefVisitor()

def __init__(
self,
symbols: Union[ast.AST, Atom, Sequence[Atom]],
scope: Optional["Scope"] = None,
ast_range: Optional[AstRange] = None,
) -> None:
# FIXME: each symbol should distinguish between attribute and subscript
# FIXME: bumped in priority 2021/09/07
Expand All @@ -336,19 +340,25 @@ def __init__(
ast.ImportFrom,
),
):
ast_range = ast_range or AstRange.from_ast_node(
symbols if hasattr(symbols, "lineno") else visit_stack[-1]
)
symbols = self._cached_symbol_ref_visitor(symbols, scope=scope).chain
elif isinstance(symbols, ast.AST): # pragma: no cover
raise TypeError("unexpected type for %s" % symbols)
elif isinstance(symbols, Atom):
symbols = [symbols]
self.chain: Tuple[Atom, ...] = tuple(symbols)
self.scope = scope
self.scope: Optional["Scope"] = scope
self.ast_range: Optional[AstRange] = ast_range

@classmethod
def from_string(
cls, symbol_str: str, scope: Optional["Scope"] = None
) -> "SymbolRef":
return cls(ast.parse(symbol_str, mode="eval").body, scope=scope)
ret = cls(ast.parse(symbol_str, mode="eval").body, scope=scope)
ret.ast_range = None
return ret

def to_symbol(self, scope: Optional["Scope"] = None) -> Optional["Symbol"]:
for resolved in self.gen_resolved_symbols(
Expand All @@ -370,21 +380,42 @@ def resolve(cls, symbol_str: str) -> Optional["Symbol"]:
return cls.from_string(symbol_str).to_symbol()

def __hash__(self) -> int:
# intentionally omit self.scope
return hash(self.chain)

def __eq__(self, other) -> bool:
# intentionally omit self.scope
return isinstance(other, SymbolRef) and self.chain == other.chain
if not isinstance(other, SymbolRef):
return False
if (
self.ast_range is not None
and other.ast_range is not None
and self.ast_range != other.ast_range
):
# goal: equality checks should compare against ast_range when it is set to ensure that
# different ranges get different SymbolRefs in sets and dicts, but containment checks
# that don't set the range (and therefore don't care about it) don't use it.
return False
if (
self.scope is not None
and other.scope is not None
and self.scope is not other.scope
):
# same for scope
return False
return self.chain == other.chain

def __repr__(self) -> str:
return repr(self.chain)

def __str__(self) -> str:
return repr(self)

def nonreactive(self) -> "SymbolRef":
def canonical(self) -> "SymbolRef":
return self.__class__(
[atom.nonreactive() for atom in self.chain], scope=self.scope
[atom.nonreactive() for atom in self.chain],
scope=None,
ast_range=None,
)

def gen_resolved_symbols(
Expand Down
2 changes: 1 addition & 1 deletion core/ipyflow/data_model/scope.py
Expand Up @@ -214,7 +214,7 @@ def _compute_is_static_write_for_assign(self, sym: Symbol) -> bool:
return False
try:
return (
SymbolRef(sym.symbol_node, scope=self).nonreactive()
SymbolRef(sym.symbol_node, scope=self).canonical()
in compute_live_dead_symbol_refs(sym.stmt_node, self)[1]
)
except TypeError:
Expand Down
19 changes: 19 additions & 0 deletions core/ipyflow/utils/ast_utils.py
@@ -1,2 +1,21 @@
# -*- coding: utf-8 -*-
import ast
from typing import NamedTuple, Optional

from pyccolo._fast.misc_ast_utils import subscript_to_slice # noqa: F401


class AstRange(NamedTuple):
lineno: int
end_lineno: Optional[int]
col_offset: int
end_col_offset: Optional[int]

@classmethod
def from_ast_node(cls, node: ast.AST) -> "AstRange":
return cls(
lineno=node.lineno,
end_lineno=getattr(node, "end_lineno", None),
col_offset=node.col_offset,
end_col_offset=getattr(node, "end_col_offset", None),
)
21 changes: 18 additions & 3 deletions core/test/test_liveness_analysis.py
Expand Up @@ -24,17 +24,23 @@ def _simplify_symbol_refs(symbols: Set[SymbolRef]) -> Set[str]:
return simplified


def compute_live_dead_symbol_refs(
def compute_live_dead_symbol_refs_raw(
code: Union[str, ast.AST]
) -> Tuple[Set[str], Set[str]]:
) -> Tuple[Set[SymbolRef], Set[SymbolRef]]:
if isinstance(code, str):
code = textwrap.dedent(code)
live, dead, *_ = compute_live_dead_symbol_refs_with_stmts(code)
live = {ref.ref for ref in live}
live, dead = _simplify_symbol_refs(live), _simplify_symbol_refs(dead)
return live, dead


def compute_live_dead_symbol_refs(
code: Union[str, ast.AST]
) -> Tuple[Set[str], Set[str]]:
live, dead = compute_live_dead_symbol_refs_raw(code)
return _simplify_symbol_refs(live), _simplify_symbol_refs(dead)


def test_simple():
live, dead = compute_live_dead_symbol_refs(
"""
Expand Down Expand Up @@ -103,3 +109,12 @@ def test_walrus():
)
assert live == {"x"}
assert dead == {"y", "z"}, "got %s" % dead

def test_positions_simple():
live, dead = compute_live_dead_symbol_refs_raw("foo")
assert len(live) == 1
foo = next(iter(live))
assert foo.ast_range.lineno == 1
assert foo.ast_range.end_lineno == 1
assert foo.ast_range.col_offset == 0
assert foo.ast_range.end_col_offset == 3

0 comments on commit 26f91d4

Please sign in to comment.