Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnyParameterProperties refactor #4412

Merged
merged 7 commits into from
Nov 15, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/shogun/base/AnyParameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,11 @@ namespace shogun
class AnyParameterProperties
{
public:
static const int32_t HYPER = 1u << 0;
static const int32_t GRADIENT = 1u << 1;
static const int32_t MODEL = 1u << 2;

/** Default constructor where all parameter properties are false
*/
AnyParameterProperties()
: m_description("No description given"), m_attribute_mask(0)
: m_description("No description given"),
m_attribute_mask(ParameterProperties())
{
}
/** Constructor
Expand All @@ -74,15 +71,20 @@ namespace shogun
: m_description(description), m_model_selection(hyperparameter),
m_gradient(gradient)
{
m_attribute_mask = (hyperparameter << 0 & HYPER) |
(gradient << 1 & GRADIENT) |
(model << 2 & MODEL);
m_attribute_mask = ParameterProperties();
if (hyperparameter)
m_attribute_mask |= ParameterProperties::HYPER;
if (gradient)
m_attribute_mask |= ParameterProperties::GRADIENT;
if (model)
m_attribute_mask |= ParameterProperties::MODEL;
}
/** Mask constructor
* @param description parameter description
* @param attribute_mask mask encoding parameter properties
* */
AnyParameterProperties(std::string description, int32_t attribute_mask)
AnyParameterProperties(
std::string description, ParameterProperties attribute_mask)
: m_description(description)
{
m_attribute_mask = attribute_mask;
Expand All @@ -102,23 +104,26 @@ namespace shogun
EModelSelectionAvailability get_model_selection() const
{
return static_cast<EModelSelectionAvailability>(
(m_attribute_mask & HYPER) > 0);
static_cast<int32_t>(
m_attribute_mask & ParameterProperties::HYPER) > 0);
}
EGradientAvailability get_gradient() const
{
return static_cast<EGradientAvailability>(
(m_attribute_mask & GRADIENT) > 0);
static_cast<int32_t>(
m_attribute_mask & ParameterProperties::GRADIENT) > 0);
}
bool get_model() const
{
return static_cast<bool>(m_attribute_mask & MODEL);
return static_cast<bool>(
m_attribute_mask & ParameterProperties::MODEL);
}

private:
std::string m_description;
EModelSelectionAvailability m_model_selection;
EGradientAvailability m_gradient;
int32_t m_attribute_mask;
ParameterProperties m_attribute_mask;
};

class AnyParameter
Expand Down
12 changes: 4 additions & 8 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,11 +675,7 @@ class CSGObject
template <typename T>
void watch_param(
const std::string& name, T* value,
AnyParameterProperties properties = AnyParameterProperties(
"Unknown parameter",
AnyParameterProperties::HYPER |
AnyParameterProperties::GRADIENT |
AnyParameterProperties::MODEL))
AnyParameterProperties properties = AnyParameterProperties())
{
BaseTag tag(name);
create_parameter(tag, AnyParameter(make_any_ref(value), properties));
Expand Down Expand Up @@ -735,9 +731,9 @@ class CSGObject
BaseTag tag(name);
AnyParameterProperties properties(
"Dynamic parameter",
!AnyParameterProperties::HYPER |
!AnyParameterProperties::GRADIENT |
!AnyParameterProperties::MODEL);
ParameterProperties::HYPER |
ParameterProperties::GRADIENT |
ParameterProperties::MODEL);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whitespaces seem weird here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, and should change it to default to false instead of all true.

std::function<T()> bind_method =
std::bind(method, dynamic_cast<const S*>(this));
create_parameter(tag, AnyParameter(make_any(bind_method), properties));
Expand Down