Skip to content

Commit

Permalink
even more memoization bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Nov 30, 2023
1 parent 89ef415 commit edda6b5
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
13 changes: 11 additions & 2 deletions core/ipyflow/data_model/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
Generator,
Expand Down Expand Up @@ -88,6 +87,7 @@ def __init__(
) -> None:
self.cell_id: IdType = cell_id
self.cell_ctr: int = cell_ctr
self.error_in_exec: Optional[BaseException] = None
self.history: List[int] = [cell_ctr] if cell_ctr > -1 else []
self.executed_content: Optional[str] = None
self.current_content: str = content
Expand Down Expand Up @@ -126,6 +126,10 @@ def __init__(
def id(self) -> IdType:
return self.cell_id

@property
def is_error(self) -> bool:
return self.error_in_exec is not None

@property
def is_dirty(self) -> bool:
return self.current_content != self.executed_content
Expand Down Expand Up @@ -279,16 +283,21 @@ def get_reactive_ids_for_tag(cls, tag: str) -> Set[IdType]:
return cls._reactive_cells_by_tag.get(tag, set())

def _maybe_memoize_params(self) -> None:
if self.is_error:
return
inputs: Dict["Symbol", MemoizedInput] = {}
for _ in SlicingContext.iter_slicing_contexts():
for edges in self.parents.values():
for sym in edges:
if sym.timestamp.cell_num >= self.cell_ctr:
if sym.timestamp.cell_num == self.cell_ctr:
continue
elif sym.timestamp.cell_num > self.cell_ctr:
return
if sym not in inputs:
inputs[sym] = MemoizedInput(
sym,
sym.timestamp,
sym.memoize_timestamp,
sym.obj_id,
sym.make_memoize_comparable()[0],
)
Expand Down
4 changes: 4 additions & 0 deletions core/ipyflow/data_model/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ def timestamp_excluding_ns_descendents(self) -> Timestamp:
else:
return max(self._timestamp, self._override_timestamp)

@property
def memoize_timestamp(self) -> Optional[Timestamp]:
return self.last_updated_timestamp_by_obj_id.get(self.obj_id)

@property
def timestamp(self) -> Timestamp:
ts = self.timestamp_excluding_ns_descendents
Expand Down
3 changes: 2 additions & 1 deletion core/ipyflow/memoization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from typing import TYPE_CHECKING, Any, List, NamedTuple
from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional

if TYPE_CHECKING:
from ipyflow.data_model.symbol import Symbol
Expand All @@ -9,6 +9,7 @@
class MemoizedInput(NamedTuple):
symbol: "Symbol"
ts_at_execution: "Timestamp"
mem_ts_at_execution: Optional["Timestamp"]
obj_id_at_execution: int
comparable: Any

Expand Down
10 changes: 6 additions & 4 deletions core/ipyflow/shell/interactiveshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def _ipyflow_run_cell(
maybe_new_content = self.before_run_cell(
raw_cell, store_history=store_history, **kwargs
)
cell = Cell.from_counter(self.cell_counter())
if maybe_new_content is not None:
raw_cell = maybe_new_content

Expand All @@ -367,6 +368,7 @@ def _ipyflow_run_cell(
shell_futures=shell_futures,
**kwargs,
) # pragma: no cover
cell.error_in_exec = ret.error_in_exec
if is_already_recording_output:
outvar = (
raw_cell.strip().splitlines()[0][len("%%capture") :].strip()
Expand Down Expand Up @@ -437,12 +439,12 @@ def _get_content_for_memoized_run(self, cell: Cell) -> Optional[str]:
for content, inputs, outputs, ctr in prev_cell.memoized_executions:
if content != cell.executed_content:
continue
for sym, in_ts, obj_id, comparable in inputs:
for sym, in_ts, mem_ts, obj_id, comparable in inputs:
if sym.timestamp.cell_num == in_ts.cell_num:
continue
elif (
sym.obj_id == obj_id
and sym.last_updated_timestamp_by_obj_id.get(obj_id) == in_ts
elif sym.obj_id == obj_id and sym.memoize_timestamp in (
in_ts,
mem_ts or Timestamp.uninitialized(),
):
continue
elif comparable is Symbol.NULL:
Expand Down

0 comments on commit edda6b5

Please sign in to comment.