Skip to content

Commit

Permalink
Fix IndexBinary.assign Python method
Browse files Browse the repository at this point in the history
Summary: Fixes #3343

Reviewed By: kuarora, junjieqi

Differential Revision: D56526842

fbshipit-source-id: b7c4377495db4e68283cf4ce2b7c8fae008cd404
  • Loading branch information
Amir Sadoughi authored and facebook-github-bot committed Apr 24, 2024
1 parent 2379b45 commit 03750f5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
34 changes: 34 additions & 0 deletions faiss/python/class_wrappers.py
Expand Up @@ -956,10 +956,44 @@ def replacement_remove_ids(self, x):
sel = IDSelectorBatch(x.size, swig_ptr(x))
return self.remove_ids_c(sel)

def replacement_assign(self, x, k, labels=None):
"""Find the k nearest neighbors of the set of vectors x in the index.
This is the same as the `search` method, but discards the distances.
Parameters
----------
x : array_like
Query vectors, shape (n, d) where d is appropriate for the index.
`dtype` must be uint8.
k : int
Number of nearest neighbors.
labels : array_like, optional
Labels array to store the results.
Returns
-------
labels: array_like
Labels of the nearest neighbors, shape (n, k).
When not enough results are found, the label is set to -1
"""
n, d = x.shape
x = _check_dtype_uint8(x)
assert d == self.code_size
assert k > 0

if labels is None:
labels = np.empty((n, k), dtype=np.int64)
else:
assert labels.shape == (n, k)

self.assign_c(n, swig_ptr(x), swig_ptr(labels), k)
return labels

replace_method(the_class, 'add', replacement_add)
replace_method(the_class, 'add_with_ids', replacement_add_with_ids)
replace_method(the_class, 'train', replacement_train)
replace_method(the_class, 'search', replacement_search)
replace_method(the_class, 'assign', replacement_assign)
replace_method(the_class, 'range_search', replacement_range_search)
replace_method(the_class, 'reconstruct', replacement_reconstruct)
replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_index_binary.py
Expand Up @@ -100,6 +100,9 @@ def test_flat(self):
index.add(self.xb)
D, I = index.search(self.xq, 3)

I2 = index.assign(x=self.xq, k=3, labels=None)
assert np.all(I == I2)

for i in range(nq):
for j, dj in zip(I[i], D[i]):
ref_dis = binary_dis(self.xq[i], self.xb[j])
Expand Down

0 comments on commit 03750f5

Please sign in to comment.