-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3454 from mrdaybird/adapt_hard_tanh
Updated hard_tanh and added to layer_types
- Loading branch information
Showing
8 changed files
with
202 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
/** | ||
* @file methods/ann/layer/hard_tanh_impl.hpp | ||
* @author Dhawal Arora | ||
* @author Vaibhav Pathak | ||
* | ||
* Implementation and implementation of the HardTanH layer. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_ANN_LAYER_HARD_TANH_IMPL_HPP | ||
#define MLPACK_METHODS_ANN_LAYER_HARD_TANH_IMPL_HPP | ||
|
||
// In case it hasn't yet been included. | ||
#include "hard_tanh.hpp" | ||
|
||
namespace mlpack { | ||
|
||
template<typename MatType> | ||
HardTanHType<MatType>::HardTanHType( | ||
const double maxValue, | ||
const double minValue) : | ||
Layer<MatType>(), | ||
maxValue(maxValue), | ||
minValue(minValue) | ||
{ | ||
// Nothing to do here. | ||
} | ||
|
||
template<typename MatType> | ||
HardTanHType<MatType>::HardTanHType(const HardTanHType& layer) : | ||
Layer<MatType>(layer), | ||
maxValue(layer.maxValue), | ||
minValue(layer.minValue) | ||
{ | ||
// Nothing to do here. | ||
} | ||
|
||
template<typename MatType> | ||
HardTanHType<MatType>::HardTanHType(HardTanHType&& layer) : | ||
Layer<MatType>(std::move(layer)), | ||
maxValue(std::move(layer.maxValue)), | ||
minValue(std::move(layer.minValue)) | ||
{ | ||
// Nothing to do here. | ||
} | ||
|
||
template<typename MatType> | ||
HardTanHType<MatType>& HardTanHType<MatType>::operator=(const HardTanHType& layer) | ||
{ | ||
if (&layer != this) | ||
{ | ||
Layer<MatType>::operator=(layer); | ||
maxValue = layer.maxValue; | ||
minValue = layer.minValue; | ||
} | ||
|
||
return *this; | ||
} | ||
|
||
template<typename MatType> | ||
HardTanHType<MatType>& HardTanHType<MatType>::operator=(HardTanHType&& layer) | ||
{ | ||
if (&layer != this) | ||
{ | ||
Layer<MatType>::operator=(std::move(layer)); | ||
maxValue = std::move(layer.maxValue); | ||
minValue = std::move(layer.minValue); | ||
} | ||
|
||
return *this; | ||
} | ||
template<typename MatType> | ||
void HardTanHType<MatType>::Forward( | ||
const MatType& input, MatType& output) | ||
{ | ||
#pragma omp parallel for | ||
for (size_t i = 0; i < input.n_elem; ++i) | ||
{ | ||
output(i) = (input(i) > maxValue ? maxValue : | ||
(input(i) < minValue ? minValue : input(i))); | ||
} | ||
} | ||
|
||
template<typename MatType> | ||
void HardTanHType<MatType>::Backward( | ||
const MatType& input, const MatType& gy, MatType& g) | ||
{ | ||
g = gy; | ||
|
||
#pragma omp parallel for | ||
for (size_t i = 0; i < input.n_elem; ++i) | ||
{ | ||
// input should not have any values greater than maxValue | ||
// and lesser than minValue | ||
if (input(i) <= minValue || input(i) >= maxValue) | ||
{ | ||
g(i) = 0; | ||
} | ||
} | ||
} | ||
|
||
template<typename MatType> | ||
template<typename Archive> | ||
void HardTanHType<MatType>::serialize( | ||
Archive& ar, | ||
const uint32_t /* version */) | ||
{ | ||
ar(cereal::base_class<Layer<MatType>>(this)); | ||
|
||
ar(CEREAL_NVP(maxValue)); | ||
ar(CEREAL_NVP(minValue)); | ||
} | ||
|
||
} // namespace mlpack | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 0 additions & 69 deletions
69
src/mlpack/methods/ann/layer/not_adapted/hard_tanh_impl.hpp
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* | ||
* @file tests/ann/layer/hard_tanh.cpp | ||
* @author Vaibhav Pathak | ||
* | ||
* Tests the hard_tanh layer | ||
* | ||
*/ | ||
|
||
#include <mlpack/core.hpp> | ||
#include <mlpack/methods/ann.hpp> | ||
|
||
#include "../../test_catch_tools.hpp" | ||
#include "../../catch.hpp" | ||
#include "../../serialization.hpp" | ||
#include "../ann_test_tools.hpp" | ||
|
||
using namespace mlpack; | ||
|
||
/** | ||
* Simple HardTanH module test | ||
*/ | ||
|
||
TEST_CASE("SimpleHardTanHTest", "[ANNLayerTest]") | ||
{ | ||
arma::mat output, gy, g; | ||
arma::mat input = {{-1.3743, -0.5565, 0.2742, -0.0151, -1.4871}, | ||
{1.5797, -4.2711, -2.2505, -1.7105, -1.2544}, | ||
{0.4023, 0.5676, 2.3100, 1.6658, -0.1907}, | ||
{0.1897, 0.9097, 0.1418, -1.5349, 0.1225}, | ||
{-0.1101, -3.3656, -5.4033, -2.2240, -3.3235}}; | ||
arma::mat actualOutput = {{-1.0000, -0.5565, 0.2742, -0.0151, -1.0000}, | ||
{1.0000, -1.0000, -1.0000, -1.0000, -1.0000}, | ||
{0.4023, 0.5676, 1.0000, 1.0000, -0.1907}, | ||
{0.1897, 0.9097, 0.1418, -1.0000, 0.1225}, | ||
{-0.1101, -1.0000, -1.0000, -1.0000, -1.0000}}; | ||
|
||
HardTanH module; | ||
|
||
output.set_size(5,5); | ||
// Test the Forward function | ||
module.Forward(input, output); | ||
REQUIRE(arma::accu(output - actualOutput) == Approx(0).epsilon(1e-4)); | ||
|
||
arma::mat delta = {{0 , 1.0, 1.0, 1.0, 0.0}, | ||
{0 , 0 , 0 , 0.0, 0.0}, | ||
{1.0, 1.0, 0 , 0.0, 1.0}, | ||
{1.0, 1.0, 1.0, 0.0, 1.0}, | ||
{1.0, 0 , 0.0, 0.0, 0.0}}; | ||
|
||
gy.set_size(5,5); | ||
gy.fill(1); | ||
g.set_size(5,5); | ||
|
||
//Test the Backward function | ||
module.Backward(output, gy, g); | ||
REQUIRE(arma::accu(g - delta) == Approx(0).epsilon(1e-4)); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters