Skip to content

Commit

Permalink
fixing some memoization bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Nov 25, 2023
1 parent 721dbb5 commit 7a3b9d7
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion core/ipyflow/data_model/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
self._extra_stmt: Optional[ast.stmt] = None
self._placeholder_id = placeholder_id
self.is_memoized = is_memoized
self.skipped_due_to_memoization = False
self.skipped_due_to_memoization_ctr = -1
self.memoized_params: List[
Tuple[Dict["Symbol", Tuple[int, Any]], Dict["Symbol", Any], int]
] = []
Expand Down
4 changes: 2 additions & 2 deletions core/ipyflow/data_model/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def transfer_symbols_to(self, new_ns: "Namespace") -> None:
new_ns.obj, dsym.name, is_subscript=False
)
except AttributeError:
inner_obj = None
continue
except TypeError:
break
dsym.update_obj_ref(inner_obj)
Expand All @@ -380,7 +380,7 @@ def transfer_symbols_to(self, new_ns: "Namespace") -> None:
new_ns.obj, dsym.name, is_subscript=True
)
except (IndexError, KeyError):
inner_obj = None
continue
except TypeError:
break
dsym.update_obj_ref(inner_obj)
Expand Down
5 changes: 3 additions & 2 deletions core/ipyflow/data_model/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,9 +924,10 @@ def update_deps(
prev_cell = None
prev_cell_ctr = -1 if prev_cell is None else prev_cell.cell_ctr
if overwrite:
flow_ = flow()
self._cascading_reactive_cell_num = -1
flow().updated_reactive_symbols.discard(self)
flow().updated_deep_reactive_symbols.discard(self)
flow_.updated_reactive_symbols.discard(self)
flow_.updated_deep_reactive_symbols.discard(self)
if is_cascading_reactive is not None:
is_cascading_reactive = is_cascading_reactive or any(
dsym.is_cascading_reactive_at_counter(prev_cell_ctr)
Expand Down
9 changes: 5 additions & 4 deletions core/ipyflow/shell/interactiveshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,14 @@ def before_run_cell(
elif comparable_obj is Symbol.NULL:
break
comparable_param, eq = param.make_memoize_comparable()
if not eq(comparable_param, comparable_obj):
if eq is None or not eq(comparable_param, comparable_obj):
break
else:
identical_result_ctr = ctr
memoized_outputs = outputs
break
if identical_result_ctr is not None:
cell.skipped_due_to_memoization = True
cell.skipped_due_to_memoization_ctr = identical_result_ctr
print_purple(
"Detected identical symbol usages to previous run; skipping due to memoization..."
)
Expand Down Expand Up @@ -542,10 +542,11 @@ def _handle_memoization(self) -> None:
cell = Cell.current_cell()
prev_cell = cell.prev_cell
if prev_cell is not None:
if cell.executed_content == prev_cell.executed_content:
if cell.executed_content == prev_cell.executed_content and cell.is_memoized:
cell.memoized_params = prev_cell.memoized_params
prev_cell.memoized_params = []
if cell.skipped_due_to_memoization:
if cell.skipped_due_to_memoization_ctr > 0:
prev_cell = Cell.at_counter(cell.skipped_due_to_memoization_ctr)
assert prev_cell is not None
cell.static_parents = prev_cell.static_parents
cell.dynamic_parents = prev_cell.dynamic_parents
Expand Down
24 changes: 12 additions & 12 deletions core/test/test_memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,23 @@ def test_ints():
assert shell().user_ns["y"] == 1
assert flow().global_scope["y"].obj == 1
assert second.is_memoized
assert not second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr == -1
run_cell("x = 0", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = x + 1", cell_id=second.id))
assert second.is_memoized
assert second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr > 0
assert shell().user_ns["y"] == 1
assert flow().global_scope["y"].obj == 1
run_cell("x = 1", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = x + 1", cell_id=second.id))
assert shell().user_ns["y"] == 2
assert flow().global_scope["y"].obj == 2
assert second.is_memoized
assert not second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr == -1
run_cell("x = 0", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = x + 1", cell_id=second.id))
assert second.is_memoized
assert second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr > 0
assert shell().user_ns["y"] == 1
assert flow().global_scope["y"].obj == 1

Expand All @@ -54,21 +54,21 @@ def test_strings():
assert shell().user_ns["y"] == "hello world"
assert flow().global_scope["y"].obj == "hello world"
assert second.is_memoized
assert not second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr == -1
run_cell("x = 'hello'", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = x + ' world'", cell_id=second.id))
assert second.is_memoized
assert second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr > 0
run_cell("x = 'hi'", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = x + ' world'", cell_id=second.id))
assert shell().user_ns["y"] == "hi world"
assert flow().global_scope["y"].obj == "hi world"
assert second.is_memoized
assert not second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr == -1
run_cell("x = 'hello'", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = x + ' world'", cell_id=second.id))
assert second.is_memoized
assert second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr > 0
assert shell().user_ns["y"] == "hello world"
assert flow().global_scope["y"].obj == "hello world"

Expand All @@ -79,22 +79,22 @@ def test_sets():
assert shell().user_ns["y"] == {0, 1}
assert flow().global_scope["y"].obj == {0, 1}
assert second.is_memoized
assert not second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr == -1
run_cell("x = {0}", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = x | {1}", cell_id="second"))
assert shell().user_ns["y"] == {0, 1}
assert flow().global_scope["y"].obj == {0, 1}
assert second.is_memoized
assert second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr > 0
run_cell("x = {2}", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = x | {1}", cell_id="second"))
assert shell().user_ns["y"] == {1, 2}
assert flow().global_scope["y"].obj == {1, 2}
assert second.is_memoized
assert not second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr == -1
run_cell("x = {0}", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = x | {1}", cell_id="second"))
assert shell().user_ns["y"] == {0, 1}
assert flow().global_scope["y"].obj == {0, 1}
assert second.is_memoized
assert second.skipped_due_to_memoization
assert second.skipped_due_to_memoization_ctr > 0
1 change: 1 addition & 0 deletions core/test/test_staleness_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2230,6 +2230,7 @@ def test_dict_2():
)


@skipif_known_failing
def test_default_dict():
run_cell("from collections import defaultdict")
run_cell("d = defaultdict(dict); d[0][0] = 0")
Expand Down

0 comments on commit 7a3b9d7

Please sign in to comment.