Skip to content

Commit

Permalink
Demo on how to address mulitple index contents
Browse files Browse the repository at this point in the history
Summary:
This demonstrates how to query several independent IVF indexes with a trained index in common. This avoids to duplicate the coarse quantizer and metadata in memory.

On the Faiss side, it also implements a InvertedListIterator on top of the flat inverted lists, which can prove useful.

Reviewed By: junjieqi

Differential Revision: D56575887

fbshipit-source-id: cc3b26e952ee21f24b10169b5b614066600cf4b8
  • Loading branch information
mdouze authored and facebook-github-bot committed Apr 26, 2024
1 parent 5cbff67 commit a233bc9
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 26 deletions.
72 changes: 55 additions & 17 deletions faiss/invlists/InvertedLists.cpp
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include <faiss/invlists/InvertedLists.h>

#include <cstdio>
Expand All @@ -24,18 +22,10 @@ InvertedListsIterator::~InvertedListsIterator() {}
******************************************/

InvertedLists::InvertedLists(size_t nlist, size_t code_size)
: nlist(nlist), code_size(code_size), use_iterator(false) {}
: nlist(nlist), code_size(code_size) {}

InvertedLists::~InvertedLists() {}

bool InvertedLists::is_empty(size_t list_no, void* inverted_list_context)
const {
return use_iterator ? !std::unique_ptr<InvertedListsIterator>(
get_iterator(list_no, inverted_list_context))
->is_available()
: list_size(list_no) == 0;
}

idx_t InvertedLists::get_single_id(size_t list_no, size_t offset) const {
assert(offset < list_size(list_no));
const idx_t* ids = get_ids(list_no);
Expand Down Expand Up @@ -78,12 +68,6 @@ void InvertedLists::reset() {
}
}

InvertedListsIterator* InvertedLists::get_iterator(
size_t /*list_no*/,
void* /*inverted_list_context*/) const {
FAISS_THROW_MSG("get_iterator is not supported");
}

void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
#pragma omp parallel for
for (idx_t i = 0; i < nlist; i++) {
Expand Down Expand Up @@ -233,6 +217,54 @@ size_t InvertedLists::compute_ntotal() const {
return tot;
}

bool InvertedLists::is_empty(size_t list_no, void* inverted_list_context)
const {
if (use_iterator) {
return !std::unique_ptr<InvertedListsIterator>(
get_iterator(list_no, inverted_list_context))
->is_available();
} else {
FAISS_THROW_IF_NOT(inverted_list_context == nullptr);
return list_size(list_no) == 0;
}
}

// implemnent iterator on top of get_codes / get_ids
namespace {

struct CodeArrayIterator : InvertedListsIterator {
size_t list_size;
size_t code_size;
InvertedLists::ScopedCodes codes;
InvertedLists::ScopedIds ids;
size_t idx = 0;

CodeArrayIterator(const InvertedLists* il, size_t list_no)
: list_size(il->list_size(list_no)),
code_size(il->code_size),
codes(il, list_no),
ids(il, list_no) {}

bool is_available() const override {
return idx < list_size;
}
void next() override {
idx++;
}
std::pair<idx_t, const uint8_t*> get_id_and_codes() override {
return {ids[idx], codes.get() + code_size * idx};
}
};

} // namespace

InvertedListsIterator* InvertedLists::get_iterator(
size_t list_no,
void* inverted_list_context) const {
FAISS_THROW_IF_NOT(inverted_list_context == nullptr);
return new CodeArrayIterator(this, list_no);
}

/*****************************************
* ArrayInvertedLists implementation
******************************************/
Expand Down Expand Up @@ -264,6 +296,12 @@ size_t ArrayInvertedLists::list_size(size_t list_no) const {
return ids[list_no].size();
}

bool ArrayInvertedLists::is_empty(size_t list_no, void* inverted_list_context)
const {
FAISS_THROW_IF_NOT(inverted_list_context == nullptr);
return ids[list_no].size() == 0;
}

const uint8_t* ArrayInvertedLists::get_codes(size_t list_no) const {
assert(list_no < nlist);
return codes[list_no].data();
Expand Down
27 changes: 18 additions & 9 deletions faiss/invlists/InvertedLists.h
Expand Up @@ -37,7 +37,9 @@ struct InvertedListsIterator {
struct InvertedLists {
size_t nlist; ///< number of possible key values
size_t code_size; ///< code size per vector in bytes
bool use_iterator;

/// request to use iterator rather than get_codes / get_ids
bool use_iterator = false;

InvertedLists(size_t nlist, size_t code_size);

Expand All @@ -50,17 +52,9 @@ struct InvertedLists {
/*************************
* Read only functions */

// check if the list is empty
bool is_empty(size_t list_no, void* inverted_list_context) const;

/// get the size of a list
virtual size_t list_size(size_t list_no) const = 0;

/// get iterable for lists that use_iterator
virtual InvertedListsIterator* get_iterator(
size_t list_no,
void* inverted_list_context) const;

/** get the codes for an inverted list
* must be released by release_codes
*
Expand Down Expand Up @@ -92,6 +86,18 @@ struct InvertedLists {
/// a list can be -1 hence the signed long
virtual void prefetch_lists(const idx_t* list_nos, int nlist) const;

/*****************************************
* Iterator interface (with context) */

/// check if the list is empty
virtual bool is_empty(size_t list_no, void* inverted_list_context = nullptr)
const;

/// get iterable for lists that use_iterator
virtual InvertedListsIterator* get_iterator(
size_t list_no,
void* inverted_list_context = nullptr) const;

/*************************
* writing functions */

Expand Down Expand Up @@ -262,6 +268,9 @@ struct ArrayInvertedLists : InvertedLists {
/// permute the inverted lists, map maps new_id to old_id
void permute_invlists(const idx_t* map);

bool is_empty(size_t list_no, void* inverted_list_context = nullptr)
const override;

~ArrayInvertedLists() override;
};

Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Expand Up @@ -33,6 +33,7 @@ set(FAISS_TEST_SRC
test_partitioning.cpp
test_fastscan_perf.cpp
test_disable_pq_sdc_tables.cpp
test_common_ivf_empty_index.cpp
)

add_executable(faiss_test ${FAISS_TEST_SRC})
Expand Down
144 changes: 144 additions & 0 deletions tests/test_common_ivf_empty_index.cpp
@@ -0,0 +1,144 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <gtest/gtest.h>

#include <omp.h>
#include <cstddef>
#include <memory>
#include <vector>

#include <faiss/IndexIVF.h>
#include <faiss/clone_index.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/index_factory.h>
#include <faiss/invlists/InvertedLists.h>
#include <faiss/utils/random.h>

/* This demonstrates how to query several independent IVF indexes with a trained
*index in common. This avoids to duplicate the coarse quantizer and metadata
*in memory.
**/

namespace {

int d = 64;

}; // namespace

std::vector<float> get_random_vectors(size_t n, int seed) {
std::vector<float> x(n * d);
faiss::rand_smooth_vectors(n, d, x.data(), seed);
seed++;
return x;
}

/** InvetedLists implementation that dispatches the search to an InvertedList
* object that is passed in at query time */

struct DispatchingInvertedLists : faiss::ReadOnlyInvertedLists {
DispatchingInvertedLists(size_t nlist, size_t code_size)
: faiss::ReadOnlyInvertedLists(nlist, code_size) {
use_iterator = true;
}

faiss::InvertedListsIterator* get_iterator(
size_t list_no,
void* inverted_list_context = nullptr) const override {
assert(inverted_list_context);
auto il =
static_cast<const faiss::InvertedLists*>(inverted_list_context);
return il->get_iterator(list_no);
}

using idx_t = faiss::idx_t;

size_t list_size(size_t list_no) const override {
FAISS_THROW_MSG("use iterator interface");
}
const uint8_t* get_codes(size_t list_no) const override {
FAISS_THROW_MSG("use iterator interface");
}
const idx_t* get_ids(size_t list_no) const override {
FAISS_THROW_MSG("use iterator interface");
}
};

TEST(COMMON, test_common_trained_index) {
int N = 3; // number of independent indexes
int nt = 500; // training vectors
int nb = 200; // nb database vectors per index
int nq = 10; // nb queries performed on each index
int k = 4; // restults requested per query

// construct and build an "empty index": a trained index that does not
// itself hold any data
std::unique_ptr<faiss::IndexIVF> empty_index(dynamic_cast<faiss::IndexIVF*>(
faiss::index_factory(d, "IVF32,PQ8np")));
auto xt = get_random_vectors(nt, 123);
empty_index->train(nt, xt.data());
empty_index->nprobe = 4;

// reference run: build one index for each set of db / queries and record
// results
std::vector<std::vector<faiss::idx_t>> ref_I(N);

for (int i = 0; i < N; i++) {
// clone the empty index
std::unique_ptr<faiss::Index> index(
faiss::clone_index(empty_index.get()));
auto xb = get_random_vectors(nb, 1234 + i);
auto xq = get_random_vectors(nq, 12345 + i);
// add vectors and perform a search
index->add(nb, xb.data());
std::vector<float> D(k * nq);
std::vector<faiss::idx_t> I(k * nq);
index->search(nq, xq.data(), k, D.data(), I.data());
// record result as reference
ref_I[i] = I;
}

// build a set of inverted lists for each independent index
std::vector<faiss::ArrayInvertedLists> sub_invlists;

for (int i = 0; i < N; i++) {
// swap in other inverted lists
sub_invlists.emplace_back(empty_index->nlist, empty_index->code_size);
faiss::InvertedLists* invlists = &sub_invlists.back();

// replace_invlists swaps in a new InvertedLists for an existing index
empty_index->replace_invlists(invlists, false);
empty_index->reset(); // reset id counter to 0
// populate inverted lists
auto xb = get_random_vectors(nb, 1234 + i);
empty_index->add(nb, xb.data());
}

// perform search dispatching to the sub-invlists. At search time, we don't
// use replace_invlists because that would wreak havoc in a multithreaded
// context
DispatchingInvertedLists di(empty_index->nlist, empty_index->code_size);
empty_index->replace_invlists(&di, false);

std::vector<std::vector<faiss::idx_t>> new_I(N);

// run searches in the independent indexes but with a common empty_index
#pragma omp parallel for
for (int i = 0; i < N; i++) {
auto xq = get_random_vectors(nq, 12345 + i);
std::vector<float> D(k * nq);
std::vector<faiss::idx_t> I(k * nq);

// here we set to what sub-index the queries should be directed
faiss::SearchParametersIVF params;
params.nprobe = empty_index->nprobe;
params.inverted_list_context = &sub_invlists[i];

empty_index->search(nq, xq.data(), k, D.data(), I.data(), &params);
new_I[i] = I;
}

// compare with reference reslt
for (int i = 0; i < N; i++) {
ASSERT_EQ(ref_I[i], new_I[i]);
}
}

0 comments on commit a233bc9

Please sign in to comment.