Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coot 17: Refactor MakeAlias to make it compatible with bandicoot #3693

Merged
merged 21 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 38 additions & 23 deletions src/mlpack/core/math/make_alias.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,52 +19,66 @@ namespace mlpack {
* Reconstruct `m` as an alias around the memory `newMem`, with size `numRows` x
* `numCols`.
*/
template<typename MatType>
void MakeAlias(MatType& m,
typename MatType::elem_type* newMem,
const size_t numRows,
const size_t numCols,
template<typename InVecType, typename OutVecType>
void MakeAlias(OutVecType& v,
const InVecType& oldVec,
const size_t numElems,
const size_t offset = 0,
const bool strict = true,
const typename std::enable_if_t<!IsCube<MatType>::value>* = 0)
const typename std::enable_if_t<IsVector<OutVecType>::value>* = 0)
{
// We use placement new to reinitialize the object, since the copy and move
// assignment operators in Armadillo will end up copying memory instead of
// making an alias.
m.~MatType();
new (&m) MatType(newMem, numRows, numCols, false, strict);
typename InVecType::elem_type* newMem =
const_cast<typename InVecType::elem_type*>(oldVec.memptr()) + offset;
v.~OutVecType();
new (&v) OutVecType(newMem, numElems, false, strict);
}

/**
* Reconstruct `c` as an alias around the memory` newMem`, with size `numRows` x
* `numCols` x `numSlices`.
* Reconstruct `m` as an alias around the memory `newMem`, with size `numRows` x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch :)

* `numCols`.
*/
template<typename CubeType>
void MakeAlias(CubeType& c,
typename CubeType::elem_type* newMem,
template<typename InMatType, typename OutMatType>
void MakeAlias(OutMatType& m,
const InMatType& oldMat,
const size_t numRows,
const size_t numCols,
const size_t numSlices,
const size_t offset = 0,
const bool strict = true,
const typename std::enable_if_t<IsCube<CubeType>::value>* = 0)
const typename std::enable_if_t<IsMatrix<OutMatType>::value>* = 0)
{
// We use placement new to reinitialize the object, since the copy and move
// assignment operators in Armadillo will end up copying memory instead of
// making an alias.
c.~CubeType();
new (&c) CubeType(newMem, numRows, numCols, numSlices, false, strict);
typename InMatType::elem_type* newMem =
const_cast<typename InMatType::elem_type*>(oldMat.memptr()) + offset;
m.~OutMatType();
new (&m) OutMatType(newMem, numRows, numCols, false, strict);
}

/**
* Make `m` an alias of `in`, using the given size.
* Reconstruct `c` as an alias around the memory` newMem`, with size `numRows` x
* `numCols` x `numSlices`.
*/
template<typename eT>
void MakeAlias(arma::Mat<eT>& m,
const arma::Mat<eT>& in,
template<typename InCubeType, typename OutCubeType>
void MakeAlias(OutCubeType& c,
const InCubeType& oldCube,
const size_t numRows,
const size_t numCols,
const bool strict = true)
const size_t numSlices,
const size_t offset = 0,
const bool strict = true,
const typename std::enable_if_t<IsCube<OutCubeType>::value>* = 0)
{
MakeAlias(m, (eT*) in.memptr(), numRows, numCols, strict);
// We use placement new to reinitialize the object, since the copy and move
// assignment operators in Armadillo will end up copying memory instead of
// making an alias.
typename InCubeType::elem_type* newMem =
const_cast<typename InCubeType::elem_type*>(oldCube.memptr()) + offset;
c.~OutCubeType();
new (&c) OutCubeType(newMem, numRows, numCols, numSlices, false, strict);
}

/**
Expand All @@ -75,6 +89,7 @@ void MakeAlias(arma::SpMat<eT>& m,
const arma::SpMat<eT>& in,
const size_t /* numRows */,
const size_t /* numCols */,
const size_t /* offset */,
const bool /* strict */)
{
// We can't make aliases of sparse objects, so just copy it.
Expand Down
27 changes: 15 additions & 12 deletions src/mlpack/methods/ann/ffn_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,12 @@ void FFN<
const size_t effectiveBatchSize = std::min(batchSize,
size_t(predictors.n_cols) - i);

const MatType predictorAlias(
const_cast<typename MatType::elem_type*>(predictors.colptr(i)),
predictors.n_rows, effectiveBatchSize, false, true);
MatType resultAlias(results.colptr(i), results.n_rows,
effectiveBatchSize, false, true);
MatType predictorAlias, resultAlias;

MakeAlias(predictorAlias, predictors, predictors.n_rows,
effectiveBatchSize, i * predictors.n_rows);
MakeAlias(resultAlias, results, results.n_rows, effectiveBatchSize,
i * results.n_rows);

network.Forward(predictorAlias, resultAlias);
}
Expand Down Expand Up @@ -449,8 +450,10 @@ typename MatType::elem_type FFN<
// pass.
networkOutput.set_size(network.OutputSize(), batchSize);
MatType predictorsBatch, responsesBatch;
MakeAlias(predictorsBatch, predictors.colptr(begin), predictors.n_rows, batchSize);
MakeAlias(responsesBatch, responses.colptr(begin), responses.n_rows, batchSize);
MakeAlias(predictorsBatch, predictors, predictors.n_rows, batchSize,
begin * predictors.n_rows);
MakeAlias(responsesBatch, responses, responses.n_rows, batchSize,
begin * responses.n_rows);
network.Forward(predictorsBatch, networkOutput);

return outputLayer.Forward(networkOutput, responsesBatch) + network.Loss();
Expand Down Expand Up @@ -497,10 +500,10 @@ typename MatType::elem_type FFN<

// Alias the batches so we don't copy memory.
MatType predictorsBatch, responsesBatch;
MakeAlias(predictorsBatch, predictors.colptr(begin), predictors.n_rows,
batchSize);
MakeAlias(responsesBatch, responses.colptr(begin), responses.n_rows,
batchSize);
MakeAlias(predictorsBatch, predictors, predictors.n_rows,
batchSize, begin * predictors.n_rows);
MakeAlias(responsesBatch, responses, responses.n_rows,
batchSize, begin * responses.n_rows);

network.Forward(predictorsBatch, networkOutput);

Expand Down Expand Up @@ -596,7 +599,7 @@ void FFN<
"FFN::SetLayerMemory(): total layer weight size does not match parameter "
"size!");

network.SetWeights(parameters.memptr());
network.SetWeights(parameters);
layerMemoryIsSet = true;
}

Expand Down
2 changes: 1 addition & 1 deletion src/mlpack/methods/ann/layer/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class AddType : public Layer<MatType>
void ComputeOutputDimensions();

//! Set the weights of the layer to use the given memory.
void SetWeights(typename MatType::elem_type* weightPtr);
void SetWeights(const MatType& weightsIn);

/**
* Serialize the layer.
Expand Down
4 changes: 2 additions & 2 deletions src/mlpack/methods/ann/layer/add_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ void AddType<MatType>::Gradient(
}

template<typename MatType>
void AddType<MatType>::SetWeights(typename MatType::elem_type* weightPtr)
void AddType<MatType>::SetWeights(const MatType& weightsIn)
{
// Set the weights to wrap the given memory.
MakeAlias(weights, weightPtr, 1, outSize);
MakeAlias(weights, weightsIn, 1, outSize);
}

template<typename MatType>
Expand Down
10 changes: 3 additions & 7 deletions src/mlpack/methods/ann/layer/batch_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,15 @@ class BatchNormType : public Layer<MatType>
/**
* Reset the layer parameters.
*/
void SetWeights(typename MatType::elem_type* weightsPtr);
void SetWeights(const MatType& weightsIn);

/**
* Initialize the weight matrix of the layer.
*
* @param W Weight matrix to initialize.
* @param elements Number of elements.
*/
void CustomInitialize(
MatType& W,
const size_t elements);
void CustomInitialize(MatType& W, const size_t elements);

/**
* Forward pass of the Batch Normalization layer. Transforms the input data
Expand Down Expand Up @@ -160,9 +158,7 @@ class BatchNormType : public Layer<MatType>
* @param error The calculated error
* @param gradient The calculated gradient.
*/
void Gradient(const MatType& input,
const MatType& error,
MatType& gradient);
void Gradient(const MatType& input, const MatType& error, MatType& gradient);

//! Get the parameters.
const MatType& Parameters() const { return weights; }
Expand Down
13 changes: 6 additions & 7 deletions src/mlpack/methods/ann/layer/batch_norm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,13 @@ BatchNormType<MatType>::operator=(
}

template<typename MatType>
void BatchNormType<MatType>::SetWeights(
typename MatType::elem_type* weightsPtr)
void BatchNormType<MatType>::SetWeights(const MatType& weightsIn)
{
MakeAlias(weights, weightsPtr, WeightSize(), 1);
MakeAlias(weights, weightsIn, WeightSize(), 1);
// Gamma acts as the scaling parameters for the normalized output.
MakeAlias(gamma, weightsPtr, size, 1);
MakeAlias(gamma, weightsIn, size, 1);
// Beta acts as the shifting parameters for the normalized output.
MakeAlias(beta, weightsPtr + gamma.n_elem, size, 1);
MakeAlias(beta, weightsIn, size, 1, gamma.n_elem);
}

template<typename MatType>
Expand All @@ -170,9 +169,9 @@ void BatchNormType<MatType>::CustomInitialize(
MatType gammaTemp;
MatType betaTemp;
// Gamma acts as the scaling parameters for the normalized output.
MakeAlias(gammaTemp, W.memptr(), size, 1);
MakeAlias(gammaTemp, W, size, 1);
// Beta acts as the shifting parameters for the normalized output.
MakeAlias(betaTemp, W.memptr() + gammaTemp.n_elem, size, 1);
MakeAlias(betaTemp, W, size, 1, gammaTemp.n_elem);

gammaTemp.fill(1.0);
betaTemp.fill(0.0);
Expand Down
53 changes: 10 additions & 43 deletions src/mlpack/methods/ann/layer/concat_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,12 @@ void ConcatType<MatType>::Forward(const MatType& input, MatType& output)
this->layerOutputs.size());
for (size_t i = 0; i < this->layerOutputs.size(); ++i)
{
MakeAlias(layerOutputAliases[i],
(typename MatType::elem_type*) this->layerOutputs[i].memptr(),
rows,
this->network[i]->OutputDimensions()[axis],
slices);
MakeAlias(layerOutputAliases[i], this->layerOutputs[i], rows,
this->network[i]->OutputDimensions()[axis], slices);
}

arma::Cube<typename MatType::elem_type> outputAlias;
MakeAlias(outputAlias,
(typename MatType::elem_type*) output.memptr(),
rows,
this->outputDimensions[axis],
slices);
MakeAlias(outputAlias, output, rows, this->outputDimensions[axis], slices);

// Now get the columns from each output.
size_t startCol = 0;
Expand Down Expand Up @@ -171,11 +164,7 @@ void ConcatType<MatType>::Backward(
slices *= this->outputDimensions[i];

arma::Cube<typename MatType::elem_type> gyTmp;
MakeAlias(gyTmp,
(typename MatType::elem_type*) gy.memptr(),
rows,
this->outputDimensions[axis],
slices);
MakeAlias(gyTmp, gy, rows, this->outputDimensions[axis], slices);
rcurtin marked this conversation as resolved.
Show resolved Hide resolved

size_t startCol = 0;
for (size_t i = 0; i < this->network.size(); ++i)
Expand Down Expand Up @@ -221,11 +210,7 @@ void ConcatType<MatType>::Backward(
slices *= this->outputDimensions[i];

arma::Cube<typename MatType::elem_type> gyTmp;
MakeAlias(gyTmp,
(typename MatType::elem_type*) gy.memptr(),
rows,
this->outputDimensions[axis],
slices);
MakeAlias(gyTmp, gy, rows, this->outputDimensions[axis], slices);

size_t startCol = 0;
for (size_t i = 0; i < index; ++i)
Expand All @@ -238,11 +223,7 @@ void ConcatType<MatType>::Backward(
// Reshape so that the batch size is the number of columns.
delta.reshape(delta.n_elem / gy.n_cols, gy.n_cols);

this->network[index]->Backward(
input,
this->layerOutputs[index],
delta,
g);
this->network[index]->Backward(input, this->layerOutputs[index], delta, g);
}

template<typename MatType>
Expand All @@ -263,11 +244,7 @@ void ConcatType<MatType>::Gradient(
slices *= this->outputDimensions[i];

arma::Cube<typename MatType::elem_type> errorTmp;
MakeAlias(errorTmp,
(typename MatType::elem_type*) error.memptr(),
rows,
this->outputDimensions[axis],
slices);
MakeAlias(errorTmp, error, rows, this->outputDimensions[axis], slices);

size_t startCol = 0;
size_t startParam = 0;
Expand All @@ -279,10 +256,7 @@ void ConcatType<MatType>::Gradient(
MatType err = errorTmp.cols(startCol, startCol + cols - 1);
err.reshape(err.n_elem / input.n_cols, input.n_cols);
MatType gradientAlias;
MakeAlias(gradientAlias,
(typename MatType::elem_type*) gradient.memptr() + startParam,
params,
1);
MakeAlias(gradientAlias, gradient, params, 1, startParam);
this->network[i]->Gradient(input, err, gradientAlias);

startCol += cols;
Expand All @@ -309,11 +283,7 @@ void ConcatType<MatType>::Gradient(
slices *= this->outputDimensions[i];

arma::Cube<typename MatType::elem_type> errorTmp;
MakeAlias(errorTmp,
(typename MatType::elem_type*) error.memptr(),
rows,
this->outputDimensions[axis],
slices);
MakeAlias(errorTmp, error, rows, this->outputDimensions[axis], slices);

size_t startCol = 0;
size_t startParam = 0;
Expand All @@ -329,10 +299,7 @@ void ConcatType<MatType>::Gradient(
MatType err = errorTmp.cols(startCol, startCol + cols - 1);
err.reshape(err.n_elem / input.n_cols, input.n_cols);
MatType gradientAlias;
MakeAlias(gradientAlias,
(typename MatType::elem_type*) gradient.memptr() + startParam,
params,
1);
MakeAlias(gradientAlias, gradient, params, 1, startParam);
this->network[index]->Gradient(input, err, gradientAlias);
}

Expand Down
2 changes: 1 addition & 1 deletion src/mlpack/methods/ann/layer/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class ConvolutionType : public Layer<MatType>
/*
* Set the weight and bias term.
*/
void SetWeights(typename MatType::elem_type* weightsPtr);
void SetWeights(const MatType& weightsIn);

/**
* Ordinary feed forward pass of a neural network, evaluating the function
Expand Down