Skip to content

Enabling meta learning in Shogun

Gil edited this page Dec 3, 2018 · 10 revisions

This page contains the logs of things that have been and will be worked on to enable meta learning in Shogun (collaboration with the ATI).

First part:

Update the Shogun codebase to use a single method to register parameters that can then be observed using the Reactive model.

  • Refactor AnyParameter.h to have a clean differentiation between model parameters (weights, bias,...), hyperparameters (k in KMeans, regularisation parameters,...) and gradient parameters (#4412).
  • Refactor SGObject to add all parameters with the new API (#4417)

Second part:

Now that we have a clean interface to add parameters we can start working on the parameter registration.

  • observe parameters using SGObject (this will allow the user to filter specific parameters, i.e. write out just model parameters, such as weights, or keep track hyper parameters during a grid search

  • Write a custom function to cast Any to its original type. This will allow us simplify the API to write out Any with a logger. #4426

    • Currently we use macros to do this casting, which works well, but is verbose and macros can be difficult to debug.
    • Instead of macros we will use lambdas, using the C++14 standard (generic lambdas and capture variable references). All the messy details can now be hidden away!

Macro version:

#define CHECK_TYPE(type)                                                       \
	else if (                                                              \
	    value.first.get_value().type_info().hash_code() ==                 \
	    typeid(type).hash_code())                                          \
	{                                                                      \
		summaryValue->set_simple_value(                                \
		    any_cast<type>(value.first.get_value()));                  \
	}
if (value.first.get_value().type_info().hash_code() ==
	    typeid(int8_t).hash_code())
	{
		summaryValue->set_simple_value(
		    any_cast<int8_t>(value.first.get_value()));
	}
	CHECK_TYPE(uint8_t)
	CHECK_TYPE(int16_t)
	CHECK_TYPE(uint16_t)

[and so on...]

	CHECK_TYPE(char)
	else
	{
		SG_ERROR(
		    "Unsupported type %s", value.first.get_value().type_info().name());
	}

Lambda version:

auto write_summary = [&summaryValue=summaryValue, &value](auto type) {
    summaryValue->set_simple_value(
        any_cast<decltype(type)>(value.first.get_value()));
};
sg_for_each_type(value.first.get_value(), sg_all_types, write_summary);
Clone this wiki locally