Skip to content

Commit

Permalink
fix: VW Slim (#4674)
Browse files Browse the repository at this point in the history
* handle feature scale correctly in VW slim

* fix namespace copy bug

* update tests to work with current VW

* update VW slim CI job

* fix typo in script name

* fix CI errors

* Update build_vw_slim.yml

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
Co-authored-by: Jacob Alber <jacob.alber@microsoft.com>
Co-authored-by: olgavrou <olgavrou@gmail.com>
Co-authored-by: Griffin Bassman <griffinbassman@gmail.com>
  • Loading branch information
5 people committed May 9, 2024
1 parent 4ef9bfc commit 5c77d72
Show file tree
Hide file tree
Showing 36 changed files with 356 additions and 505 deletions.
39 changes: 37 additions & 2 deletions .github/workflows/build_vw_slim.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,45 @@ jobs:
- uses: actions/checkout@v1
with:
submodules: recursive
- name: Build VW Slim

- name: Install dependencies
shell: bash
run: |
sudo apt update
sudo apt install -y xxd
- name: Configure VW Slim
shell: bash
run: ./.scripts/linux/build-slim.sh
run: |
rm -rf build
cmake -S . -B build -G Ninja \
-DBUILD_TESTING=On \
-DVW_FEAT_FLATBUFFERS=Off \
-DRAPIDJSON_SYS_DEP=Off \
-DFMT_SYS_DEP=Off \
-DSPDLOG_SYS_DEP=Off \
-DVW_ZLIB_SYS_DEP=Off \
-DVW_BOOST_MATH_SYS_DEP=Off
- name: Build VW and VW Slim
shell: bash
run: cmake --build build --target vw_cli_bin vw_slim vw_slim_test

- name: Test VW Slim
shell: bash
working-directory: build
run: ctest --output-on-failure --no-tests=error --tests-regex "VowpalWabbitSlim|ExploreSlim|CommandLineOptionsSlim" --parallel 2

- name: Generate test data with new VW executable
shell: bash
working-directory: vowpalwabbit/slim/test/data/
run: ./generate-data.sh ../../../../build/vowpalwabbit/cli/vw

- name: Build VW Slim again
shell: bash
run: cmake --build build --target vw_slim vw_slim_test

- name: Test VW Slim again
shell: bash
working-directory: build
run: ctest --output-on-failure --no-tests=error --tests-regex "VowpalWabbitSlim|ExploreSlim|CommandLineOptionsSlim" --parallel 2
19 changes: 0 additions & 19 deletions .scripts/linux/build-slim.sh

This file was deleted.

48 changes: 26 additions & 22 deletions vowpalwabbit/slim/include/vw/slim/vw_slim_predict.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,32 @@ class namespace_copy_guard
VW::example_predict& _ex;
unsigned char _ns;
bool _remove_ns;
size_t _old_size;
};

class feature_offset_guard
{
public:
feature_offset_guard(VW::example_predict& ex, uint64_t ft_offset);
feature_offset_guard(VW::example_predict& ex, uint64_t ft_index_offset);
~feature_offset_guard();

private:
VW::example_predict& _ex;
uint64_t _old_ft_offset;
uint64_t _old_ft_index_offset;
};

class stride_shift_guard
class feature_scale_guard
{
public:
stride_shift_guard(VW::example_predict& ex, uint64_t shift);
~stride_shift_guard();
feature_scale_guard(VW::example_predict& ex, uint64_t ft_index_scale);
~feature_scale_guard();

private:
VW::example_predict& _ex;
uint64_t _shift;
uint64_t _ft_index_scale;
};

/**
/*
* @brief Vowpal Wabbit slim predictor. Supports: regression, multi-class classification and contextual bandits.
*/
template <typename W>
Expand Down Expand Up @@ -152,7 +153,7 @@ class vw_predict
}

// TODO: take --cb_type dr into account
uint64_t num_weights = 0;
uint64_t feature_scale = 0;

if (_command_line_arguments.find("--cb_explore_adf") != std::string::npos)
{
Expand All @@ -164,7 +165,7 @@ class vw_predict
_bag_size = static_cast<size_t>(bag_size);

_exploration = vw_predict_exploration::bag;
num_weights = _bag_size;
feature_scale = _bag_size;

// check for additional minimum epsilon greedy
_minimum_epsilon = 0.f;
Expand Down Expand Up @@ -212,10 +213,10 @@ class vw_predict
RETURN_ON_FAIL(mp.read("resume", gd_resume));
if (gd_resume) { return E_VW_PREDICT_ERR_GD_RESUME_NOT_SUPPORTED; }

// read sparse weights into dense
_stride_shift = (uint32_t)ceil_log_2(num_weights);
_feature_scale_bits = (uint32_t)ceil_log_2(feature_scale);

RETURN_ON_FAIL(mp.read_weights<W>(_weights, _num_bits, _stride_shift));
// stride shift always 0 bits
RETURN_ON_FAIL(mp.read_weights<W>(_weights, _num_bits, 0));

// TODO: check that permutations is not enabled (or parse it)

Expand Down Expand Up @@ -261,7 +262,7 @@ class vw_predict
// add constant feature
ns_copy_guard =
std::unique_ptr<namespace_copy_guard>(new namespace_copy_guard(ex, VW::details::CONSTANT_NAMESPACE));
ns_copy_guard->feature_push_back(1.f, (VW::details::CONSTANT << _stride_shift) + ex.ft_offset);
ns_copy_guard->feature_push_back(1.f, (VW::details::CONSTANT << _feature_scale_bits) + ex.ft_offset);
}

if (_contains_wildcard)
Expand Down Expand Up @@ -291,9 +292,9 @@ class vw_predict

out_scores.resize(num_actions);

VW::example_predict* action = actions;
for (size_t i = 0; i < num_actions; i++, action++)
for (size_t i = 0; i < num_actions; i++)
{
VW::example_predict* action = &actions[i];
std::vector<std::unique_ptr<namespace_copy_guard>> ns_copy_guards;

// shared feature copying
Expand Down Expand Up @@ -358,19 +359,21 @@ class vw_predict
{
std::vector<uint32_t> top_actions(num_actions);

// apply stride shifts
std::vector<std::unique_ptr<stride_shift_guard>> stride_shift_guards;
stride_shift_guards.push_back(
std::unique_ptr<stride_shift_guard>(new stride_shift_guard(shared, _stride_shift)));
// apply feature scale
uint64_t feature_scale = static_cast<uint64_t>(1) << _feature_scale_bits;
std::vector<std::unique_ptr<feature_scale_guard>> feature_scale_guards;
feature_scale_guards.push_back(
std::unique_ptr<feature_scale_guard>(new feature_scale_guard(shared, feature_scale)));
VW::example_predict* actions_end = actions + num_actions;
for (VW::example_predict* action = actions; action != actions_end; ++action)
{
stride_shift_guards.push_back(
std::unique_ptr<stride_shift_guard>(new stride_shift_guard(*action, _stride_shift)));
feature_scale_guards.push_back(
std::unique_ptr<feature_scale_guard>(new feature_scale_guard(*action, feature_scale)));
}

for (size_t i = 0; i < _bag_size; i++)
{
// apply feature offset
std::vector<std::unique_ptr<feature_offset_guard>> feature_offset_guards;
for (VW::example_predict* action = actions; action != actions_end; ++action)
{
Expand Down Expand Up @@ -487,7 +490,8 @@ class vw_predict
size_t _bag_size;
uint32_t _num_bits;

uint32_t _stride_shift;
// log2 of feature scale, rounded upwards to next integer
uint32_t _feature_scale_bits;
bool _model_loaded;
};
} // namespace vw_slim
36 changes: 23 additions & 13 deletions vowpalwabbit/slim/src/vw_slim_predict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,57 @@ namespace_copy_guard::namespace_copy_guard(VW::example_predict& ex, unsigned cha
{
_ex.indices.push_back(_ns);
_remove_ns = true;
_old_size = 0;
}
else
{
_remove_ns = false;
_old_size = _ex.feature_space[_ns].size();
}
else { _remove_ns = false; }
}

namespace_copy_guard::~namespace_copy_guard()
{
_ex.indices.pop_back();
if (_remove_ns) { _ex.feature_space[_ns].clear(); }
if (_remove_ns)
{
_ex.feature_space[_ns].clear();
_ex.indices.pop_back();
}
else { _ex.feature_space[_ns].truncate_to(_old_size); }
}

void namespace_copy_guard::feature_push_back(VW::feature_value v, VW::feature_index idx)
{
_ex.feature_space[_ns].push_back(v, idx);
}

feature_offset_guard::feature_offset_guard(VW::example_predict& ex, uint64_t ft_offset)
: _ex(ex), _old_ft_offset(ex.ft_offset)
feature_offset_guard::feature_offset_guard(VW::example_predict& ex, uint64_t ft_index_offset)
: _ex(ex), _old_ft_index_offset(ex.ft_offset)
{
_ex.ft_offset = ft_offset;
_ex.ft_offset = ft_index_offset;
}

feature_offset_guard::~feature_offset_guard() { _ex.ft_offset = _old_ft_offset; }
feature_offset_guard::~feature_offset_guard() { _ex.ft_offset = _old_ft_index_offset; }

stride_shift_guard::stride_shift_guard(VW::example_predict& ex, uint64_t shift) : _ex(ex), _shift(shift)
feature_scale_guard::feature_scale_guard(VW::example_predict& ex, uint64_t ft_index_scale)
: _ex(ex), _ft_index_scale(ft_index_scale)
{
if (_shift > 0)
if (_ft_index_scale > 1)
{
for (auto ns : _ex.indices)
{
for (auto& f : _ex.feature_space[ns]) { f.index() <<= _shift; }
for (auto& f : _ex.feature_space[ns]) { f.index() *= _ft_index_scale; }
}
}
}

stride_shift_guard::~stride_shift_guard()
feature_scale_guard::~feature_scale_guard()
{
if (_shift > 0)
if (_ft_index_scale > 1)
{
for (auto ns : _ex.indices)
{
for (auto& f : _ex.feature_space[ns]) { f.index() >>= _shift; }
for (auto& f : _ex.feature_space[ns]) { f.index() /= _ft_index_scale; }
}
}
}
Expand Down

0 comments on commit 5c77d72

Please sign in to comment.