Skip to content

Commit

Permalink
fix depth visuzaliation pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasbtnfr committed Mar 21, 2023
1 parent f17c0df commit d428c18
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
19 changes: 10 additions & 9 deletions skmine/periodic/cycles.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
"""Periodic pattern mining with a MDL criterion"""
# Authors: Rémi Adon <remi.adon@gmail.com>
# Esther Galbrun <esther.galbrun@inria.fr>
# Cyril Regan <cyril.regan@loria.fr>
# Thomas Betton <thomas.betton@irisa.fr>
#
# License: BSD 3 clause

import copy
import json
import warnings

Expand All @@ -18,6 +12,13 @@
from .pattern_collection import PatternCollection
from .run_mine import mine_seqs

# Authors: Rémi Adon <remi.adon@gmail.com>
# Esther Galbrun <esther.galbrun@inria.fr>
# Cyril Regan <cyril.regan@loria.fr>
# Thomas Betton <thomas.betton@irisa.fr>
#
# License: BSD 3 clause

INDEX_TYPES = (
pd.DatetimeIndex,
pd.RangeIndex,
Expand Down Expand Up @@ -427,10 +428,10 @@ def draw_pattern(self, pattern_id):
-------
"""
pattern = self.cycles.loc[pattern_id]["pattern_json_tree"]
pattern = copy.deepcopy(self.cycles.loc[pattern_id]["pattern_json_tree"])
# map each event id to its real textual name
for nid in pattern.keys():
if isinstance(nid, int):
if "event" in pattern[nid].keys():
pattern[nid]["event"] = list(self.data_details.map_ev_num.keys())[int(pattern[nid]["event"])]
return draw_pattern(self.cycles.loc[pattern_id]["pattern_json_tree"])
return draw_pattern(pattern)
2 changes: 1 addition & 1 deletion skmine/periodic/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def draw_pattern_rec(graph, pattern, id_to_pr_event=None, id=0, id_parent=-1, di
graph.edge(str(id_parent), str(id), dir="none")

if distance != (-1, -1):
graph.edge(str(distance[0]), str(id), label=str(distance[1]), style="dotted")
graph.edge(str(distance[0]), str(id), label=str(distance[1]), style="dotted", constraint="false")

return graph

Expand Down
13 changes: 13 additions & 0 deletions skmine/periodic/tests/test_cycles.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,19 @@ def test_import_patterns(patterns_json):
assert pcm.miners_.patterns[0][2] == patterns_json["patterns"][0]["E"]


def test_import_export_patterns(data):
pcm1 = PeriodicPatternMiner()
pcm1.fit(data)
res1 = pcm1.discover()
pcm1.export_patterns()

pcm2 = PeriodicPatternMiner()
pcm2.import_patterns()
res2 = pcm2.discover()

assert_frame_equal(res1, res2)


@pytest.fixture
def expected_reconstruct():
expected_reconstruct = pd.DataFrame({
Expand Down

0 comments on commit d428c18

Please sign in to comment.