Skip to content

Commit

Permalink
AnyParameterProperties refactor (#4412)
Browse files Browse the repository at this point in the history
* combined parameter flags in a single mask
* added mask_attribute to keep track of parameter availabilities
* updated constructors and getters
* added some documentation
* changed properties default
* typesafe bitmasking
* refactored code to use enum class and bitmask operators
  • Loading branch information
Gil authored and karlnapf committed Nov 15, 2018
1 parent f6efd89 commit b45243c
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 20 deletions.
86 changes: 73 additions & 13 deletions src/shogun/base/AnyParameter.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Heiko Strathmann, Gil Hoben
*/

#ifndef __ANYPARAMETER_H__
#define __ANYPARAMETER_H__

#include <shogun/lib/any.h>
#include <shogun/lib/bitmask_operators.h>

#include <string>

Expand All @@ -22,48 +29,101 @@ namespace shogun
GRADIENT_AVAILABLE = 1
};

/** parameter properties */
enum class ParameterProperties
{
HYPER = 1u << 0,
GRADIENT = 1u << 1,
MODEL = 1u << 2
};

enableEnumClassBitmask(ParameterProperties);

/** @brief Class AnyParameterProperties keeps track of of parameter meta
* information, such as properties and descriptions The parameter properties
* can be either true or false. These properties describe if a parameter is
* for example a hyperparameter or if it has a gradient.
*/
class AnyParameterProperties
{
public:
/** Default constructor where all parameter properties are false
*/
AnyParameterProperties()
: m_description(), m_model_selection(MS_NOT_AVAILABLE),
m_gradient(GRADIENT_NOT_AVAILABLE)
: m_description("No description given"),
m_attribute_mask(ParameterProperties())
{
}
/** Constructor
* @param description parameter description
* @param hyperparameter set to true for parameters that determine
* how training is performed, e.g. regularisation parameters
* @param gradient set to true for parameters required for gradient
* updates
* @param model set to true for parameters used in inference, e.g.
* weights and bias
* */
AnyParameterProperties(
std::string description,
EModelSelectionAvailability model_selection = MS_NOT_AVAILABLE,
EGradientAvailability gradient = GRADIENT_NOT_AVAILABLE)
: m_description(description), m_model_selection(model_selection),
EModelSelectionAvailability hyperparameter = MS_NOT_AVAILABLE,
EGradientAvailability gradient = GRADIENT_NOT_AVAILABLE,
bool model = false)
: m_description(description), m_model_selection(hyperparameter),
m_gradient(gradient)
{
m_attribute_mask = ParameterProperties();
if (hyperparameter)
m_attribute_mask |= ParameterProperties::HYPER;
if (gradient)
m_attribute_mask |= ParameterProperties::GRADIENT;
if (model)
m_attribute_mask |= ParameterProperties::MODEL;
}
/** Mask constructor
* @param description parameter description
* @param attribute_mask mask encoding parameter properties
* */
AnyParameterProperties(
std::string description, ParameterProperties attribute_mask)
: m_description(description)
{
m_attribute_mask = attribute_mask;
}
/** Copy contructor */
AnyParameterProperties(const AnyParameterProperties& other)
: m_description(other.m_description),
m_model_selection(other.m_model_selection),
m_gradient(other.m_gradient)
m_gradient(other.m_gradient),
m_attribute_mask(other.m_attribute_mask)
{
}

std::string get_description() const
const std::string& get_description() const
{
return m_description;
}

EModelSelectionAvailability get_model_selection() const
{
return m_model_selection;
return static_cast<EModelSelectionAvailability>(
static_cast<int32_t>(
m_attribute_mask & ParameterProperties::HYPER) > 0);
}

EGradientAvailability get_gradient() const
{
return m_gradient;
return static_cast<EGradientAvailability>(
static_cast<int32_t>(
m_attribute_mask & ParameterProperties::GRADIENT) > 0);
}
bool get_model() const
{
return static_cast<bool>(
m_attribute_mask & ParameterProperties::MODEL);
}

private:
std::string m_description;
EModelSelectionAvailability m_model_selection;
EGradientAvailability m_gradient;
ParameterProperties m_attribute_mask;
};

class AnyParameter
Expand Down Expand Up @@ -116,6 +176,6 @@ namespace shogun
Any m_value;
AnyParameterProperties m_properties;
};
}
} // namespace shogun

#endif
14 changes: 7 additions & 7 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,7 @@ class CSGObject
template <typename T>
void watch_param(
const std::string& name, T* value,
AnyParameterProperties properties = AnyParameterProperties(
"Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE))
AnyParameterProperties properties = AnyParameterProperties())
{
BaseTag tag(name);
create_parameter(tag, AnyParameter(make_any_ref(value), properties));
Expand All @@ -693,8 +692,7 @@ class CSGObject
template <typename T, typename S>
void watch_param(
const std::string& name, T** value, S* len,
AnyParameterProperties properties = AnyParameterProperties(
"Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE))
AnyParameterProperties properties = AnyParameterProperties())
{
BaseTag tag(name);
create_parameter(
Expand All @@ -714,8 +712,7 @@ class CSGObject
template <typename T, typename S>
void watch_param(
const std::string& name, T** value, S* rows, S* cols,
AnyParameterProperties properties = AnyParameterProperties(
"Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE))
AnyParameterProperties properties = AnyParameterProperties())
{
BaseTag tag(name);
create_parameter(
Expand All @@ -733,7 +730,10 @@ class CSGObject
{
BaseTag tag(name);
AnyParameterProperties properties(
"Dynamic parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE);
"Dynamic parameter",
ParameterProperties::HYPER |
ParameterProperties::GRADIENT |
ParameterProperties::MODEL);
std::function<T()> bind_method =
std::bind(method, dynamic_cast<const S*>(this));
create_parameter(tag, AnyParameter(make_any(bind_method), properties));
Expand Down
110 changes: 110 additions & 0 deletions src/shogun/lib/bitmask_operators.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#ifndef JSS_BITMASK_HPP
#define JSS_BITMASK_HPP

// (C) Copyright 2015 Just Software Solutions Ltd
//
// Distributed under the Boost Software License, Version 1.0.
//
// Boost Software License - Version 1.0 - August 17th, 2003
//
// Permission is hereby granted, free of charge, to any person or
// organization obtaining a copy of the software and accompanying
// documentation covered by this license (the "Software") to use,
// reproduce, display, distribute, execute, and transmit the
// Software, and to prepare derivative works of the Software, and
// to permit third-parties to whom the Software is furnished to
// do so, all subject to the following:
//
// The copyright notices in the Software and this entire
// statement, including the above license grant, this restriction
// and the following disclaimer, must be included in all copies
// of the Software, in whole or in part, and all derivative works
// of the Software, unless such copies or derivative works are
// solely in the form of machine-executable object code generated
// by a source language processor.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY
// KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
// PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE
// LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN
// CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#include<type_traits>

namespace shogun {

template<typename E>
struct enable_bitmask_operators {
static constexpr bool enable = false;
};

#define enableEnumClassBitmask(T) template<> \
struct enable_bitmask_operators<T> \
{ \
static constexpr bool enable = true; \
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
operator|(E lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
return static_cast<E>(
static_cast<underlying>(lhs) | static_cast<underlying>(rhs));
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
operator&(E lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
return static_cast<E>(
static_cast<underlying>(lhs) & static_cast<underlying>(rhs));
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
operator^(E lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
return static_cast<E>(
static_cast<underlying>(lhs) ^ static_cast<underlying>(rhs));
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
operator~(E lhs) {
typedef typename std::underlying_type<E>::type underlying;
return static_cast<E>(
~static_cast<underlying>(lhs));
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E &>::type
operator|=(E &lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
lhs = static_cast<E>(
static_cast<underlying>(lhs) | static_cast<underlying>(rhs));
return lhs;
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E &>::type
operator&=(E &lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
lhs = static_cast<E>(
static_cast<underlying>(lhs) & static_cast<underlying>(rhs));
return lhs;
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E &>::type
operator^=(E &lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
lhs = static_cast<E>(
static_cast<underlying>(lhs) ^ static_cast<underlying>(rhs));
return lhs;
}
}
#endif

0 comments on commit b45243c

Please sign in to comment.