Skip to content

Commit

Permalink
added back checks in linear_regression_main.cpp and linear_regression…
Browse files Browse the repository at this point in the history
…_predict_maain.cpp
  • Loading branch information
lumi232 committed Feb 28, 2024
1 parent d534ede commit 40676a7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/mlpack/methods/linear_regression/linear_regression_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,18 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timer)
timer.Stop("load_test_points");

mat points = std::move(params.Get<mat>("test"));


// Ensure that test file data has the right number of features.
try
{
util::CheckSameDimensionality(points, lr->Parameters().n_elem - 1,
"Linear Regression Prediction", "test points");
}
catch (std::invalid_argument& e)
{
Log::Fatal << e.what() << std::endl;
}

// Perform the predictions using our model.
rowvec predictions;
timer.Start("prediction");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timer)

mat points = std::move(params.Get<mat>("test"));

// Ensure that test file data has the right number of features.
try
{
util::CheckSameDimensionality(points, lr->Parameters().n_elem - 1,
"Linear Regression Prediction", "test points");
}
catch (std::invalid_argument& e)
{
Log::Fatal << e.what() << std::endl;
}

// Perform the predictions using our model.
rowvec predictions;
timer.Start("prediction");
Expand Down

0 comments on commit 40676a7

Please sign in to comment.