Skip to content

Commit

Permalink
make parent / children getters raw and add API helpers for querying t…
Browse files Browse the repository at this point in the history
…hem from the notebook
  • Loading branch information
smacke committed Jan 27, 2024
1 parent e54c3e6 commit 24b996e
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 80 deletions.
32 changes: 16 additions & 16 deletions core/ipyflow/data_model/cell.py
Expand Up @@ -109,10 +109,10 @@ def __init__(
self.override_live_refs: Optional[List[str]] = None
self.override_dead_refs: Optional[List[str]] = None
self.reactive_tags: Set[str] = set()
self.dynamic_parents: Dict[IdType, Set["Symbol"]] = {}
self.dynamic_children: Dict[IdType, Set["Symbol"]] = {}
self.static_parents: Dict[IdType, Set["Symbol"]] = {}
self.static_children: Dict[IdType, Set["Symbol"]] = {}
self.raw_dynamic_parents: Dict[IdType, Set["Symbol"]] = {}
self.raw_dynamic_children: Dict[IdType, Set["Symbol"]] = {}
self.raw_static_parents: Dict[IdType, Set["Symbol"]] = {}
self.raw_static_children: Dict[IdType, Set["Symbol"]] = {}
self.used_symbols: Set["Symbol"] = set()
self.static_removed_symbols: Set["Symbol"] = set()
self.static_writes: Set["Symbol"] = set()
Expand Down Expand Up @@ -184,7 +184,7 @@ def position(self) -> int:
@property
def directional_parents(self) -> Mapping[IdType, FrozenSet["Symbol"]]:
# trick to catch some mutations at typecheck time w/out runtime overhead
parents = self.parents
parents = self.raw_parents
if flow().mut_settings.flow_order == FlowDirection.IN_ORDER:
parents = {
sid: syms
Expand All @@ -195,7 +195,7 @@ def directional_parents(self) -> Mapping[IdType, FrozenSet["Symbol"]]:

@property
def directional_children(self) -> Mapping[IdType, FrozenSet["Symbol"]]:
children = self.children
children = self.raw_children
if flow().mut_settings.flow_order == FlowDirection.IN_ORDER:
children = {
cell_id: syms
Expand Down Expand Up @@ -251,7 +251,7 @@ def __str__(self):
return self.executed_content

def __repr__(self):
return f"<{self.__class__.__name__}[id={self.cell_id},ctr={self.cell_ctr}]>"
return f"<{self.__class__.__name__}[ctr={self.cell_ctr},id={self.cell_id}]>"

def __hash__(self):
return hash((self.cell_id, self.cell_ctr))
Expand All @@ -276,17 +276,17 @@ def update_id(self, new_id: IdType, update_edges: bool = True) -> None:
reactive_cells.discard(old_id)
reactive_cells.add(new_id)
for _ in flow().mut_settings.iter_slicing_contexts():
for pid in self.parents.keys():
for pid in self.raw_parents.keys():
parent = self.from_id(pid)
parent.children = {
parent.raw_children = {
(new_id if cid == old_id else cid): syms
for cid, syms in parent.children.items()
for cid, syms in parent.raw_children.items()
}
for cid in self.children.keys():
for cid in self.raw_children.keys():
child = self.from_id(cid)
child.parents = {
child.raw_parents = {
(new_id if pid == old_id else pid): syms
for pid, syms in child.parents.items()
for pid, syms in child.raw_parents.items()
}

def add_used_cell_counter(self, sym: "Symbol", ctr: int) -> None:
Expand Down Expand Up @@ -319,7 +319,7 @@ def _maybe_memoize_params(self) -> None:
return
inputs: Dict["Symbol", MemoizedInput] = {}
for _ in flow().mut_settings.iter_slicing_contexts():
for edges in self.parents.values():
for edges in self.raw_parents.values():
for sym in edges:
if sym in inputs:
continue
Expand Down Expand Up @@ -397,8 +397,8 @@ def create_and_track(
)
if prev_cell is not None:
cell.history = prev_cell.history + cell.history
cell.static_children = prev_cell.static_children
cell.dynamic_children = prev_cell.dynamic_children
cell.raw_static_children = prev_cell.raw_static_children
cell.raw_dynamic_children = prev_cell.raw_dynamic_children
for tag in prev_cell.tags:
cls._cells_by_tag[tag].discard(prev_cell)
for tag in prev_cell.reactive_tags:
Expand Down
15 changes: 9 additions & 6 deletions core/ipyflow/data_model/statement.py
Expand Up @@ -67,10 +67,10 @@ def __init__(
self.lambda_call_point_deps_done_once = False
self.node_id_for_last_call: Optional[int] = None
self._stmt_contains_cascading_reactive_rval: Optional[bool] = None
self.dynamic_parents: Dict[IdType, Set[Symbol]] = {}
self.dynamic_children: Dict[IdType, Set[Symbol]] = {}
self.static_parents: Dict[IdType, Set[Symbol]] = {}
self.static_children: Dict[IdType, Set[Symbol]] = {}
self.raw_dynamic_parents: Dict[IdType, Set[Symbol]] = {}
self.raw_dynamic_children: Dict[IdType, Set[Symbol]] = {}
self.raw_static_parents: Dict[IdType, Set[Symbol]] = {}
self.raw_static_children: Dict[IdType, Set[Symbol]] = {}

@classmethod
def current(cls) -> "Statement":
Expand Down Expand Up @@ -139,9 +139,9 @@ def create_and_track(
cls._stmts_by_ts[stmt.timestamp] = [stmt]
cls._stmts_by_id[stmt.stmt_id] = [stmt]
for _ in SlicingContext.iter_slicing_contexts():
for cid in list(prev.children.keys()):
for cid in list(prev.raw_children.keys()):
cls.from_id(cid).replace_parent_edges(prev, stmt)
for pid in list(prev.parents.keys()):
for pid in list(prev.raw_parents.keys()):
cls.from_id(pid).replace_child_edges(prev, stmt)
else:
cls._stmts_by_ts.setdefault(stmt.timestamp, []).append(stmt)
Expand Down Expand Up @@ -216,6 +216,9 @@ def __str__(self):
def __repr__(self):
return f"<{self.__class__.__name__}[ts={self.timestamp},text={repr(self.text[:self._TEXT_REPR_MAX_LENGTH])}]>"

def __hash__(self):
return hash(self.stmt_node)

def slice(
self,
blacken: bool = True,
Expand Down
6 changes: 3 additions & 3 deletions core/ipyflow/flow.py
Expand Up @@ -1014,17 +1014,17 @@ def _resync_symbols(self, symbols: Iterable[Symbol]):
def _remove_dangling_parent_edges(self, dangling: Set[Symbol]) -> None:
for _ in SlicingContext.iter_slicing_contexts():
for cell in cells().iterate_over_notebook_in_counter_order():
for pid in list(cell.parents.keys()):
for pid in list(cell.raw_parents.keys()):
cell.remove_parent_edges(pid, dangling)
cell = cells().at_counter(self.cell_counter())
prev_cell = cell.prev_cell
if prev_cell is None:
return
for _ in SlicingContext.iter_slicing_contexts():
for prev_pid, sym_edges in list(prev_cell.parents.items()):
for prev_pid, sym_edges in list(prev_cell.raw_parents.items()):
# remove anything not in the current parent set
cell.remove_parent_edges(
prev_pid, sym_edges - cell.parents.get(prev_pid, set())
prev_pid, sym_edges - cell.raw_parents.get(prev_pid, set())
)

@property
Expand Down
2 changes: 1 addition & 1 deletion core/ipyflow/frontend.py
Expand Up @@ -261,7 +261,7 @@ def _compute_stale_parent_makers(self) -> None:
cells_so_far_that_update_symbol: Dict[Symbol, Set[Cell]] = {}
for cell in cells().iterate_over_notebook_in_position_order():
for _ in flow_.mut_settings.iter_slicing_contexts():
for pid, syms in cell.parents.items():
for pid, syms in cell.raw_parents.items():
for sym in syms:
for executed_cell in cells_so_far_that_update_symbol.get(
sym, []
Expand Down
12 changes: 6 additions & 6 deletions core/ipyflow/shell/interactiveshell.py
Expand Up @@ -420,8 +420,8 @@ async def _ipyflow_run_cell(
if should_trace:
self.after_run_cell(raw_cell)
elif cell.prev_cell is not None:
cell.static_parents = cell.prev_cell.static_parents
cell.dynamic_parents = cell.prev_cell.dynamic_parents
cell.raw_static_parents = cell.prev_cell.raw_static_parents
cell.raw_dynamic_parents = cell.prev_cell.raw_dynamic_parents
except Exception as e:
if settings.is_dev_mode:
logger.exception("exception occurred")
Expand Down Expand Up @@ -626,14 +626,14 @@ def _handle_memoization(self) -> None:
prev_cell = Cell.at_counter(cell.skipped_due_to_memoization_ctr)
assert prev_cell is not None
for _ in singletons.flow().mut_settings.iter_slicing_contexts():
for parent, syms in list(cell.parents.items()):
for parent, syms in list(cell.raw_parents.items()):
cell.remove_parent_edges(parent, syms)
for parent, syms in prev_cell.parents.items():
for parent, syms in prev_cell.raw_parents.items():
cell.add_parent_edges(parent, syms)
for stmt, prev_stmt in zip(cell.statements(), prev_cell.statements()):
for parent, syms in list(stmt.parents.items()):
for parent, syms in list(stmt.raw_parents.items()):
stmt.remove_parent_edges(parent, syms)
for parent, syms in prev_stmt.parents.items():
for parent, syms in prev_stmt.raw_parents.items():
stmt.add_parent_edges(parent, syms)
elif cell.is_memoized:
cell._maybe_memoize_params()
Expand Down
102 changes: 67 additions & 35 deletions core/ipyflow/slicing/mixin.py
Expand Up @@ -196,10 +196,42 @@ class SliceableMixin(Protocol):
#############
# subclasses must implement the following:

dynamic_parents: Dict[IdType, Set["Symbol"]]
dynamic_children: Dict[IdType, Set["Symbol"]]
static_parents: Dict[IdType, Set["Symbol"]]
static_children: Dict[IdType, Set["Symbol"]]
raw_dynamic_parents: Dict[IdType, Set["Symbol"]]
raw_dynamic_children: Dict[IdType, Set["Symbol"]]
raw_static_parents: Dict[IdType, Set["Symbol"]]
raw_static_children: Dict[IdType, Set["Symbol"]]

@property
def dynamic_parents(self) -> Dict["SliceableMixin", Set["Symbol"]]:
return {
self.from_id(pid): syms for pid, syms in self.raw_dynamic_parents.items()
}

@property
def dynamic_children(self) -> Dict["SliceableMixin", Set["Symbol"]]:
return {
self.from_id(cid): syms for cid, syms in self.raw_dynamic_children.items()
}

@property
def static_parents(self) -> Dict["SliceableMixin", Set["Symbol"]]:
return {
self.from_id(pid): syms for pid, syms in self.raw_static_parents.items()
}

@property
def static_children(self) -> Dict["SliceableMixin", Set["Symbol"]]:
return {
self.from_id(cid): syms for cid, syms in self.raw_static_children.items()
}

@property
def parents(self) -> Dict["SliceableMixin", Set["Symbol"]]:
return {self.from_id(pid): syms for pid, syms in self.raw_parents.items()}

@property
def children(self) -> Dict["SliceableMixin", Set["Symbol"]]:
return {self.from_id(cid): syms for cid, syms in self.raw_children.items()}

@classmethod
def current(cls) -> "SliceableMixin":
Expand Down Expand Up @@ -262,18 +294,18 @@ def add_parent_edges(self, parent_ref: SliceRefType, syms: Set["Symbol"]) -> Non
return
parent = self._from_ref(parent_ref)
pid = parent.id
if pid in self.children:
if pid in self.raw_children:
return
if pid == self.id:
# in this case, inherit the previous parents, if any
if self.prev is not None:
for prev_pid, prev_syms in self.prev.parents.items():
for prev_pid, prev_syms in self.prev.raw_parents.items():
common = syms & prev_syms
if common:
self.parents.setdefault(prev_pid, set()).update(common)
self.raw_parents.setdefault(prev_pid, set()).update(common)
return
self.parents.setdefault(pid, set()).update(syms)
parent.children.setdefault(self.id, set()).update(syms)
self.raw_parents.setdefault(pid, set()).update(syms)
parent.raw_children.setdefault(self.id, set()).update(syms)

def add_parent_edge(self, parent_ref: SliceRefType, sym: "Symbol") -> None:
self.add_parent_edges(parent_ref, {sym})
Expand All @@ -285,7 +317,7 @@ def remove_parent_edges(
return
parent = self._from_ref(parent_ref)
pid = parent.id
for edges, eid in ((self.parents, pid), (parent.children, self.id)):
for edges, eid in ((self.raw_parents, pid), (parent.raw_children, self.id)):
sym_edges = edges.get(eid, set())
if not sym_edges:
continue
Expand All @@ -301,76 +333,76 @@ def replace_parent_edges(
) -> None:
prev_parent = self._from_ref(prev_parent_ref)
new_parent = self._from_ref(new_parent_ref)
syms = self.parents.pop(prev_parent.id)
prev_parent.children.pop(self.id)
self.parents.setdefault(new_parent.id, set()).update(syms)
new_parent.children.setdefault(self.id, set()).update(syms)
syms = self.raw_parents.pop(prev_parent.id)
prev_parent.raw_children.pop(self.id)
self.raw_parents.setdefault(new_parent.id, set()).update(syms)
new_parent.raw_children.setdefault(self.id, set()).update(syms)

def replace_child_edges(
self, prev_child_ref: SliceRefType, new_child_ref: SliceRefType
) -> None:
prev_child = self._from_ref(prev_child_ref)
new_child = self._from_ref(new_child_ref)
syms = self.children.pop(prev_child.id)
prev_child.parents.pop(self.id)
self.children.setdefault(new_child.id, set()).update(syms)
new_child.parents.setdefault(self.id, set()).update(syms)
syms = self.raw_children.pop(prev_child.id)
prev_child.raw_parents.pop(self.id)
self.raw_children.setdefault(new_child.id, set()).update(syms)
new_child.raw_parents.setdefault(self.id, set()).update(syms)

@property
def parents(self) -> Dict[IdType, Set["Symbol"]]:
def raw_parents(self) -> Dict[IdType, Set["Symbol"]]:
ctx = slicing_ctx_var.get()
if ctx == SlicingContext.DYNAMIC:
return self.dynamic_parents
return self.raw_dynamic_parents
elif ctx == SlicingContext.STATIC:
return self.static_parents
return self.raw_static_parents
flow_ = flow()
# TODO: rather than asserting test context,
# assert that we're being called from the notebook
assert not flow_.is_test
settings = flow_.mut_settings
parents: Dict[IdType, Set["Symbol"]] = {}
for _ in settings.iter_slicing_contexts():
for pid, syms in self.parents.items():
for pid, syms in self.raw_parents.items():
parents.setdefault(pid, set()).update(syms)
return parents

@parents.setter
def parents(self, new_parents: Dict[IdType, Set["Symbol"]]) -> None:
@raw_parents.setter
def raw_parents(self, new_parents: Dict[IdType, Set["Symbol"]]) -> None:
ctx = slicing_ctx_var.get()
assert ctx is not None
if ctx == SlicingContext.DYNAMIC:
self.dynamic_parents = new_parents
self.raw_dynamic_parents = new_parents
elif ctx == SlicingContext.STATIC:
self.static_parents = new_parents
self.raw_static_parents = new_parents
else:
assert False

@property
def children(self) -> Dict[IdType, Set["Symbol"]]:
def raw_children(self) -> Dict[IdType, Set["Symbol"]]:
ctx = slicing_ctx_var.get()
if ctx == SlicingContext.DYNAMIC:
return self.dynamic_children
return self.raw_dynamic_children
elif ctx == SlicingContext.STATIC:
return self.static_children
return self.raw_static_children
flow_ = flow()
# TODO: rather than asserting test context,
# assert that we're being called from the notebook
assert not flow_.is_test
settings = flow_.mut_settings
children: Dict[IdType, Set["Symbol"]] = {}
for _ in settings.iter_slicing_contexts():
for pid, syms in self.children.items():
for pid, syms in self.raw_children.items():
children.setdefault(pid, set()).update(syms)
return children

@children.setter
def children(self, new_children: Dict[IdType, Set["Symbol"]]) -> None:
@raw_children.setter
def raw_children(self, new_children: Dict[IdType, Set["Symbol"]]) -> None:
ctx = slicing_ctx_var.get()
assert ctx is not None
if ctx == SlicingContext.DYNAMIC:
self.dynamic_children = new_children
self.raw_dynamic_children = new_children
elif ctx == SlicingContext.STATIC:
self.static_children = new_children
self.raw_static_children = new_children
else:
assert False

Expand All @@ -379,7 +411,7 @@ def _make_slice_helper(self, closure: Set["SliceableMixin"]) -> None:
return
closure.add(self)
for _ in flow().mut_settings.iter_slicing_contexts():
for pid in self.parents.keys():
for pid in self.raw_parents.keys():
parent = self.from_id(pid)
while parent.timestamp > self.timestamp:
if getattr(parent, "override", False):
Expand Down

0 comments on commit 24b996e

Please sign in to comment.