Skip to content

Commit

Permalink
support range search from GPU (#2860)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2860

Optimized range search function where the GPU computes by default and falls back on gpu for queries where there are too many results.

Parallelize the CPU to GPU cloning, it seems to work.

Support range_search_preassigned in Python

Fix long-standing issue with SWIG exposed functions that did not release the GIL (in particular the MapLong2Long).

Adds a MapInt64ToInt64 that is more efficient than MapLong2Long.

Reviewed By: algoriddle

Differential Revision: D45672301

fbshipit-source-id: 2e77397c40083818584dbafa5427149359a2abfd
  • Loading branch information
mdouze authored and facebook-github-bot committed May 16, 2023
1 parent 54d331e commit b9ea339
Show file tree
Hide file tree
Showing 19 changed files with 711 additions and 181 deletions.
12 changes: 6 additions & 6 deletions contrib/evaluation.py
Expand Up @@ -226,7 +226,7 @@ def compute_PR_for(q):
# Functions that compare search results with a reference result.
# They are intended for use in tests

def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
""" test that knn search results are identical, raise if not """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
# here we have to be careful because of draws
Expand All @@ -243,14 +243,14 @@ def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))


def test_ref_range_results(lims_ref, Dref, Iref,
lims_new, Dnew, Inew):
def check_ref_range_results(Lref, Dref, Iref,
Lnew, Dnew, Inew):
""" compare range search results wrt. a reference result,
throw if it fails """
np.testing.assert_array_equal(lims_ref, lims_new)
nq = len(lims_ref) - 1
np.testing.assert_array_equal(Lref, Lnew)
nq = len(Lref) - 1
for i in range(nq):
l0, l1 = lims_ref[i], lims_ref[i + 1]
l0, l1 = Lref[i], Lref[i + 1]
Ii_ref = Iref[l0:l1]
Ii_new = Inew[l0:l1]
Di_ref = Dref[l0:l1]
Expand Down
102 changes: 68 additions & 34 deletions contrib/exhaustive_search.py
Expand Up @@ -11,7 +11,6 @@

LOG = logging.getLogger(__name__)


def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
"""Computes the exact KNN search results for a dataset that possibly
does not fit in RAM but for which we have an iterator that
Expand Down Expand Up @@ -51,47 +50,82 @@ def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):



def range_search_gpu(xq, r2, index_gpu, index_cpu):
def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
"""GPU does not support range search, so we emulate it with
knn search + fallback to CPU index.
The index_cpu can either be a CPU index or a numpy table that will
be used to construct a Flat index if needed.
The index_cpu can either be:
- a CPU index that supports range search
- a numpy table, that will be used to construct a Flat index if needed.
- None. In that case, at most gpu_k results will be returned
"""
nq, d = xq.shape
LOG.debug("GPU search %d queries" % nq)
k = min(index_gpu.ntotal, 1024)
k = min(index_gpu.ntotal, gpu_k)
keep_max = faiss.is_similarity_metric(index_gpu.metric_type)
LOG.debug(f"GPU search {nq} queries with {k=:}")
t0 = time.time()
D, I = index_gpu.search(xq, k)
if index_gpu.metric_type == faiss.METRIC_L2:
mask = D[:, k - 1] < r2
else:
mask = D[:, k - 1] > r2
if mask.sum() > 0:
LOG.debug("CPU search remain %d" % mask.sum())
if isinstance(index_cpu, np.ndarray):
# then it in fact an array that we have to make flat
xb = index_cpu
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
index_cpu.add(xb)
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
t1 = time.time() - t0
t2 = 0
lim_remain = None
if index_cpu is not None:
if not keep_max:
mask = D[:, k - 1] < r2
else:
mask = D[:, k - 1] > r2
if mask.sum() > 0:
LOG.debug("CPU search remain %d" % mask.sum())
t0 = time.time()
if isinstance(index_cpu, np.ndarray):
# then it in fact an array that we have to make flat
xb = index_cpu
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
index_cpu.add(xb)
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
t2 = time.time() - t0
LOG.debug("combine")
D_res, I_res = [], []
nr = 0
for i in range(nq):
if not mask[i]:
if index_gpu.metric_type == faiss.METRIC_L2:
nv = (D[i, :] < r2).sum()
t0 = time.time()

combiner = faiss.CombinerRangeKNN(nq, k, float(r2), keep_max)
if True:
sp = faiss.swig_ptr
combiner.I = sp(I)
combiner.D = sp(D)
# combiner.set_knn_result(sp(I), sp(D))
if lim_remain is not None:
combiner.mask = sp(mask)
combiner.D_remain = sp(D_remain)
combiner.lim_remain = sp(lim_remain.view("int64"))
combiner.I_remain = sp(I_remain)
# combiner.set_range_result(sp(mask), sp(lim_remain.view("int64")), sp(D_remain), sp(I_remain))
L_res = np.empty(nq + 1, dtype='int64')
combiner.compute_sizes(sp(L_res))
nres = L_res[-1]
D_res = np.empty(nres, dtype='float32')
I_res = np.empty(nres, dtype='int64')
combiner.write_result(sp(D_res), sp(I_res))
else:
D_res, I_res = [], []
nr = 0
for i in range(nq):
if not mask[i]:
if index_gpu.metric_type == faiss.METRIC_L2:
nv = (D[i, :] < r2).sum()
else:
nv = (D[i, :] > r2).sum()
D_res.append(D[i, :nv])
I_res.append(I[i, :nv])
else:
nv = (D[i, :] > r2).sum()
D_res.append(D[i, :nv])
I_res.append(I[i, :nv])
else:
l0, l1 = lim_remain[nr], lim_remain[nr + 1]
D_res.append(D_remain[l0:l1])
I_res.append(I_remain[l0:l1])
nr += 1
lims = np.cumsum([0] + [len(di) for di in D_res])
return lims, np.hstack(D_res), np.hstack(I_res)
l0, l1 = lim_remain[nr], lim_remain[nr + 1]
D_res.append(D_remain[l0:l1])
I_res.append(I_remain[l0:l1])
nr += 1
L_res = np.cumsum([0] + [len(di) for di in D_res])
D_res = np.hstack(D_res)
I_res = np.hstack(I_res)
t3 = time.time() - t0
LOG.debug(f"times {t1:.3f}s {t2:.3f}s {t3:.3f}s")
return L_res, D_res, I_res


def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,
Expand Down
2 changes: 1 addition & 1 deletion contrib/ivf_tools.py
Expand Up @@ -77,7 +77,7 @@ def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
res = faiss.RangeSearchResult(n)
sp = faiss.swig_ptr

index_ivf.range_search_preassigned(
index_ivf.range_search_preassigned_c(
n, sp(x), radius,
sp(list_nos), sp(coarse_dis),
res
Expand Down
1 change: 1 addition & 0 deletions faiss/gpu/GpuCloner.cpp
Expand Up @@ -309,6 +309,7 @@ Index* ToGpuClonerMultiple::clone_Index_to_shards(const Index* index) {

std::vector<faiss::Index*> shards(n);

#pragma omp parallel for
for (idx_t i = 0; i < n; i++) {
// make a shallow copy
if (reserveVecs) {
Expand Down
50 changes: 48 additions & 2 deletions faiss/gpu/test/test_contrib_gpu.py
Expand Up @@ -12,7 +12,7 @@

from faiss.contrib import datasets, evaluation, big_batch_search
from faiss.contrib.exhaustive_search import knn_ground_truth, \
range_ground_truth
range_ground_truth, range_search_gpu


class TestComputeGT(unittest.TestCase):
Expand Down Expand Up @@ -51,7 +51,7 @@ def do_test_range(self, metric):
xq, ds.database_iterator(bs=100), threshold,
metric_type=metric)

evaluation.test_ref_range_results(
evaluation.check_ref_range_results(
ref_lims, ref_D, ref_I,
new_lims, new_D, new_I
)
Expand Down Expand Up @@ -131,3 +131,49 @@ def knn_function(xq, xb, k, metric=faiss.METRIC_L2, thread_id=None):

def test_Flat(self):
self.do_test("IVF64,Flat")


class TestRangeSearchGpu(unittest.TestCase):

def do_test(self, factory_string):
ds = datasets.SyntheticDataset(32, 2000, 4000, 1000)
k = 10
index_gpu = faiss.index_cpu_to_all_gpus(
faiss.index_factory(ds.d, factory_string)
)
index_gpu.train(ds.get_train())
index_gpu.add(ds.get_database())
# just to find a reasonable threshold
D, _ = index_gpu.search(ds.get_queries(), k)
threshold = np.median(D[:, 5])

# ref run
index_cpu = faiss.index_gpu_to_cpu(index_gpu)
Lref, Dref, Iref = index_cpu.range_search(ds.get_queries(), threshold)
nres_per_query = Lref[1:] - Lref[:-1]
# make sure some entries were computed by CPU and some by GPU
assert np.any(nres_per_query > 4) and not np.all(nres_per_query > 4)

# mixed GPU / CPU run
Lnew, Dnew, Inew = range_search_gpu(
ds.get_queries(), threshold, index_gpu, index_cpu, gpu_k=4)
evaluation.check_ref_range_results(
Lref, Dref, Iref,
Lnew, Dnew, Inew
)

# also test the version without CPU search
Lnew2, Dnew2, Inew2 = range_search_gpu(
ds.get_queries(), threshold, index_gpu, None, gpu_k=4)
for q in range(ds.nq):
ref = Iref[Lref[q]:Lref[q+1]]
new = Inew2[Lnew2[q]:Lnew2[q+1]]
if nres_per_query[q] <= 4:
self.assertEqual(set(ref), set(new))
else:
ref = set(ref)
for v in new:
self.assertIn(v, ref)

def test_ivf(self):
self.do_test("IVF64,Flat")
2 changes: 1 addition & 1 deletion faiss/python/__init__.py
Expand Up @@ -22,7 +22,7 @@
from faiss.extra_wrappers import kmin, kmax, pairwise_distances, rand, randint, \
lrand, randn, rand_smooth_vectors, eval_intersection, normalize_L2, \
ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort, \
merge_knn_results
merge_knn_results, MapInt64ToInt64


__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,
Expand Down
97 changes: 96 additions & 1 deletion faiss/python/class_wrappers.py
Expand Up @@ -544,6 +544,7 @@ def replacement_range_search(self, x, thresh, *, params=None):
n, d = x.shape
assert d == self.d
x = np.ascontiguousarray(x, dtype='float32')
thresh = float(thresh)

res = RangeSearchResult(n)
self.range_search_c(n, swig_ptr(x), thresh, res, params)
Expand Down Expand Up @@ -618,6 +619,64 @@ def replacement_search_preassigned(self, x, k, Iq, Dq, *, params=None, D=None, I
)
return D, I

def replacement_range_search_preassigned(self, x, thresh, Iq, Dq, *, params=None):
"""Search vectors that are within a distance of the query vectors.
Parameters
----------
x : array_like
Query vectors, shape (n, d) where d is appropriate for the index.
`dtype` must be float32.
thresh : float
Threshold to select neighbors. All elements within this radius are returned,
except for maximum inner product indexes, where the elements above the
threshold are returned
Iq : array_like, optional
Nearest centroids, size (n, nprobe)
Dq : array_like, optional
Distance array to the centroids, size (n, nprobe)
params : SearchParameters
Search parameters of the current search (overrides the class-level params)
Returns
-------
lims: array_like
Starting index of the results for each query vector, size n+1.
D : array_like
Distances of the nearest neighbors, shape `lims[n]`. The distances for
query i are in `D[lims[i]:lims[i+1]]`.
I : array_like
Labels of nearest neighbors, shape `lims[n]`. The labels for query i
are in `I[lims[i]:lims[i+1]]`.
"""
n, d = x.shape
assert d == self.d
x = np.ascontiguousarray(x, dtype='float32')

Iq = np.ascontiguousarray(Iq, dtype='int64')
assert params is None, "params not supported"
assert Iq.shape == (n, self.nprobe)

if Dq is not None:
Dq = np.ascontiguousarray(Dq, dtype='float32')
assert Dq.shape == Iq.shape

thresh = float(thresh)
res = RangeSearchResult(n)
self.range_search_preassigned_c(
n, swig_ptr(x), thresh,
swig_ptr(Iq), swig_ptr(Dq),
res
)
# get pointers and copy them
lims = rev_swig_ptr(res.lims, n + 1).copy()
nd = int(lims[-1])
D = rev_swig_ptr(res.distances, nd).copy()
I = rev_swig_ptr(res.labels, nd).copy()
return lims, D, I

def replacement_sa_encode(self, x, codes=None):
n, d = x.shape
assert d == self.d
Expand Down Expand Up @@ -675,8 +734,12 @@ def replacement_permute_entries(self, perm):
ignore_missing=True)
replace_method(the_class, 'search_and_reconstruct',
replacement_search_and_reconstruct, ignore_missing=True)

# these ones are IVF-specific
replace_method(the_class, 'search_preassigned',
replacement_search_preassigned, ignore_missing=True)
replace_method(the_class, 'range_search_preassigned',
replacement_range_search_preassigned, ignore_missing=True)
replace_method(the_class, 'sa_encode', replacement_sa_encode)
replace_method(the_class, 'sa_decode', replacement_sa_decode)
replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes,
Expand Down Expand Up @@ -776,6 +839,36 @@ def replacement_range_search(self, x, thresh):
I = rev_swig_ptr(res.labels, nd).copy()
return lims, D, I

def replacement_range_search_preassigned(self, x, thresh, Iq, Dq, *, params=None):
n, d = x.shape
x = _check_dtype_uint8(x)
assert d * 8 == self.d

Iq = np.ascontiguousarray(Iq, dtype='int64')
assert params is None, "params not supported"
assert Iq.shape == (n, self.nprobe)

if Dq is not None:
Dq = np.ascontiguousarray(Dq, dtype='int32')
assert Dq.shape == Iq.shape

thresh = int(thresh)
res = RangeSearchResult(n)
self.range_search_preassigned_c(
n, swig_ptr(x), thresh,
swig_ptr(Iq), swig_ptr(Dq),
res
)
# get pointers and copy them
lims = rev_swig_ptr(res.lims, n + 1).copy()
nd = int(lims[-1])
D = rev_swig_ptr(res.distances, nd).copy()
I = rev_swig_ptr(res.labels, nd).copy()
return lims, D, I




def replacement_remove_ids(self, x):
if isinstance(x, IDSelector):
sel = x
Expand All @@ -794,6 +887,8 @@ def replacement_remove_ids(self, x):
replace_method(the_class, 'remove_ids', replacement_remove_ids)
replace_method(the_class, 'search_preassigned',
replacement_search_preassigned, ignore_missing=True)
replace_method(the_class, 'range_search_preassigned',
replacement_range_search_preassigned, ignore_missing=True)


def handle_VectorTransform(the_class):
Expand Down Expand Up @@ -937,7 +1032,7 @@ def handle_MapLong2Long(the_class):

def replacement_map_add(self, keys, vals):
n, = keys.shape
assert (n,) == keys.shape
assert (n,) == vals.shape
self.add_c(n, swig_ptr(keys), swig_ptr(vals))

def replacement_map_search_multiple(self, keys):
Expand Down

0 comments on commit b9ea339

Please sign in to comment.