Skip to content
Jack Gerrits edited this page Aug 10, 2022 · 9 revisions

⚠️ This is an experimental feature

Model merging takes several compatible VW models and merges them into a single model that approximately represents all of the models combined. This is will probably never be as effective as a single model trained with all of the data sequentially. However, for situations where it is not feasible to train against all data sequentially the speedup from parallel computation can make a merged model which sees all data potentially more effective than a model trained on a subset of the data.

When using model merging it is important to use --preserve_performance_counters when loading models to be merged. However, if loading a merged model the counters need to be reset prior to continue training on it. This can be done by writing and reading the model without the --preserve_performance_counters option.

Availability

The API is exposed in multiple places:

API

The general shape of this API should be consistent across the several locations.

This API will accept a list of VW models loaded as workspaces_to_merge, and return a unique pointer to a VW::workspace which is the merged result.

There are two modes which this API can be used in:

  1. If the models to be merged were trained from scratch
  2. The models to be merged with were all trained from some common base model

In case one base_workspace should be a nullptr. In case two the base_workspace should be passed as the common base model. This is to ensure that differences from the common base can be considered.

If logger is passed it is both used as a logger during the duration of the function and it is set as the logger for the produced merged model.

std::unique_ptr<VW::workspace> merge_models(const VW::workspace* base_workspace,
    const std::vector<const VW::workspace*>& workspaces_to_merge, VW::io::logger* logger = nullptr);

Details

Generally speaking, merging is a weighted average of all given models based on relative amount of data processed. Values which act as counters are accumulated instead of averaged.

In the case of the GD reduction, when save_resume is in use, then the adaptive values are used to do a per model parameter weighted average. For all other averaged values in a model, the number of examples seen by a model is used for the given weighted average.

If a reduction defines a save_load function this implies that the reduction has training state which is persisted. Therefore, a rule of thumb is that if a reduction defines save_load it must also define merge. A warning will be emitted if any of the reductions in the stack have a save_load but no merge and an error will be emitted if the base reduction in a stack has no merge as it will definitely not work in that case.

Internal signatures

The signature of the merge function depends on if the reduction is a base or not. Ideally, all merge functions would use the non-base reduction signature but since base learners use the weights and other state in VW::Workspace it is not currently feasible.

using ReductionDataT = void; // ...

// Base reduction
using merge_with_all_fn = void (*)(const std::vector<float>& per_model_weighting, const VW::workspace& base_workspace,
          const std::vector<const VW::workspace*>& all_workspaces, const ReductionDataT& base_data,
          const std::vector<ReductionDataT*>& all_data, VW::workspace& output_workspace, ReductionDataT& output_data)

// Non-base reduction
using merge_fn = void (*)(const std::vector<float>& per_model_weighting, const ReductionDataT& base_data,
          const std::vector<const ReductionDataT*>& all_data, ReductionDataT& output_data)

This is then set on the respective learner builder during construction.

The following is then exposed off of the learner object

void merge(const std::vector<float>& per_model_weighting, const VW::workspace& base_workspace,
    const std::vector<const VW::workspace*>& all_workspaces, const base_learner* base_workspaces_learner,
    const std::vector<const base_learner*>& all_learners, VW::workspace& output_workspace,
    base_learner* output_learner)
Clone this wiki locally