Skip to content
Byron Xu edited this page Feb 7, 2023 · 6 revisions

VW works by transforming a problem from one type to another through a set of reductions. This allows us to leverage hardened solutions to existing problems to solve new problems. In VW, this is implemented using a reduction stack, which is a chain of learner objects. Each learner represents one distinct learning algorithm.

Conceptually, there are two categories of learners: reduction learners and bottom learners. A reduction learner requires at least one learner below it in the stack. We call that learner its base, and the reduction will recursively call into it, using the base learner's output to compute its own result. A bottom learner does not require any further learners below it, and it directly returns a result. The bottom of the reduction stack must be a bottom learner, and all other learners in the stack must be reduction learners.

This document describes how a learner is implemented in the codebase. Learners are defined by a set of types, functions, and a couple of data fields. Learners are created using learner builder objects, which are templated to enforce type consistency. After a learner is created, the learner object itself is fully type-erased, so all learners are the exact same C++ type: class learner.

Types

A learner has several important data types:

  • DataT - The type of the data object of this learner. Each learner has its own data object to store internal state.
  • ExampleT - The type of example this reduction expects. Either example or multi_ex.
  • label_type_t - Used for two label types: the type this learner expects examples to have, and the type this learner produces for its base
  • prediction_type_t - Used for two prediction types: the type this learner expects its base to produce, and the type this learner itself produces

Note that DataT and ExampleT are template parameters for learner builders. They are type-erased in the learner builder so that the resulting learner object does not reference them. However, label_type_t and prediction_type_t are enums, and are used for data fields in the learner object.

Input and output types

The types label_type_t and prediction_type_t are used to define a contract between a reduction learner and its base. These properties must be satisfied in order for the reduction stack to work.

  • The output label type of a reduction should match the input label type of its base
  • The input label type of a reduction should match the output label type of its base

More details, including special cases for bottom learners, are provided on this page: Matching Label and Prediction Types Between Reductions

Data fields

This is an overview of important fields in the learner class.

  • std::string _name - Human-readable name for the learner
  • size_t weights - Describes how many weight vectors are required by this learner. This means that there can essentially be several models that are referenced by this learner.
  • size_t increment - Used along with the per call increment to reference different weight vectors
  • bool is_multiline - true if the expected ExampleT is multi_ex, otherwise false and ExampleT is example
  • Input and output prediction and label types:
    • prediction_type_t _output_pred_type
    • prediction_type_t _input_pred_type
    • label_type_t _output_label_type
    • label_type_t _input_label_type
  • std::shared_ptr<void> _learner_data - The data object for this learner. Note that here it has been type-erased from DataT to void.
  • std::shared_ptr<learner> _base_learner - The base of this learner. It points to the learner object immediately below this one in the reduction stack.
    • As a shared_ptr, this gives each reduction ownership of its base learner. Multiple learners are allowed to share the same base, but this is very uncommon.
    • For bottom learners, this will be nullptr because there does not exist a learner below it.
    • Note that the reduction stack cannot be traversed from bottom to top. You can only go from top to bottom.

Functions

The logic of a learner is implemented by several std::function objects. For the overwhelming majority of reductions only learn, predict, and finish_example are important.

Overview

Functions can be assigned to a learner only via learner builders. The learner builder takes function pointers to fully-typed functions (with DataT and ExampleT), and type-erases them. This is done by binding some arguments via lambda-capture so that the resulting function can be stored in the same generic std::function type for all learners.

Some functions of a learner are auto-recursive. Auto-recursion is where the corresponding function for each learner in the stack is invoked in sequence automatically, without any individual function in the stack having to explicitly call a base learner's function. This is done by the implementation of the learner class itself.

Not all functions are auto-recursive. Some functions will need to explicitly call the functions of its base learner in the stack.

Details for important functions in a learner are provided in the following sections. Note that all function signatures given here are those before type-erasure. When implementing a new learner, the functions you write should have the signatures given below, with DataT and ExampleT replaced by your specific data and example types. You will provide a function pointer to the learner builder, which expects fully-typed function pointers as inputs and stores type-erased std::function objects into the learner.

Init

void(DataT* data);

This is called once by the driver when it starts up. This does not auto-recurse, the definition in the top-most learner will be used.

Learn/Predict/Update

void(DataT* data, BaseT* base_learner, ExampleT* example);

These three functions are perhaps the most important. They define the core learning process. update is not commonly used and by default it simply refers to learn.

These functions will not auto-recurse. However, in nearly all cases, you will want to use the result of the base learner. Thus, you are responsible for implementing a call to the appropriate function in the base learner.

Each example passed to this function implicitly has a valid label_type_t object associated with it. Additionally, when ExampleT == VW::example there is an allocated and empty prediction object on the example, and when ExampleT == VW::multi_ex there is an allocated and empty prediction object on the zeroth example.

When the base learner is called, any examples that are passed to it MUST adhere to the contract described previously. This is a very important requirement that, if broken, causes serious and sometimes hard to find bugs. Your implementations of these functions are responsible for satisfying the contract.

Multipredict

void(DataT* data, BaseT& base, ExampleT* ex, size_t count, size_t step, polyprediction* pred, bool finalize_predictions);

Multipredict makes several predictions using one example. Each call increments the offset, so it is effectively using a different weight vector for each prediction. This is often used internally in reductions but not often used externally.

Multipredict does not need to be defined. By default, the learner implementation will automatically fall back to predict if multipredict is undefined.

  • pred is an array of count number of polyprediction objects.
  • step is the weight increment to be applied per prediction

Sensitivity

float(DataT* data, BaseT* base, ExampleT* example);

Does not auto-recurse.

Finish Example

void(vw&, DataT* data, EaxmpleT* ex);

Finish example is called after learn/predict and is where the reduction must calculate and report loss as well as free any resources that were allocated for that example. Additionally, the example label and prediction must be returned to a clean slate.

Does not auto-recurse.

End Pass

void(DataT* data);

Called at the end of a learning pass. This function is autorecursive.

End Examples

void(DataT* data);

Called once all of the examples are finished being parsed and processed by the reduction stack. This function is autorecursive.

Finish

void(DataT* data);

Called as the reduction is being destroyed. However, do note that if the reduction data type DataT has a destructor, it will be called automatically. So often this function is not necessary. This function is autorecursive.

Save Load

void(DataT* data, io_buf* model_buffer, bool read, bool text);

This is how the reduction implements serialization and deserialization from a model file. This function is auto recursive.

  • read is true if a model file is being read and false if the expectation is to write to the buffer
  • text means that a readable model should be written instead of binary
Clone this wiki locally