Skip to content

Commit

Permalink
tiling bfKnn (#2865)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2865

Introduces a tiling version of `bfKnn` called `bfKnn_tiling`, which can break up both queries and vectors into tiles of size vectorsMemoryLimit and queriesMemoryLimit.

Reviewed By: wickedfoo

Differential Revision: D45944524

fbshipit-source-id: f9cd4c14dbf2d43def773124f19e92d25c86fc5a
  • Loading branch information
algoriddle authored and facebook-github-bot committed May 23, 2023
1 parent 5c221ed commit 1c1879b
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 4 deletions.
150 changes: 150 additions & 0 deletions faiss/gpu/GpuDistance.cu
Expand Up @@ -24,6 +24,7 @@
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/utils/DeviceUtils.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/Heap.h>
#include <faiss/gpu/impl/Distance.cuh>
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/CopyUtils.cuh>
Expand Down Expand Up @@ -368,6 +369,155 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
}
}

template <class C>
void bfKnn_shard_database(
GpuResourcesProvider* prov,
const GpuDistanceParams& args,
size_t shard_size,
size_t distance_size) {
std::vector<typename C::T> heaps_distances;
if (args.ignoreOutDistances) {
heaps_distances.resize(args.numQueries * args.k, 0);
}
HeapArray<C> heaps = {
(size_t)args.numQueries,
(size_t)args.k,
(typename C::TI*)args.outIndices,
args.ignoreOutDistances ? heaps_distances.data()
: args.outDistances};
heaps.heapify();
std::vector<typename C::TI> labels(args.numQueries * args.k);
std::vector<typename C::T> distances(args.numQueries * args.k);
GpuDistanceParams args_batch = args;
args_batch.outDistances = distances.data();
args_batch.ignoreOutDistances = false;
args_batch.outIndices = labels.data();
for (idx_t i = 0; i < args.numVectors; i += shard_size) {
args_batch.numVectors = min(shard_size, args.numVectors - i);
args_batch.vectors =
(char*)args.vectors + distance_size * args.dims * i;
args_batch.vectorNorms =
args.vectorNorms ? args.vectorNorms + i : nullptr;
bfKnn(prov, args_batch);
for (auto& label : labels) {
label += i;
}
heaps.addn_with_ids(args.k, distances.data(), labels.data(), args.k);
}
heaps.reorder();
}

void bfKnn_single_query_shard(
GpuResourcesProvider* prov,
const GpuDistanceParams& args,
size_t vectorsMemoryLimit) {
if (vectorsMemoryLimit == 0) {
bfKnn(prov, args);
return;
}
FAISS_THROW_IF_NOT_MSG(
args.numVectors > 0, "bfKnn_tiling: numVectors must be > 0");
FAISS_THROW_IF_NOT_MSG(
args.vectors,
"bfKnn_tiling: vectors must be provided (passed null)");
FAISS_THROW_IF_NOT_MSG(
getDeviceForAddress(args.vectors) == -1,
"bfKnn_tiling: vectors should be in CPU memory when vectorsMemoryLimit > 0");
FAISS_THROW_IF_NOT_MSG(
args.vectorsRowMajor,
"bfKnn_tiling: tiling vectors is only supported in row major mode");
FAISS_THROW_IF_NOT_MSG(
args.k > 0,
"bfKnn_tiling: tiling vectors is only supported for k > 0");
size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4
: args.vectorType == DistanceDataType::F16 ? 2
: 0;
FAISS_THROW_IF_NOT_MSG(
distance_size > 0, "bfKnn_tiling: unknown vectorType");
size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size);
FAISS_THROW_IF_NOT_MSG(
shard_size > 0, "bfKnn_tiling: vectorsMemoryLimit is too low");
if (args.numVectors <= shard_size) {
bfKnn(prov, args);
return;
}
if (is_similarity_metric(args.metric)) {
if (args.outIndicesType == IndicesDataType::I64) {
bfKnn_shard_database<CMin<float, int64_t>>(
prov, args, shard_size, distance_size);
} else if (args.outIndicesType == IndicesDataType::I32) {
bfKnn_shard_database<CMin<float, int32_t>>(
prov, args, shard_size, distance_size);
} else {
FAISS_THROW_MSG("bfKnn_tiling: unknown outIndicesType");
}
} else {
if (args.outIndicesType == IndicesDataType::I64) {
bfKnn_shard_database<CMax<float, int64_t>>(
prov, args, shard_size, distance_size);
} else if (args.outIndicesType == IndicesDataType::I32) {
bfKnn_shard_database<CMax<float, int32_t>>(
prov, args, shard_size, distance_size);
} else {
FAISS_THROW_MSG("bfKnn_tiling: unknown outIndicesType");
}
}
}

void bfKnn_tiling(
GpuResourcesProvider* prov,
const GpuDistanceParams& args,
size_t vectorsMemoryLimit,
size_t queriesMemoryLimit) {
if (queriesMemoryLimit == 0) {
bfKnn_single_query_shard(prov, args, vectorsMemoryLimit);
return;
}
FAISS_THROW_IF_NOT_MSG(
args.numQueries > 0, "bfKnn_tiling: numQueries must be > 0");
FAISS_THROW_IF_NOT_MSG(
args.queries,
"bfKnn_tiling: queries must be provided (passed null)");
FAISS_THROW_IF_NOT_MSG(
getDeviceForAddress(args.queries) == -1,
"bfKnn_tiling: queries should be in CPU memory when queriesMemoryLimit > 0");
FAISS_THROW_IF_NOT_MSG(
args.queriesRowMajor,
"bfKnn_tiling: tiling queries is only supported in row major mode");
FAISS_THROW_IF_NOT_MSG(
args.k > 0,
"bfKnn_tiling: tiling queries is only supported for k > 0");
size_t distance_size = args.queryType == DistanceDataType::F32 ? 4
: args.queryType == DistanceDataType::F16 ? 2
: 0;
FAISS_THROW_IF_NOT_MSG(
distance_size > 0, "bfKnn_tiling: unknown queryType");
size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8
: args.outIndicesType == IndicesDataType::I32 ? 4
: 0;
FAISS_THROW_IF_NOT_MSG(
distance_size > 0, "bfKnn_tiling: unknown outIndicesType");
size_t shard_size = queriesMemoryLimit /
(args.k * (distance_size + label_size) + args.dims * distance_size);
FAISS_THROW_IF_NOT_MSG(
shard_size > 0, "bfKnn_tiling: queriesMemoryLimit is too low");
FAISS_THROW_IF_NOT_MSG(
args.outIndices,
"bfKnn: outIndices must be provided (passed null)");
for (idx_t i = 0; i < args.numQueries; i += shard_size) {
GpuDistanceParams args_batch = args;
args_batch.numQueries = min(shard_size, args.numQueries - i);
args_batch.queries =
(char*)args.queries + distance_size * args.dims * i;
if (!args_batch.ignoreOutDistances) {
args_batch.outDistances = args.outDistances + args.k * i;
}
args_batch.outIndices =
(char*)args.outIndices + args.k * label_size * i;
bfKnn_single_query_shard(prov, args_batch, vectorsMemoryLimit);
}
}

// legacy version
void bruteForceKnn(
GpuResourcesProvider* res,
Expand Down
18 changes: 18 additions & 0 deletions faiss/gpu/GpuDistance.h
Expand Up @@ -123,6 +123,24 @@ struct GpuDistanceParams {
/// nearest neighbors with respect to the given metric
void bfKnn(GpuResourcesProvider* resources, const GpuDistanceParams& args);

// bfKnn which takes two extra parameters to control the maximum GPU
// memory allowed for vectors and queries, the latter including the
// memory required for the results.
// If 0, the corresponding input must fit into GPU memory.
// If greater than 0, the function will use at most this much GPU
// memory (in bytes) for vectors and queries respectively.
// Vectors are broken up into chunks of size vectorsMemoryLimit,
// and queries are broken up into chunks of size queriesMemoryLimit.
// The tiles resulting from the product of the query and vector
// chunks are processed sequentially on the GPU.
// Only supported for row major matrices and k > 0. The input that
// needs sharding must reside on the CPU.
void bfKnn_tiling(
GpuResourcesProvider* resources,
const GpuDistanceParams& args,
size_t vectorsMemoryLimit,
size_t queriesMemoryLimit);

/// Deprecated legacy implementation
void bruteForceKnn(
GpuResourcesProvider* resources,
Expand Down
30 changes: 28 additions & 2 deletions faiss/gpu/test/test_gpu_basics.py
Expand Up @@ -225,6 +225,14 @@ def make_t(num, d, clamp=False, seed=None):

class TestKnn(unittest.TestCase):
def test_input_types(self):
self.do_test_input_types(0, 0)

def test_input_types_tiling(self):
self.do_test_input_types(0, 500)
self.do_test_input_types(1000, 0)
self.do_test_input_types(1000, 500)

def do_test_input_types(self, vectorsMemoryLimit, queriesMemoryLimit):
d = 33
k = 5
nb = 1000
Expand All @@ -243,6 +251,8 @@ def test_input_types(self):
out_d = np.empty((nq, k), dtype=np.float32)
out_i = np.empty((nq, k), dtype=np.int64)

gpu_id = random.randrange(0, faiss.get_num_gpus())

# Try f32 data/queries, i64 out indices
params = faiss.GpuDistanceParams()
params.k = k
Expand All @@ -253,9 +263,24 @@ def test_input_types(self):
params.numQueries = nq
params.outDistances = faiss.swig_ptr(out_d)
params.outIndices = faiss.swig_ptr(out_i)
params.device = random.randrange(0, faiss.get_num_gpus())
params.device = gpu_id

if vectorsMemoryLimit > 0 or queriesMemoryLimit > 0:
faiss.bfKnn_tiling(
res,
params,
vectorsMemoryLimit,
queriesMemoryLimit)
else:
faiss.bfKnn(res, params)

faiss.bfKnn(res, params)
self.assertTrue(np.allclose(ref_d, out_d, atol=1e-5))
self.assertGreaterEqual((out_i == ref_i).sum(), ref_i.size)

out_d, out_i = faiss.knn_gpu(
res, qs, xs, k, device=gpu_id,
vectorsMemoryLimit=vectorsMemoryLimit,
queriesMemoryLimit=queriesMemoryLimit)

self.assertTrue(np.allclose(ref_d, out_d, atol=1e-5))
self.assertGreaterEqual((out_i == ref_i).sum(), ref_i.size)
Expand All @@ -266,6 +291,7 @@ def test_input_types(self):
params.outIndicesType = faiss.IndicesDataType_I32

faiss.bfKnn(res, params)

self.assertEqual((out_i32 == ref_i).sum(), ref_i.size)

# Try float16 data/queries, i64 out indices
Expand Down
15 changes: 13 additions & 2 deletions faiss/python/gpu_wrappers.py
Expand Up @@ -54,7 +54,7 @@ def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1):
# allows numpy ndarray usage with bfKnn


def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1, use_raft=False):
def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1, use_raft=False, vectorsMemoryLimit=0, queriesMemoryLimit=0):
"""
Compute the k nearest neighbors of a vector on one GPU without constructing an index
Expand Down Expand Up @@ -82,6 +82,14 @@ def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1, use_raf
(can also be set via torch.cuda.set_device in PyTorch)
Otherwise, an integer 0 <= device < numDevices indicates the GPU on which
the computation should be run
vectorsMemoryLimit: int, optional
queriesMemoryLimit: int, optional
Memory limits for vectors and queries.
If not 0, the GPU will use at most this amount of memory
for vectors and queries respectively.
Vectors are broken up into chunks of size vectorsMemoryLimit,
and queries are broken up into chunks of size queriesMemoryLimit,
including the memory required for the results.
Returns
-------
Expand Down Expand Up @@ -172,7 +180,10 @@ def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1, use_raf

# no stream synchronization needed, inputs and outputs are guaranteed to
# be on the CPU (numpy arrays)
bfKnn(res, args)
if vectorsMemoryLimit > 0 or queriesMemoryLimit > 0:
bfKnn_tiling(res, args, vectorsMemoryLimit, queriesMemoryLimit)
else:
bfKnn(res, args)

return D, I

Expand Down
2 changes: 2 additions & 0 deletions faiss/utils/Heap.cpp
Expand Up @@ -136,6 +136,8 @@ void HeapArray<C>::per_line_extrema(T* out_val, TI* out_ids) const {

template struct HeapArray<CMin<float, int64_t>>;
template struct HeapArray<CMax<float, int64_t>>;
template struct HeapArray<CMin<float, int32_t>>;
template struct HeapArray<CMax<float, int32_t>>;
template struct HeapArray<CMin<int, int64_t>>;
template struct HeapArray<CMax<int, int64_t>>;

Expand Down

0 comments on commit 1c1879b

Please sign in to comment.