Skip to content

Commit

Permalink
memoize: bugfix for functions and handling for classes
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Nov 26, 2023
1 parent 47a5a93 commit c6ce049
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
6 changes: 3 additions & 3 deletions core/ipyflow/data_model/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def make_memoize_comparable_for_obj(
if size > cls._MAX_MEMOIZE_COMPARABLE_SIZE:
return cls.NULL, None, -1
comparable.append(inner_comp)
return type(obj)(comparable), cls._equal, size
return comparable, cls._equal, size
else:
# hacks to check if they are arrays or dataframes without explicitly importing these
module = getattr(type(obj), "__module__", "")
Expand All @@ -1262,8 +1262,8 @@ def make_memoize_comparable_for_obj(
def make_memoize_comparable(
self,
) -> Tuple[Any, Optional[Callable[[Any, Any], bool]]]:
if isinstance(self.stmt_node, ast.FunctionDef):
return astunparse.unparse(self.stmt_node)
if isinstance(self.stmt_node, (ast.ClassDef, ast.FunctionDef)):
return astunparse.unparse(self.stmt_node), self._equal
obj, eq, size = self.make_memoize_comparable_for_obj(self.obj)
if size > self._MAX_MEMOIZE_COMPARABLE_SIZE:
return self.NULL, None
Expand Down
27 changes: 27 additions & 0 deletions core/test/test_memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,30 @@ def test_sets():
assert flow().global_scope["y"].obj == {0, 1}
assert second.is_memoized
assert second.skipped_due_to_memoization_ctr > 0


def test_functions():
first = cells(run_cell("def foo(): return 42"))
second = cells(run_cell("%%memoize\ny = foo() + 1", cell_id="second"))
assert shell().user_ns["y"] == 43
assert flow().global_scope["y"].obj == 43
assert second.is_memoized
assert second.skipped_due_to_memoization_ctr == -1
run_cell("def foo(): return 42", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = foo() + 1", cell_id="second"))
assert shell().user_ns["y"] == 43
assert flow().global_scope["y"].obj == 43
assert second.is_memoized
assert second.skipped_due_to_memoization_ctr > 0
run_cell("def foo(): return 44", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = foo() + 1", cell_id="second"))
assert shell().user_ns["y"] == 45
assert flow().global_scope["y"].obj == 45
assert second.is_memoized
assert second.skipped_due_to_memoization_ctr == -1
run_cell("def foo(): return 42", cell_id=first.id)
second = cells(run_cell("%%memoize\ny = foo() + 1", cell_id="second"))
assert shell().user_ns["y"] == 43
assert flow().global_scope["y"].obj == 43
assert second.is_memoized
assert second.skipped_due_to_memoization_ctr > 0

0 comments on commit c6ce049

Please sign in to comment.