Skip to content

Commit

Permalink
seeding fix for xvalmmd
Browse files Browse the repository at this point in the history
  • Loading branch information
vigsterkr committed Oct 25, 2019
1 parent ce24d32 commit 7022a56
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,11 @@ struct CrossValidationMMD : PermutationMMD
SGVector<int64_t> dummy_labels_x(m_n_x);
SGVector<int64_t> dummy_labels_y(m_n_y);

auto instance_x=new CCrossValidationSplitting(new CBinaryLabels(dummy_labels_x), m_num_folds);
auto instance_y=new CCrossValidationSplitting(new CBinaryLabels(dummy_labels_y), m_num_folds);
random::seed(instance_x, prng);
random::seed(instance_y, prng);

m_kfold_x=unique_ptr<CCrossValidationSplitting>(instance_x);
m_kfold_y=unique_ptr<CCrossValidationSplitting>(instance_y);
m_kfold_x=std::make_unique<CCrossValidationSplitting>(new CBinaryLabels(dummy_labels_x), m_num_folds);
m_kfold_y=std::make_unique<CCrossValidationSplitting>(new CBinaryLabels(dummy_labels_y), m_num_folds);
random::seed(m_kfold_x.get(), prng);
random::seed(m_kfold_y.get(), prng);

m_stack=unique_ptr<CSubsetStack>(new CSubsetStack());

const index_t size=m_n_x+m_n_y;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ TEST(CrossValidationMMD, biased_full)
kfold_p->put("seed", seed);
kfold_q->put("seed", seed);

std::mt19937_64 permPRNG(seed);
auto permutation_mmd=PermutationMMD();
permutation_mmd.m_stype=stype;
permutation_mmd.m_num_null_samples=num_null_samples;
Expand Down Expand Up @@ -134,7 +135,7 @@ TEST(CrossValidationMMD, biased_full)
(feats_p->create_merged_copy(feats_q));

kernel->init(current_merged_feats, current_merged_feats);
auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix<float32_t>(), prng);
auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix<float32_t>(), permPRNG);

EXPECT_EQ(cv.m_rejections(current_run*num_folds+current_fold, k), p_value<alpha);

Expand Down Expand Up @@ -206,6 +207,7 @@ TEST(CrossValidationMMD, unbiased_full)
kfold_p->put("seed", seed);
kfold_q->put("seed", seed);

std::mt19937_64 permPRNG(seed);
auto permutation_mmd=PermutationMMD();
permutation_mmd.m_stype=stype;
permutation_mmd.m_num_null_samples=num_null_samples;
Expand Down Expand Up @@ -233,7 +235,7 @@ TEST(CrossValidationMMD, unbiased_full)
(feats_p->create_merged_copy(feats_q));

kernel->init(current_merged_feats, current_merged_feats);
auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix<float32_t>(), prng);
auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix<float32_t>(), permPRNG);

EXPECT_EQ(cv.m_rejections(current_run*num_folds+current_fold, k), p_value<alpha);

Expand Down Expand Up @@ -306,6 +308,7 @@ TEST(CrossValidationMMD, unbiased_incomplete)
kfold_p->put("seed", seed);
kfold_q->put("seed", seed);

std::mt19937_64 permPRNG(seed);
auto permutation_mmd=PermutationMMD();
permutation_mmd.m_stype=stype;
permutation_mmd.m_num_null_samples=num_null_samples;
Expand Down Expand Up @@ -333,7 +336,7 @@ TEST(CrossValidationMMD, unbiased_incomplete)
(feats_p->create_merged_copy(feats_q));

kernel->init(current_merged_feats, current_merged_feats);
auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix<float32_t>(), prng);
auto p_value=permutation_mmd.p_value(kernel->get_kernel_matrix<float32_t>(), permPRNG);

EXPECT_EQ(cv.m_rejections(current_run*num_folds+current_fold, k), p_value<alpha);

Expand Down

0 comments on commit 7022a56

Please sign in to comment.