Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnyParameterProperties refactor #4412

Merged
merged 7 commits into from
Nov 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 73 additions & 13 deletions src/shogun/base/AnyParameter.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Heiko Strathmann, Gil Hoben
*/

#ifndef __ANYPARAMETER_H__
#define __ANYPARAMETER_H__

#include <shogun/lib/any.h>
#include <shogun/lib/bitmask_operators.h>

#include <string>

Expand All @@ -22,48 +29,101 @@ namespace shogun
GRADIENT_AVAILABLE = 1
};

/** parameter properties */
enum class ParameterProperties
{
HYPER = 1u << 0,
GRADIENT = 1u << 1,
MODEL = 1u << 2
};

enableEnumClassBitmask(ParameterProperties);

/** @brief Class AnyParameterProperties keeps track of of parameter meta
* information, such as properties and descriptions The parameter properties
* can be either true or false. These properties describe if a parameter is
* for example a hyperparameter or if it has a gradient.
*/
class AnyParameterProperties
{
public:
/** Default constructor where all parameter properties are false
*/
AnyParameterProperties()
: m_description(), m_model_selection(MS_NOT_AVAILABLE),
m_gradient(GRADIENT_NOT_AVAILABLE)
: m_description("No description given"),
m_attribute_mask(ParameterProperties())
{
}
/** Constructor
* @param description parameter description
* @param hyperparameter set to true for parameters that determine
* how training is performed, e.g. regularisation parameters
* @param gradient set to true for parameters required for gradient
* updates
* @param model set to true for parameters used in inference, e.g.
* weights and bias
* */
AnyParameterProperties(
std::string description,
EModelSelectionAvailability model_selection = MS_NOT_AVAILABLE,
EGradientAvailability gradient = GRADIENT_NOT_AVAILABLE)
: m_description(description), m_model_selection(model_selection),
EModelSelectionAvailability hyperparameter = MS_NOT_AVAILABLE,
EGradientAvailability gradient = GRADIENT_NOT_AVAILABLE,
bool model = false)
: m_description(description), m_model_selection(hyperparameter),
m_gradient(gradient)
{
m_attribute_mask = ParameterProperties();
if (hyperparameter)
m_attribute_mask |= ParameterProperties::HYPER;
if (gradient)
m_attribute_mask |= ParameterProperties::GRADIENT;
if (model)
m_attribute_mask |= ParameterProperties::MODEL;
}
/** Mask constructor
* @param description parameter description
* @param attribute_mask mask encoding parameter properties
* */
AnyParameterProperties(
std::string description, ParameterProperties attribute_mask)
: m_description(description)
{
m_attribute_mask = attribute_mask;
}
/** Copy contructor */
AnyParameterProperties(const AnyParameterProperties& other)
: m_description(other.m_description),
m_model_selection(other.m_model_selection),
m_gradient(other.m_gradient)
m_gradient(other.m_gradient),
m_attribute_mask(other.m_attribute_mask)
{
}

std::string get_description() const
const std::string& get_description() const
{
return m_description;
}

EModelSelectionAvailability get_model_selection() const
{
return m_model_selection;
return static_cast<EModelSelectionAvailability>(
static_cast<int32_t>(
m_attribute_mask & ParameterProperties::HYPER) > 0);
}

EGradientAvailability get_gradient() const
{
return m_gradient;
return static_cast<EGradientAvailability>(
static_cast<int32_t>(
m_attribute_mask & ParameterProperties::GRADIENT) > 0);
}
bool get_model() const
{
return static_cast<bool>(
m_attribute_mask & ParameterProperties::MODEL);
}

private:
std::string m_description;
EModelSelectionAvailability m_model_selection;
EGradientAvailability m_gradient;
ParameterProperties m_attribute_mask;
};

class AnyParameter
Expand Down Expand Up @@ -116,6 +176,6 @@ namespace shogun
Any m_value;
AnyParameterProperties m_properties;
};
}
} // namespace shogun

#endif
14 changes: 7 additions & 7 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,7 @@ class CSGObject
template <typename T>
void watch_param(
const std::string& name, T* value,
AnyParameterProperties properties = AnyParameterProperties(
"Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE))
AnyParameterProperties properties = AnyParameterProperties())
{
BaseTag tag(name);
create_parameter(tag, AnyParameter(make_any_ref(value), properties));
Expand All @@ -693,8 +692,7 @@ class CSGObject
template <typename T, typename S>
void watch_param(
const std::string& name, T** value, S* len,
AnyParameterProperties properties = AnyParameterProperties(
"Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE))
AnyParameterProperties properties = AnyParameterProperties())
{
BaseTag tag(name);
create_parameter(
Expand All @@ -714,8 +712,7 @@ class CSGObject
template <typename T, typename S>
void watch_param(
const std::string& name, T** value, S* rows, S* cols,
AnyParameterProperties properties = AnyParameterProperties(
"Unknown parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE))
AnyParameterProperties properties = AnyParameterProperties())
{
BaseTag tag(name);
create_parameter(
Expand All @@ -733,7 +730,10 @@ class CSGObject
{
BaseTag tag(name);
AnyParameterProperties properties(
"Dynamic parameter", MS_NOT_AVAILABLE, GRADIENT_NOT_AVAILABLE);
"Dynamic parameter",
ParameterProperties::HYPER |
ParameterProperties::GRADIENT |
ParameterProperties::MODEL);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whitespaces seem weird here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, and should change it to default to false instead of all true.

std::function<T()> bind_method =
std::bind(method, dynamic_cast<const S*>(this));
create_parameter(tag, AnyParameter(make_any(bind_method), properties));
Expand Down
110 changes: 110 additions & 0 deletions src/shogun/lib/bitmask_operators.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#ifndef JSS_BITMASK_HPP
#define JSS_BITMASK_HPP

// (C) Copyright 2015 Just Software Solutions Ltd
//
// Distributed under the Boost Software License, Version 1.0.
//
// Boost Software License - Version 1.0 - August 17th, 2003
//
// Permission is hereby granted, free of charge, to any person or
// organization obtaining a copy of the software and accompanying
// documentation covered by this license (the "Software") to use,
// reproduce, display, distribute, execute, and transmit the
// Software, and to prepare derivative works of the Software, and
// to permit third-parties to whom the Software is furnished to
// do so, all subject to the following:
//
// The copyright notices in the Software and this entire
// statement, including the above license grant, this restriction
// and the following disclaimer, must be included in all copies
// of the Software, in whole or in part, and all derivative works
// of the Software, unless such copies or derivative works are
// solely in the form of machine-executable object code generated
// by a source language processor.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY
// KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
// PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE
// LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN
// CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could put "modified by G.F" here if you want

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok! I wasn't sure how it works with the boost software license!


#include<type_traits>

namespace shogun {

template<typename E>
struct enable_bitmask_operators {
static constexpr bool enable = false;
};

#define enableEnumClassBitmask(T) template<> \
struct enable_bitmask_operators<T> \
{ \
static constexpr bool enable = true; \
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
operator|(E lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
return static_cast<E>(
static_cast<underlying>(lhs) | static_cast<underlying>(rhs));
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
operator&(E lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
return static_cast<E>(
static_cast<underlying>(lhs) & static_cast<underlying>(rhs));
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
operator^(E lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
return static_cast<E>(
static_cast<underlying>(lhs) ^ static_cast<underlying>(rhs));
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E>::type
operator~(E lhs) {
typedef typename std::underlying_type<E>::type underlying;
return static_cast<E>(
~static_cast<underlying>(lhs));
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E &>::type
operator|=(E &lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
lhs = static_cast<E>(
static_cast<underlying>(lhs) | static_cast<underlying>(rhs));
return lhs;
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E &>::type
operator&=(E &lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
lhs = static_cast<E>(
static_cast<underlying>(lhs) & static_cast<underlying>(rhs));
return lhs;
}

template<typename E>
typename std::enable_if<enable_bitmask_operators<E>::enable, E &>::type
operator^=(E &lhs, E rhs) {
typedef typename std::underlying_type<E>::type underlying;
lhs = static_cast<E>(
static_cast<underlying>(lhs) ^ static_cast<underlying>(rhs));
return lhs;
}
}
#endif