Skip to content

Commit

Permalink
fix: Search bug fixes (#4673)
Browse files Browse the repository at this point in the history
* fix memory leaks and feature indexing issues

* use scope exit guards

* formatting

* update test reference outputs

* remove extra define

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
Co-authored-by: olgavrou <olgavrou@gmail.com>
  • Loading branch information
3 people committed Apr 11, 2024
1 parent a0be017 commit 128fad3
Show file tree
Hide file tree
Showing 17 changed files with 179 additions and 198 deletions.
8 changes: 4 additions & 4 deletions test/train-sets/ref/search_dep_parser_cost_to_go.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ average since instance current true current predicted
loss last counter output prefix output prefix pass pol made hits gener beta
88.000000 88.000000 1 [43:1 5:2 5:2 5:2 1..] [0:8 1:1 2:1 3:1 4:..] 0 0 144 0 141 0.000000
46.500000 5.000000 2 [2:2 3:5 0:8 3:7 99..] [2:2 0:8 4:2 2:3 99..] 0 0 153 0 150 0.001409
30.750000 15.000000 4 [2:2 3:5 0:8 3:7 99..] [2:2 3:5 0:8 3:7 99..] 1 0 306 0 300 0.002906
17.125000 3.500000 8 [2:2 3:5 0:8 3:7 99..] [2:2 3:5 0:8 3:7 99..] 3 0 606 0 600 0.005893
29.250000 12.000000 4 [2:2 3:5 0:8 3:7 99..] [2:2 3:5 0:8 3:7 99..] 1 0 306 0 300 0.002906
16.375000 3.500000 8 [2:2 3:5 0:8 3:7 99..] [2:2 3:5 0:8 3:7 99..] 3 0 606 0 600 0.005893

finished run
number of examples per pass = 2
passes used = 6
weighted example sum = 12.000000
weighted label sum = 0.000000
average loss = 11.416667
total feature number = 270404
average loss = 10.916667
total feature number = 269977
4 changes: 2 additions & 2 deletions test/train-sets/ref/search_dep_parser_one_learner.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
89.000000 89.000000 1 [43:1 5:2 5:2 5:2 1..] [20:9 20:9 20:9 20:..] 0 0 96 0 94 0.000930
47.000000 5.000000 2 [2:2 3:5 0:8 3:7 99..] [2:12 3:7 4:4 0:8 9..] 0 0 102 0 100 0.000990
46.500000 4.000000 2 [2:2 3:5 0:8 3:7 99..] [2:5 3:5 4:5 0:8 99..] 0 0 102 0 100 0.000990

finished run
number of examples = 2
weighted example sum = 2.000000
weighted label sum = 0.000000
average loss = 47.000000
average loss = 46.500000
total feature number = 28636
10 changes: 5 additions & 5 deletions test/train-sets/ref/search_wsj.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
30.000000 30.000000 1 [1 2 3 1 4 5 6 7 8 ..] [1 1 1 1 1 1 1 1 1 ..] 0 0 37 0 37 0.000036
23.500000 17.000000 2 [11 2 3 11 11 11 15..] [1 2 1 1 4 1 2 1 1 ..] 0 0 64 0 64 0.000063
16.000000 8.500000 4 [3 4 6 3 ] [11 11 2 3 ] 0 0 97 0 97 0.000096
8.000000 0.000000 8 [3 4 6 3 ] [3 4 6 3 ] 1 0 194 0 194 0.000193
4.000000 0.000000 16 [3 4 6 3 ] [3 4 6 3 ] 3 0 388 0 388 0.000387
24.500000 19.000000 2 [11 2 3 11 11 11 15..] [1 2 1 1 4 1 12 9 1..] 0 0 64 0 64 0.000063
16.250000 8.000000 4 [3 4 6 3 ] [1 4 6 3 ] 0 0 97 0 97 0.000096
8.125000 0.000000 8 [3 4 6 3 ] [3 4 6 3 ] 1 0 194 0 194 0.000193
4.062500 0.000000 16 [3 4 6 3 ] [3 4 6 3 ] 3 0 388 0 388 0.000387

finished run
number of examples per pass = 4
passes used = 6
weighted example sum = 24.000000
weighted label sum = 0.000000
average loss = 2.666667
average loss = 2.708333
total feature number = 52110
4 changes: 2 additions & 2 deletions test/train-sets/ref/search_wsj2.dat.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
30.000000 30.000000 1 [1 2 3 1 4 5 6 7 8 ..] [1 1 1 1 1 1 1 1 1 ..] 0 0 37 0 37 0.000036
24.000000 18.000000 2 [11 2 3 11 11 11 15..] [1 2 3 1 4 1 2 1 1 ..] 0 0 64 0 64 0.000063
16.750000 9.500000 4 [3 4 6 3 ] [11 11 11 11 ] 0 0 97 0 97 0.000096
24.500000 19.000000 2 [11 2 3 11 11 11 15..] [1 2 1 1 1 1 12 12 ..] 0 0 64 0 64 0.000063
16.750000 9.000000 4 [3 4 6 3 ] [1 4 6 6 ] 0 0 97 0 97 0.000096
8.375000 0.000000 8 [3 4 6 3 ] [3 4 6 3 ] 1 0 194 0 194 0.000193
4.187500 0.000000 16 [3 4 6 3 ] [3 4 6 3 ] 3 1 388 0 388 0.000387

Expand Down
12 changes: 1 addition & 11 deletions test/train-sets/ref/sequence_data.ldf.beam.test.predict
Original file line number Diff line number Diff line change
@@ -1,11 +1 @@
5 4 3 2 1 0.000242054
5 4 4 3 2 1.00016
5 4 3 3 2 1.00018
5 4 3 2 4 1.00018
5 4 5 4 3 1.00019
5 4 3 4 3 1.00026
5 4 2 1 4 1.60554
5 3 2 1 4 1.60557
5 4 1 4 3 1.60563
4 3 2 1 4 1.60563

5 4 3 2 1
2 changes: 1 addition & 1 deletion test/train-sets/ref/sequence_data.ldf.beam.test.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Input label = MULTICLASS
Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
0.000000 0.000000 1 [5 4 3 2 1 ] [5 4 3 2 1 0.00024..] 0 0 26 0 0 0.000000
0.000000 0.000000 1 [5 4 3 2 1 ] [5 4 3 2 1 ] 0 0 26 0 0 0.000000

finished run
number of examples = 1
Expand Down
12 changes: 1 addition & 11 deletions test/train-sets/ref/sequence_data.nonldf.beam.test.predict
Original file line number Diff line number Diff line change
@@ -1,11 +1 @@
5 4 3 2 1 8.34465e-07
5 4 3 5 4 1
5 4 3 2 4 1
5 5 4 3 2 1.02138
5 4 3 2 3 1.02816
5 4 3 2 2 1.03424
5 3 2 1 1 1.76761
5 4 3 1 1 1.79503
4 3 2 1 1 1.79576
5 2 1 1 1 2.53521

5 4 3 2 1
2 changes: 1 addition & 1 deletion test/train-sets/ref/sequence_data.nonldf.beam.test.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Input label = MULTICLASS
Output pred = MULTICLASS
average since instance current true current predicted cur cur predic cache examples
loss last counter output prefix output prefix pass pol made hits gener beta
0.000000 0.000000 1 [5 4 3 2 1 ] [5 4 3 2 1 8.34465..] 0 0 24 0 0 0.000000
0.000000 0.000000 1 [5 4 3 2 1 ] [5 4 3 2 1 ] 0 0 24 0 0 0.000000

finished run
number of examples = 1
Expand Down
12 changes: 6 additions & 6 deletions vowpalwabbit/core/include/vw/core/reductions/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ class search
public: // INTERFACE
// for managing task-specific data that you want on the heap:
template <class T>
void set_task_data(T* data)
void set_task_data(std::shared_ptr<T> data)
{
task_data = std::shared_ptr<T>(data);
task_data = std::move(data);
}
template <class T>
T* get_task_data()
Expand All @@ -98,9 +98,9 @@ class search

// for managing metatask-specific data
template <class T>
void set_metatask_data(T* data)
void set_metatask_data(std::shared_ptr<T> data)
{
metatask_data = std::shared_ptr<T>(data);
metatask_data = std::move(data);
}
template <class T>
T* get_metatask_data()
Expand Down Expand Up @@ -218,7 +218,7 @@ class search
BaseTask base_task(VW::multi_ex& ec) { return BaseTask(this, ec); }

// internal data that you don't get to see!
search_private* priv = nullptr;
std::shared_ptr<search_private> priv = nullptr;
std::shared_ptr<void> task_data = nullptr; // your task data!
std::shared_ptr<void> metatask_data = nullptr; // your metatask data!
const char* task_name = nullptr;
Expand All @@ -227,8 +227,8 @@ class search
VW::workspace& get_vw_pointer_unsafe(); // although you should rarely need this, some times you need a pointer to the
// vw data structure :(
void set_force_oracle(bool force); // if the library wants to force search to use the oracle, set this to true

search();
~search();
};

// for defining new tasks, you must fill out a search_task
Expand Down

0 comments on commit 128fad3

Please sign in to comment.