Skip to content

Commit

Permalink
Merge pull request #3473 from mrdaybird/fix_prelu
Browse files Browse the repository at this point in the history
Fix PReLU (Resolves #3466)
  • Loading branch information
rcurtin committed Apr 28, 2023
2 parents a7cc156 + 5bd3259 commit 417ba39
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
### mlpack ?.?.?
###### ????-??-??

* Fix PReLU, add integration test to it (#3473).

### mlpack 4.1.0
###### 2023-04-26

Expand Down
9 changes: 3 additions & 6 deletions src/mlpack/methods/ann/layer/parametric_relu_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,9 @@ void PReLUType<MatType>::Forward(
const MatType& input, MatType& output)
{
output = input;
if (this->training)
{
#pragma omp for
for (size_t i = 0; i < input.n_elem; ++i)
output(i) *= (input(i) >= 0) ? 1 : alpha(0);
}
#pragma omp for
for (size_t i = 0; i < input.n_elem; ++i)
output(i) *= (input(i) >= 0) ? 1 : alpha(0);
}

template<typename MatType>
Expand Down
39 changes: 38 additions & 1 deletion src/mlpack/tests/ann/layer/parametric_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ TEST_CASE("PReLUFORWARDTest", "[ANNLayerTest]")
{5.5, -4.7, 2.1},
{0.2, 0.1, -0.5}};
PReLU module(0.01);
module.Training() = true;
arma::mat moduleParams(module.WeightSize(), 1);
module.CustomInitialize(moduleParams, module.WeightSize());
module.SetWeights((double*) moduleParams.memptr());
Expand Down Expand Up @@ -94,3 +93,41 @@ TEST_CASE("PReLUGRADIENTTest", "[ANNLayerTest]")
REQUIRE(0.0103 - arma::accu(predGradient) ==
Approx(0.0).margin(1e-4));
}

double ComputeMSRE(arma::mat input, arma::mat target)
{
return std::pow(arma::accu(arma::pow(input - target, 2)) / target.n_cols, 0.5);
}

TEST_CASE("PReLUIntegrationTest", "[ANNLayerTest]")
{
arma::mat data;
data::Load("boston_housing_price.csv", data);
arma::mat labels;
data::Load("boston_housing_price_responses.csv", labels);

arma::mat trainData, testData, trainLabels, testLabels;
data::Split(data, labels, trainData, testData, trainLabels, testLabels, 0.2);

FFN<L1Loss> model;
model.Add<Linear>(10);
model.Add<PReLU>(0.01);
model.Add<Linear>(3);
model.Add<PReLU>(0.01);
model.Add<Linear>(1);

int epochs = 50;
ens::RMSProp optimizer(0.01, 32, 0.99, 1e-8, epochs * trainData.n_cols);
model.Train(trainData, trainLabels, optimizer);

arma::mat predictions;
model.Predict(trainData, predictions);
double msreTrain = ComputeMSRE(predictions, trainLabels);
model.Predict(testData, predictions);
double msreTest = ComputeMSRE(predictions, testLabels);

double relativeMSRE = std::abs((msreTest - msreTrain) / msreTrain);

REQUIRE(relativeMSRE <= 0.25);
}

0 comments on commit 417ba39

Please sign in to comment.