Skip to content

Commit

Permalink
add nested virtual symbol to track mutations without bumping containi…
Browse files Browse the repository at this point in the history
…ng symbol shallow timestamp
  • Loading branch information
smacke committed Jan 26, 2024
1 parent 8afbde8 commit e54c3e6
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 6 deletions.
34 changes: 32 additions & 2 deletions core/ipyflow/data_model/symbol.py
Expand Up @@ -100,6 +100,8 @@ class Symbol:

IMMUTABLE_TYPES = set(IMMUTABLE_PRIMITIVE_TYPES)

IPYFLOW_MUTATION_VIRTUAL_SYMBOL_NAME = "__ipyflow_mutation"

def __init__(
self,
name: SupportedIndexType,
Expand Down Expand Up @@ -505,6 +507,18 @@ def obj_len(self) -> Optional[int]:
def obj_type(self) -> Type[Any]:
return type(self.obj)

@property
def is_immutable(self) -> bool:
return self.obj_type in self.IMMUTABLE_TYPES

@property
def is_mutation_virtual_symbol(self) -> bool:
return self.name == self.IPYFLOW_MUTATION_VIRTUAL_SYMBOL_NAME

@property
def is_underscore(self) -> bool:
return self.name == "_" and self.containing_scope.is_global

@property
def is_obj_lazy_module(self) -> bool:
return self.obj_type is _LazyModule
Expand Down Expand Up @@ -905,7 +919,7 @@ def should_mark_waiting(self, updated_dep):
return True

def _is_underscore_or_simple_assign(self, new_deps: Set["Symbol"]) -> bool:
if self.name == "_":
if self.is_underscore:
# FIXME: distinguish between explicit assignment to _ from user and implicit assignment from kernel
return True
if not isinstance(self.stmt_node, (ast.Assign, ast.AnnAssign)):
Expand Down Expand Up @@ -934,7 +948,7 @@ def update_deps(
return
if overwrite and not self.is_globally_accessible:
self.watchpoints.clear()
if mutated and self.obj_type in self.IMMUTABLE_TYPES:
if mutated and self.is_immutable:
return
# if we get here, no longer implicit
self._implicit = False
Expand All @@ -961,6 +975,22 @@ def update_deps(
self.fresher_ancestor_timestamps.clear()
if mutated or isinstance(self.stmt_node, ast.AugAssign):
self.update_usage_info()
if (
(mutated or overwrite)
and Timestamp.current().is_initialized
and not self.is_immutable
and not self.is_mutation_virtual_symbol
and not self.is_anonymous
and self.containing_scope.is_global
and not self.is_underscore
and not self.is_implicit
and self.obj_type is not type
and not self.is_class
and self.namespace is not None
):
self.namespace.upsert_symbol_for_name(
self.IPYFLOW_MUTATION_VIRTUAL_SYMBOL_NAME, object(), propagate=False
)
propagate = propagate and (
mutated or deleted or not self._should_cancel_propagation(prev_obj)
)
Expand Down
2 changes: 1 addition & 1 deletion core/ipyflow/data_model/utils/update_protocol.py
Expand Up @@ -174,7 +174,7 @@ def _propagate_waiting_to_namespace_children(
if not skip_seen_check and sym in self.seen:
return
self.seen.add(sym)
self_ns = flow().namespaces.get(sym.obj_id, None)
self_ns = flow().namespaces.get(sym.obj_id)
if self_ns is None:
return
for ns_child in self_ns.all_symbols_this_indentation(exclude_class=True):
Expand Down
2 changes: 1 addition & 1 deletion core/test/test_nested_symbols.py
Expand Up @@ -128,7 +128,7 @@ def test_basic():
assert_detected("`y` depends on changed `d.y`")


def test_nested_readable_name():
def test_nested_readable_name_attr():
run_cell("d = DotDict()")
run_cell("d.x = DotDict()")
run_cell("d.x.a = 5")
Expand Down
2 changes: 1 addition & 1 deletion core/test/test_reactivity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import logging
import sys
from test.utils import make_flow_fixture, skipif_known_failing
from test.utils import make_flow_fixture
from typing import Optional, Set, Tuple

from ipyflow.config import ExecutionMode
Expand Down
8 changes: 7 additions & 1 deletion core/test/test_updated_symbols.py
Expand Up @@ -2,6 +2,7 @@
import logging
from test.utils import make_flow_fixture

from ipyflow.data_model.symbol import Symbol
from ipyflow.singletons import flow

logging.basicConfig(level=logging.ERROR)
Expand Down Expand Up @@ -36,7 +37,12 @@ def test_simplest():
def test_dict_hierarchy():
run_cell("d = {}")
updated_sym_names = updated_symbol_names()
assert updated_sym_names == ["d"], "got %s" % updated_sym_names
assert updated_sym_names == [
"d",
f"d.{Symbol.IPYFLOW_MUTATION_VIRTUAL_SYMBOL_NAME}",
], (
"got %s" % updated_sym_names
)
run_cell('d["foo"] = {}')
assert updated_symbol_names() == sorted(["d[foo]", "d"])
run_cell('d["foo"]["bar"] = []')
Expand Down

0 comments on commit e54c3e6

Please sign in to comment.