Skip to content

Commit

Permalink
fix more memoization bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Nov 27, 2023
1 parent 297e5bf commit 0ca6fc2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
8 changes: 6 additions & 2 deletions core/ipyflow/data_model/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def __init__(
self.is_memoized = is_memoized
self.skipped_due_to_memoization_ctr = -1
self.memoized_params: List[
Tuple[Dict["Symbol", Tuple[int, Any]], Dict["Symbol", Any], int]
Tuple[
Dict["Symbol", Tuple[int, Any]],
Dict["Symbol", Tuple[Any, Timestamp]],
int,
]
] = []
self._force_tracking = force_tracking

Expand Down Expand Up @@ -295,7 +299,7 @@ def _maybe_memoize_params(self) -> None:
for sym in flow().updated_symbols:
if not sym.is_user_accessible or not sym.containing_scope.is_global:
continue
outputs[sym] = sym.obj
outputs[sym] = (sym.obj, sym.timestamp_excluding_ns_descendents)
self.memoized_params.append((memoized_params, outputs, self.cell_ctr))

@classmethod
Expand Down
15 changes: 11 additions & 4 deletions core/ipyflow/data_model/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ def __str__(self) -> str:
def __hash__(self) -> int:
return hash(id(self))

def __lt__(self, other) -> bool:
return id(self) < id(other)

def add_tag(self, tag_value: str) -> None:
self._tags.add(tag_value)

Expand Down Expand Up @@ -1279,12 +1282,16 @@ def make_memoize_comparable_for_obj(
def make_memoize_comparable(
self, seen_ids: Optional[Set[int]] = None
) -> Tuple[Any, Optional[Callable[[Any, Any], bool]]]:
if isinstance(self.stmt_node, (ast.ClassDef, ast.FunctionDef)):
# TODO: should additionally include deps such as super / kw defaults
# maybe suffices just to look at deps?
return astunparse.unparse(self.stmt_node), self._equal
if seen_ids is None:
seen_ids = set()
if isinstance(self.stmt_node, (ast.ClassDef, ast.FunctionDef)):
comps = [astunparse.unparse(self.stmt_node)]
for sym in sorted(self.parents.keys()):
par_comp, eq = sym.make_memoize_comparable(seen_ids=seen_ids)
if par_comp is self.NULL or eq is not self._equal:
return self.NULL, None
comps.append(par_comp)
return comps, self._equal
obj, eq, size = self.make_memoize_comparable_for_obj(self.obj, seen_ids)
if size > self._MAX_MEMOIZE_COMPARABLE_SIZE:
return self.NULL, None
Expand Down
7 changes: 5 additions & 2 deletions core/ipyflow/shell/interactiveshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from ipyflow import singletons
from ipyflow.config import Interface
from ipyflow.data_model.cell import Cell
from ipyflow.data_model.statement import Statement
from ipyflow.data_model.symbol import Symbol
from ipyflow.data_model.timestamp import Timestamp
from ipyflow.flow import NotebookFlow
from ipyflow.tracing.flow_ast_rewriter import DataflowAstRewriter
from ipyflow.tracing.ipyflow_tracer import (
Expand Down Expand Up @@ -506,11 +508,12 @@ def before_run_cell(
print_purple(
"Detected identical symbol usages to previous run; reusing memoized result..."
)
for sym, obj in memoized_outputs.items():
for sym, (obj, mem_ts) in memoized_outputs.items():
if sym.obj is not obj:
self.user_ns[sym.name] = obj
sym.update_obj_ref(obj)
sym.refresh()
new_updated_ts = Timestamp(self.cell_counter(), mem_ts.stmt_num)
sym.refresh(timestamp=new_updated_ts)
return f"Out.get({identical_result_ctr})"

# Stage 1: Precheck.
Expand Down

0 comments on commit 0ca6fc2

Please sign in to comment.