Skip to content

Commit

Permalink
Implement reconstruct_n for GPU IVFFlat indexes (#3338)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3338

add reconstruct_n for GPU IVFFlat

Reviewed By: mdouze

Differential Revision: D55577561

fbshipit-source-id: 47f8b939611e2df7dbcd087129538145f627293c
  • Loading branch information
junjieqi authored and facebook-github-bot committed Apr 5, 2024
1 parent da9f292 commit cfc7fe5
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 3 deletions.
22 changes: 22 additions & 0 deletions faiss/gpu/GpuIndexIVFFlat.cu
Expand Up @@ -356,5 +356,27 @@ void GpuIndexIVFFlat::setIndex_(
}
}

void GpuIndexIVFFlat::reconstruct_n(idx_t i0, idx_t ni, float* out) const {
FAISS_ASSERT(index_);

if (ni == 0) {
// nothing to do
return;
}

FAISS_THROW_IF_NOT_FMT(
i0 < this->ntotal,
"start index (%zu) out of bounds (ntotal %zu)",
i0,
this->ntotal);
FAISS_THROW_IF_NOT_FMT(
i0 + ni - 1 < this->ntotal,
"max index requested (%zu) out of bounds (ntotal %zu)",
i0 + ni - 1,
this->ntotal);

index_->reconstruct_n(i0, ni, out);
}

} // namespace gpu
} // namespace faiss
2 changes: 2 additions & 0 deletions faiss/gpu/GpuIndexIVFFlat.h
Expand Up @@ -87,6 +87,8 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
/// Trains the coarse quantizer based on the given vector data
void train(idx_t n, const float* x) override;

void reconstruct_n(idx_t i0, idx_t n, float* out) const override;

protected:
/// Initialize appropriate index
void setIndex_(
Expand Down
4 changes: 4 additions & 0 deletions faiss/gpu/impl/IVFBase.cu
Expand Up @@ -340,6 +340,10 @@ void IVFBase::copyInvertedListsTo(InvertedLists* ivf) {
}
}

void IVFBase::reconstruct_n(idx_t i0, idx_t n, float* out) {
FAISS_THROW_MSG("not implemented");
}

void IVFBase::addEncodedVectorsToList_(
idx_t listId,
const void* codes,
Expand Down
13 changes: 11 additions & 2 deletions faiss/gpu/impl/IVFBase.cuh
Expand Up @@ -109,9 +109,18 @@ class IVFBase {
Tensor<idx_t, 2, true>& outIndices,
bool storePairs) = 0;

/* It is used to reconstruct a given number of vectors in an Inverted File
* (IVF) index
* @param i0 index of the first vector to reconstruct
* @param n number of vectors to reconstruct
* @param out This is a pointer to a buffer where the reconstructed
* vectors will be stored.
*/
virtual void reconstruct_n(idx_t i0, idx_t n, float* out);

protected:
/// Adds a set of codes and indices to a list, with the representation
/// coming from the CPU equivalent
/// Adds a set of codes and indices to a list, with the
/// representation coming from the CPU equivalent
virtual void addEncodedVectorsToList_(
idx_t listId,
// resident on the host
Expand Down
47 changes: 47 additions & 0 deletions faiss/gpu/impl/IVFFlat.cu
Expand Up @@ -283,6 +283,53 @@ void IVFFlat::searchPreassigned(
storePairs);
}

void IVFFlat::reconstruct_n(idx_t i0, idx_t ni, float* out) {
if (ni == 0) {
// nothing to do
return;
}

auto stream = resources_->getDefaultStreamCurrentDevice();

for (idx_t list_no = 0; list_no < numLists_; list_no++) {
size_t list_size = deviceListData_[list_no]->numVecs;

auto idlist = getListIndices(list_no);

for (idx_t offset = 0; offset < list_size; offset++) {
idx_t id = idlist[offset];
if (!(id >= i0 && id < i0 + ni)) {
continue;
}

// vector data in the non-interleaved format is laid out like:
// v0d0 v0d1 ... v0d(dim-1) v1d0 v1d1 ... v1d(dim-1)

// vector data in the interleaved format is laid out like:
// (v0d0 v1d0 ... v31d0) (v0d1 v1d1 ... v31d1)
// (v0d(dim - 1) ... v31d(dim-1))
// (v32d0 v33d0 ... v63d0) (... v63d(dim-1)) (v64d0 ...)

// where vectors are chunked into groups of 32, and each dimension
// for each of the 32 vectors is contiguous

auto vectorChunk = offset / 32;
auto vectorWithinChunk = offset % 32;

auto listDataPtr = (float*)deviceListData_[list_no]->data.data();
listDataPtr += vectorChunk * 32 * dim_ + vectorWithinChunk;

for (int d = 0; d < dim_; ++d) {
fromDevice<float>(
listDataPtr + 32 * d,
out + (id - i0) * dim_ + d,
1,
stream);
}
}
}
}

void IVFFlat::searchImpl_(
Tensor<float, 2, true>& queries,
Tensor<float, 2, true>& coarseDistances,
Expand Down
2 changes: 2 additions & 0 deletions faiss/gpu/impl/IVFFlat.cuh
Expand Up @@ -51,6 +51,8 @@ class IVFFlat : public IVFBase {
Tensor<idx_t, 2, true>& outIndices,
bool storePairs) override;

void reconstruct_n(idx_t i0, idx_t n, float* out) override;

protected:
/// Returns the number of bytes in which an IVF list containing numVecs
/// vectors is encoded on the device. Note that due to padding this is not
Expand Down
65 changes: 65 additions & 0 deletions faiss/gpu/test/TestGpuIndexIVFFlat.cpp
Expand Up @@ -842,6 +842,71 @@ TEST(TestGpuIndexIVFFlat, LongIVFList) {
#endif
}

TEST(TestGpuIndexIVFFlat, Reconstruct_n) {
Options opt;

std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);

faiss::IndexFlatL2 cpuQuantizer(opt.dim);
faiss::IndexIVFFlat cpuIndex(
&cpuQuantizer, opt.dim, opt.numCentroids, faiss::METRIC_L2);
cpuIndex.nprobe = opt.nprobe;
cpuIndex.train(opt.numTrain, trainVecs.data());
cpuIndex.add(opt.numAdd, addVecs.data());

faiss::gpu::StandardGpuResources res;
res.noTempMemory();

faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = opt.device;
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
config.use_raft = false;

faiss::gpu::GpuIndexIVFFlat gpuIndex(
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
gpuIndex.nprobe = opt.nprobe;

gpuIndex.train(opt.numTrain, trainVecs.data());
gpuIndex.add(opt.numAdd, addVecs.data());

std::vector<float> gpuVals(opt.numAdd * opt.dim);

gpuIndex.reconstruct_n(0, gpuIndex.ntotal, gpuVals.data());

std::vector<float> cpuVals(opt.numAdd * opt.dim);

cpuIndex.reconstruct_n(0, cpuIndex.ntotal, cpuVals.data());

EXPECT_EQ(gpuVals, cpuVals);

config.indicesOptions = faiss::gpu::INDICES_32_BIT;

faiss::gpu::GpuIndexIVFFlat gpuIndex1(
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
gpuIndex1.nprobe = opt.nprobe;

gpuIndex1.train(opt.numTrain, trainVecs.data());
gpuIndex1.add(opt.numAdd, addVecs.data());

gpuIndex1.reconstruct_n(0, gpuIndex1.ntotal, gpuVals.data());

EXPECT_EQ(gpuVals, cpuVals);

config.indicesOptions = faiss::gpu::INDICES_CPU;

faiss::gpu::GpuIndexIVFFlat gpuIndex2(
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
gpuIndex2.nprobe = opt.nprobe;

gpuIndex2.train(opt.numTrain, trainVecs.data());
gpuIndex2.add(opt.numAdd, addVecs.data());

gpuIndex2.reconstruct_n(0, gpuIndex2.ntotal, gpuVals.data());

EXPECT_EQ(gpuVals, cpuVals);
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);

Expand Down
1 change: 1 addition & 0 deletions faiss/gpu/test/test_gpu_basics.py
Expand Up @@ -11,6 +11,7 @@
import random
from common_faiss_tests import get_dataset_2


class ReferencedObject(unittest.TestCase):

d = 16
Expand Down
25 changes: 25 additions & 0 deletions faiss/gpu/test/test_gpu_index_ivfflat.py
@@ -0,0 +1,25 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import faiss
import numpy as np


class TestGpuIndexIvfflat(unittest.TestCase):
def test_reconstruct_n(self):
index = faiss.index_factory(4, "IVF10,Flat")
x = np.random.RandomState(123).rand(10, 4).astype('float32')
index.train(x)
index.add(x)
res = faiss.StandardGpuResources()
res.noTempMemory()
config = faiss.GpuIndexIVFFlatConfig()
config.use_raft = False
index2 = faiss.GpuIndexIVFFlat(res, index, config)
recons = index2.reconstruct_n(0, 10)

np.testing.assert_array_equal(recons, x)
36 changes: 35 additions & 1 deletion faiss/gpu/test/torch_test_contrib_gpu.py
Expand Up @@ -108,7 +108,7 @@ def test_train_add_with_ids(self):
self.assertTrue(np.array_equal(I.reshape(10), ids_np[10:20]))

# tests reconstruct, reconstruct_n
def test_reconstruct(self):
def test_flat_reconstruct(self):
d = 32
res = faiss.StandardGpuResources()
res.noTempMemory()
Expand Down Expand Up @@ -157,6 +157,40 @@ def test_reconstruct(self):
index.reconstruct_n(50, 10, y)
self.assertTrue(torch.equal(xb[50:60], y))

def test_ivfflat_reconstruct(self):
d = 32
nlist = 5
res = faiss.StandardGpuResources()
res.noTempMemory()
config = faiss.GpuIndexIVFFlatConfig()
config.use_raft = False

index = faiss.GpuIndexIVFFlat(res, d, nlist, faiss.METRIC_L2, config)

xb = torch.rand(100, d, device=torch.device('cuda', 0), dtype=torch.float32)
index.train(xb)
index.add(xb)

# Test reconstruct_n with torch gpu (native return)
y = index.reconstruct_n(10, 10)
self.assertTrue(y.is_cuda)
self.assertTrue(torch.equal(xb[10:20], y))

# Test reconstruct with numpy output provided
y = np.empty((10, d), dtype='float32')
index.reconstruct_n(20, 10, y)
self.assertTrue(np.array_equal(xb.cpu().numpy()[20:30], y))

# Test reconstruct_n with torch cpu output provided
y = torch.empty(10, d, dtype=torch.float32)
index.reconstruct_n(40, 10, y)
self.assertTrue(torch.equal(xb[40:50].cpu(), y))

# Test reconstruct_n with torch gpu output provided
y = torch.empty(10, d, device=torch.device('cuda', 0), dtype=torch.float32)
index.reconstruct_n(50, 10, y)
self.assertTrue(torch.equal(xb[50:60], y))

# tests assign
def test_assign(self):
d = 32
Expand Down
2 changes: 2 additions & 0 deletions faiss/gpu/utils/DeviceVector.cuh
Expand Up @@ -169,6 +169,8 @@ class DeviceVector {
T out;
CUDA_VERIFY(cudaMemcpyAsync(
&out, data() + idx, sizeof(T), cudaMemcpyDeviceToHost, stream));

return out;
}

// Clean up after oversized allocations, while leaving some space to
Expand Down

0 comments on commit cfc7fe5

Please sign in to comment.