From 753a10960fd3361bc82feffa6335d8a2afc97abd Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 5 Oct 2023 14:46:56 -0700 Subject: [PATCH 01/46] start integration of cagra --- faiss/gpu/CMakeLists.txt | 4 +- faiss/gpu/impl/RaftCagra.cu | 134 +++++++++++++++++++++++++++++++++++ faiss/gpu/impl/RaftCagra.cuh | 98 +++++++++++++++++++++++++ 3 files changed, 235 insertions(+), 1 deletion(-) create mode 100644 faiss/gpu/impl/RaftCagra.cu create mode 100644 faiss/gpu/impl/RaftCagra.cuh diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index ad7d2103fa..b18209ca64 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -238,11 +238,13 @@ generate_ivf_interleaved_code() if(FAISS_ENABLE_RAFT) list(APPEND FAISS_GPU_HEADERS + impl/RaftCagra.cuh impl/RaftIVFFlat.cuh impl/RaftFlatIndex.cuh) list(APPEND FAISS_GPU_SRC + impl/RaftCagra.cu impl/RaftFlatIndex.cu - impl/RaftIVFFlat.cu) + impl/RaftIVFFlat.cu) target_compile_definitions(faiss PUBLIC USE_NVIDIA_RAFT=1) target_compile_definitions(faiss_avx2 PUBLIC USE_NVIDIA_RAFT=1) diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu new file mode 100644 index 0000000000..0520209b17 --- /dev/null +++ b/faiss/gpu/impl/RaftCagra.cu @@ -0,0 +1,134 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +namespace faiss { +namespace gpu { + +RaftCagra::RaftCagra( + GpuResources* resources, + int dim, + idx_t intermediate_graph_degree, + idx_t graph_degree, + faiss::cagra_build_algo graph_build_algo, + faiss::MetricType metric, + float metricArg) + : resources_(resources), + dim_(dim), + metric_(metric), + metricArg_(metricArg), + index_pams_() { + FAISS_THROW_IF_NOT_MSG( + metric == faiss::METRIC_L2, + "CAGRA currently only supports L2 metric."); + + index_pams_.intermediate_graph_degree = intermediate_graph_degree; + index_pams_.graph_degree = graph_degree; + index_pams_.build_algo = + static_cast( + graph_build_algo); +} + +void RaftCagra::train(idx_t n, const float* x) { + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + if (getDeviceForAddress(x) >= 0) { + raft_knn_index = raft::neighbors::cagra::build( + raft_handle, + index_pams_, + raft::make_device_matrix_view(x, n, dim_)); + } else { + raft_knn_index = raft::neighbors::cagra::build( + raft_handle, + index_pams_, + raft::make_host_matrix_view(x, n, dim_)); + } +} + +void RaftCagra::search( + Tensor& queries, + int k, + Tensor& outDistances, + Tensor& outIndices, + idx_t max_queries, + idx_t itopk_size, + idx_t max_iterations, + faiss::cagra_search_algo graph_search_algo, + idx_t team_size, + idx_t search_width, + idx_t min_iterations, + idx_t thread_block_size, + faiss::cagra_hash_mode hash_mode, + idx_t hashmap_min_bitlen, + float hashmap_max_fill_rate, + idx_t num_random_samplings, + idx_t rand_xor_mask) { + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + idx_t numQueries = queries.getSize(0); + idx_t cols = queries.getSize(1); + idx_t k_ = k; + + FAISS_ASSERT(raft_knn_index.has_value()); + FAISS_ASSERT(numQueries > 0); + FAISS_ASSERT(cols == dim_); + + auto queries_view = raft::make_device_matrix_view( + queries.data(), numQueries, cols); + auto distances_view = raft::make_device_matrix_view( + outDistances.data(), numQueries, k_); + auto indices_view = raft::make_device_matrix_view( + outIndices.data(), numQueries, k_); + + raft::neighbors::cagra::search_params search_pams; + search_pams.max_queries = max_queries; + search_pams.itopk_size = itopk_size; + search_pams.max_iterations = max_iterations; + search_pams.algo = + static_cast(graph_search_algo); + search_pams.team_size = team_size; + search_pams.search_width = search_width; + search_pams.min_iterations = min_iterations; + search_pams.thread_block_size = thread_block_size; + search_pams.hashmap_mode = + static_cast(hash_mode); + search_pams.hashmap_min_bitlen = hashmap_min_bitlen; + search_pams.hashmap_max_fill_rate = hashmap_max_fill_rate; + search_pams.num_random_samplings = num_random_samplings; + search_pams.rand_xor_mask = rand_xor_mask; + + raft::neighbors::cagra::search( + raft_handle, + search_pams, + raft_knn_index.value(), + queries_view, + indices_view, + distances_view); +} + +} // namespace gpu +} // namespace faiss \ No newline at end of file diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh new file mode 100644 index 0000000000..ccdffa28f0 --- /dev/null +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -0,0 +1,98 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +#include + +namespace faiss { + +enum class cagra_build_algo { IVF_PQ, NN_DESCENT }; + +enum class cagra_search_algo { SINGLE_CTA, MULTI_CTA }; + +enum class cagra_hash_mode { HASH, SMALL, AUTO }; + +namespace gpu { + +class RaftCagra { + public: + RaftCagra( + GpuResources* resources, + int dim, + idx_t intermediate_graph_degree, + idx_t graph_degree, + faiss::cagra_build_algo graph_build_algo, + faiss::MetricType metric, + float metricArg); + + ~RaftCagra() = default; + + void train(idx_t n, const float* x); + + void search( + Tensor& queries, + int k, + Tensor& outDistances, + Tensor& outIndices, + idx_t max_queries, + idx_t itopk_size, + idx_t max_iterations, + faiss::cagra_search_algo graph_search_algo, + idx_t team_size, + idx_t search_width, + idx_t min_iterations, + idx_t thread_block_size, + faiss::cagra_hash_mode hash_mode, + idx_t hashmap_min_bitlen, + float hashmap_max_fill_rate, + idx_t num_random_samplings, + idx_t rand_xor_mask); + + private: + /// Collection of GPU resources that we use + GpuResources* resources_; + + /// Expected dimensionality of the vectors + const int dim_; + + /// Metric type of the index + faiss::MetricType metric_; + + /// Metric arg + float metricArg_; + + /// Parameters to build RAFT CAGRA index + raft::neighbors::cagra::index_params index_pams_; + + /// Instance of trained RAFT CAGRA index + std::optional> raft_knn_index{ + std::nullopt}; +}; + +} // namespace gpu +} // namespace faiss \ No newline at end of file From f21c1f1c89ffc947249aae58ffddf52e78b0ea6f Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 30 Jan 2024 15:11:43 -0800 Subject: [PATCH 02/46] add public API layer --- faiss/gpu/CMakeLists.txt | 2 + faiss/gpu/GpuIndexCagra.cu | 84 ++++++++++++++++++++++++ faiss/gpu/GpuIndexCagra.h | 124 +++++++++++++++++++++++++++++++++++ faiss/gpu/impl/RaftCagra.cu | 4 ++ faiss/gpu/impl/RaftCagra.cuh | 4 +- 5 files changed, 217 insertions(+), 1 deletion(-) create mode 100644 faiss/gpu/GpuIndexCagra.cu create mode 100644 faiss/gpu/GpuIndexCagra.h diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index 2cdc7e8a19..b76a3a0fb3 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -29,6 +29,7 @@ set(FAISS_GPU_SRC GpuIndexIVFFlat.cu GpuIndexIVFPQ.cu GpuIndexIVFScalarQuantizer.cu + GpuIndexCagra.cu GpuResources.cpp StandardGpuResources.cpp impl/BinaryDistance.cu @@ -91,6 +92,7 @@ set(FAISS_GPU_HEADERS GpuFaissAssert.h GpuIndex.h GpuIndexBinaryFlat.h + GpuIndexCagra.h GpuIndexFlat.h GpuIndexIVF.h GpuIndexIVFFlat.h diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu new file mode 100644 index 0000000000..814ab3e9d7 --- /dev/null +++ b/faiss/gpu/GpuIndexCagra.cu @@ -0,0 +1,84 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "GpuIndexCagra.h" + +namespace faiss { +namespace gpu { + +GpuIndexCagra::GpuIndexCagra( + GpuResourcesProvider* provider, + int dims, + faiss::MetricType metric, + GpuIndexCagraConfig config) + : GpuIndex(provider->getResources(), dims, metric, 0.0f, config), + cagraConfig_(config) {} + +void GpuIndexCagra::train(idx_t n, const float* x) { + if (this->is_trained) { + FAISS_ASSERT(index_); + return; + } + + FAISS_ASSERT(!index_); + + index_ = std::make_shared( + this->resources_.get(), + this->d, + cagraConfig_.intermediate_graph_degree, + cagraConfig_.graph_degree, + static_cast(cagraConfig_.build_algo), + cagraConfig_.nn_descent_niter, + this->metric_type, + this->metric_arg); + + index_->train(n, x); + + this->is_trained = true; + this->ntotal += n; +} + +void GpuIndexCagra::searchImpl_( + idx_t n, + const float* x, + int k, + float* distances, + idx_t* labels, + const SearchParameters* search_params) const { + FAISS_ASSERT(this->is_trained && index_); + FAISS_ASSERT(n > 0); + + Tensor queries(const_cast(x), {n, this->d}); + Tensor outDistances(distances, {n, k}); + Tensor outLabels(const_cast(labels), {n, k}); + + auto params = dynamic_cast(search_params); + + index_->search( + queries, + k, + outDistances, + outLabels, + params->max_queries, + params->itopk_size, + params->max_iterations, + static_cast(params->algo), + params->team_size, + params->search_width, + params->min_iterations, + params->thread_block_size, + static_cast(params->hashmap_mode), + params->hashmap_min_bitlen, + params->hashmap_max_fill_rate, + params->num_random_samplings, + params->rand_xor_mask); +} + +} // namespace gpu +} // namespace faiss \ No newline at end of file diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h new file mode 100644 index 0000000000..0ce5090674 --- /dev/null +++ b/faiss/gpu/GpuIndexCagra.h @@ -0,0 +1,124 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace faiss { +namespace gpu { + +class RaftCagra; + +enum class graph_build_algo { + /* Use IVF-PQ to build all-neighbors knn graph */ + IVF_PQ, + /* Experimental, use NN-Descent to build all-neighbors knn graph */ + NN_DESCENT +}; + +struct GpuIndexCagraConfig : public GpuIndexConfig { + /** Degree of input graph for pruning. */ + size_t intermediate_graph_degree = 128; + /** Degree of output graph. */ + size_t graph_degree = 64; + /** ANN algorithm to build knn graph. */ + graph_build_algo build_algo = graph_build_algo::IVF_PQ; + /** Number of Iterations to run if building with NN_DESCENT */ + size_t nn_descent_niter = 20; +}; + +enum class search_algo { + /** For large batch sizes. */ + SINGLE_CTA, + /** For small batch sizes. */ + MULTI_CTA, + MULTI_KERNEL, + AUTO +}; + +enum class hash_mode { HASH, SMALL, AUTO }; + +struct SearchParametersCagra : SearchParameters { + /** Maximum number of queries to search at the same time (batch size). Auto + * select when 0.*/ + size_t max_queries = 0; + + /** Number of intermediate search results retained during the search. + * + * This is the main knob to adjust trade off between accuracy and search + * speed. Higher values improve the search accuracy. + */ + size_t itopk_size = 64; + + /** Upper limit of search iterations. Auto select when 0.*/ + size_t max_iterations = 0; + + // In the following we list additional search parameters for fine tuning. + // Reasonable default values are automatically chosen. + + /** Which search implementation to use. */ + search_algo algo = search_algo::AUTO; + + /** Number of threads used to calculate a single distance. 4, 8, 16, or 32. + */ + size_t team_size = 0; + + /** Number of graph nodes to select as the starting point for the search in + * each iteration. aka search width?*/ + size_t search_width = 1; + /** Lower limit of search iterations. */ + size_t min_iterations = 0; + + /** Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. */ + size_t thread_block_size = 0; + /** Hashmap type. Auto selection when AUTO. */ + hash_mode hashmap_mode = hash_mode::AUTO; + /** Lower limit of hashmap bit length. More than 8. */ + size_t hashmap_min_bitlen = 0; + /** Upper limit of hashmap fill rate. More than 0.1, less than 0.9.*/ + float hashmap_max_fill_rate = 0.5; + + /** Number of iterations of initial random seed node selection. 1 or more. + */ + uint32_t num_random_samplings = 1; + /** Bit mask used for initial random seed node selection. */ + uint64_t rand_xor_mask = 0x128394; +}; + +struct GpuIndexCagra : public GpuIndex { + public: + GpuIndexCagra( + GpuResourcesProvider* provider, + int dims, + faiss::MetricType metric = faiss::METRIC_L2, + GpuIndexCagraConfig config = GpuIndexCagraConfig()); + + ~GpuIndexCagra(); + + /// Trains CAGRA based on the given vector data + void train(idx_t n, const float* x) override; + + protected: + /// Called from GpuIndex for search + void searchImpl_( + idx_t n, + const float* x, + int k, + float* distances, + idx_t* labels, + const SearchParameters* search_params) const override; + + /// Our configuration options + const GpuIndexCagraConfig cagraConfig_; + + /// Instance that we own; contains the inverted lists + std::shared_ptr index_; +}; + +} // namespace gpu +} // namespace faiss \ No newline at end of file diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu index 0520209b17..972c95c576 100644 --- a/faiss/gpu/impl/RaftCagra.cu +++ b/faiss/gpu/impl/RaftCagra.cu @@ -21,10 +21,12 @@ */ #include +#include #include #include #include +#include namespace faiss { namespace gpu { @@ -35,6 +37,7 @@ RaftCagra::RaftCagra( idx_t intermediate_graph_degree, idx_t graph_degree, faiss::cagra_build_algo graph_build_algo, + size_t nn_descent_niter, faiss::MetricType metric, float metricArg) : resources_(resources), @@ -51,6 +54,7 @@ RaftCagra::RaftCagra( index_pams_.build_algo = static_cast( graph_build_algo); + index_pams_.nn_descent_niter = nn_descent_niter; } void RaftCagra::train(idx_t n, const float* x) { diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index ccdffa28f0..cbe9fde857 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -23,11 +23,12 @@ #pragma once #include +#include #include #include -#include +#include namespace faiss { @@ -47,6 +48,7 @@ class RaftCagra { idx_t intermediate_graph_degree, idx_t graph_degree, faiss::cagra_build_algo graph_build_algo, + size_t nn_descent_niter, faiss::MetricType metric, float metricArg); From 656f493f45b702121bc658d5fb10bb628bb9668c Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 31 Jan 2024 18:49:34 -0800 Subject: [PATCH 03/46] write tests, figure out a way to compare --- faiss/gpu/GpuIndexCagra.cu | 36 ++++++++- faiss/gpu/GpuIndexCagra.h | 24 +++++- faiss/gpu/impl/RaftCagra.cu | 10 ++- faiss/gpu/impl/RaftCagra.cuh | 8 +- faiss/gpu/test/CMakeLists.txt | 4 +- faiss/gpu/test/TestGpuIndexCagra.cpp | 115 +++++++++++++++++++++++++++ 6 files changed, 184 insertions(+), 13 deletions(-) create mode 100644 faiss/gpu/test/TestGpuIndexCagra.cpp diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index 814ab3e9d7..db266ddbed 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -4,6 +4,21 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include @@ -18,7 +33,9 @@ GpuIndexCagra::GpuIndexCagra( faiss::MetricType metric, GpuIndexCagraConfig config) : GpuIndex(provider->getResources(), dims, metric, 0.0f, config), - cagraConfig_(config) {} + cagraConfig_(config) { + this->is_trained = false; + } void GpuIndexCagra::train(idx_t n, const float* x) { if (this->is_trained) { @@ -36,7 +53,8 @@ void GpuIndexCagra::train(idx_t n, const float* x) { static_cast(cagraConfig_.build_algo), cagraConfig_.nn_descent_niter, this->metric_type, - this->metric_arg); + this->metric_arg, + faiss::gpu::INDICES_64_BIT); index_->train(n, x); @@ -58,7 +76,13 @@ void GpuIndexCagra::searchImpl_( Tensor outDistances(distances, {n, k}); Tensor outLabels(const_cast(labels), {n, k}); - auto params = dynamic_cast(search_params); + SearchParametersCagra* params; + if (search_params) { + params = dynamic_cast(const_cast(search_params)); + } + else { + params = new SearchParametersCagra{}; + } index_->search( queries, @@ -78,7 +102,11 @@ void GpuIndexCagra::searchImpl_( params->hashmap_max_fill_rate, params->num_random_samplings, params->rand_xor_mask); + + if (not search_params) { + delete params; + } } } // namespace gpu -} // namespace faiss \ No newline at end of file +} // namespace faiss diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 0ce5090674..c17183635f 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -4,6 +4,21 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once @@ -98,12 +113,17 @@ struct GpuIndexCagra : public GpuIndex { faiss::MetricType metric = faiss::METRIC_L2, GpuIndexCagraConfig config = GpuIndexCagraConfig()); - ~GpuIndexCagra(); + ~GpuIndexCagra() {} /// Trains CAGRA based on the given vector data void train(idx_t n, const float* x) override; + void reset() {} + protected: + bool addImplRequiresIDs_() const {} + + void addImpl_(idx_t n, const float* x, const idx_t* ids) {} /// Called from GpuIndex for search void searchImpl_( idx_t n, @@ -121,4 +141,4 @@ struct GpuIndexCagra : public GpuIndex { }; } // namespace gpu -} // namespace faiss \ No newline at end of file +} // namespace faiss diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu index 972c95c576..6253213fde 100644 --- a/faiss/gpu/impl/RaftCagra.cu +++ b/faiss/gpu/impl/RaftCagra.cu @@ -5,7 +5,7 @@ * LICENSE file in the root directory of this source tree. */ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,8 @@ RaftCagra::RaftCagra( faiss::cagra_build_algo graph_build_algo, size_t nn_descent_niter, faiss::MetricType metric, - float metricArg) + float metricArg, + IndicesOptions indicesOptions) : resources_(resources), dim_(dim), metric_(metric), @@ -48,6 +49,9 @@ RaftCagra::RaftCagra( FAISS_THROW_IF_NOT_MSG( metric == faiss::METRIC_L2, "CAGRA currently only supports L2 metric."); + FAISS_THROW_IF_NOT_MSG( + indicesOptions == faiss::gpu::INDICES_64_BIT, + "only INDICES_64_BIT is supported for RAFT CAGRA index"); index_pams_.intermediate_graph_degree = intermediate_graph_degree; index_pams_.graph_degree = graph_degree; @@ -135,4 +139,4 @@ void RaftCagra::search( } } // namespace gpu -} // namespace faiss \ No newline at end of file +} // namespace faiss diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index cbe9fde857..f92d04e38d 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -5,7 +5,7 @@ * LICENSE file in the root directory of this source tree. */ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -50,7 +51,8 @@ class RaftCagra { faiss::cagra_build_algo graph_build_algo, size_t nn_descent_niter, faiss::MetricType metric, - float metricArg); + float metricArg, + IndicesOptions indicesOptions); ~RaftCagra() = default; @@ -97,4 +99,4 @@ class RaftCagra { }; } // namespace gpu -} // namespace faiss \ No newline at end of file +} // namespace faiss diff --git a/faiss/gpu/test/CMakeLists.txt b/faiss/gpu/test/CMakeLists.txt index 9300deead9..4b654f534d 100644 --- a/faiss/gpu/test/CMakeLists.txt +++ b/faiss/gpu/test/CMakeLists.txt @@ -21,7 +21,6 @@ find_package(CUDAToolkit REQUIRED) # Defines `gtest_discover_tests()`. include(GoogleTest) - add_library(faiss_gpu_test_helper TestUtils.cpp) target_link_libraries(faiss_gpu_test_helper PUBLIC faiss gtest CUDA::cudart $<$:raft::raft> $<$:raft::compiled>) @@ -42,6 +41,9 @@ faiss_gpu_test(TestGpuIndexIVFPQ.cpp) faiss_gpu_test(TestGpuIndexIVFScalarQuantizer.cpp) faiss_gpu_test(TestGpuDistance.cu) faiss_gpu_test(TestGpuSelect.cu) +if(FAISS_ENABLE_RAFT) + faiss_gpu_test(TestGpuIndexCagra.cpp) +endif() add_executable(demo_ivfpq_indexing_gpu EXCLUDE_FROM_ALL demo_ivfpq_indexing_gpu.cpp) diff --git a/faiss/gpu/test/TestGpuIndexCagra.cpp b/faiss/gpu/test/TestGpuIndexCagra.cpp new file mode 100644 index 0000000000..1179d3a3cb --- /dev/null +++ b/faiss/gpu/test/TestGpuIndexCagra.cpp @@ -0,0 +1,115 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "faiss/MetricType.h" + +struct Options { + Options() { + numTrain = 2 * faiss::gpu::randVal(2000, 5000); + dim = faiss::gpu::randVal(64, 200); + + graphDegree = faiss::gpu::randSelect({16, 32}); + intermediateGraphDegree = faiss::gpu::randSelect({32, 64}); + buildAlgo = faiss::gpu::randSelect( + {faiss::gpu::graph_build_algo::IVF_PQ, faiss::gpu::graph_build_algo::NN_DESCENT}); + + numQuery = faiss::gpu::randVal(32, 100); + k = faiss::gpu::randVal(10, 30); + + device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + } + + std::string toString() const { + std::stringstream str; + str << "CAGRA device " << device << " numVecs " << numTrain << " dim " + << dim << " graphDegree " << graphDegree << " intermediateGraphDegree " << intermediateGraphDegree + << "buildAlgo " << static_cast(buildAlgo) + << " numQuery " << numQuery << " k " << k; + + return str.str(); + } + + int numTrain; + int dim; + size_t graphDegree; + size_t intermediateGraphDegree; + faiss::gpu::graph_build_algo buildAlgo; + int numQuery; + int k; + int device; +}; + +void queryTest() { + for (int tries = 0; tries < 2; ++tries) { + Options opt; + + std::vector trainVecs = + faiss::gpu::randVecs(opt.numTrain, opt.dim); + + faiss::IndexHNSWFlat cpuIndex( + opt.dim, opt.graphDegree / 2); + cpuIndex.train(opt.numTrain, trainVecs.data()); + cpuIndex.add(opt.numTrain, trainVecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexCagraConfig config; + config.device = opt.device; + config.graph_degree = opt.graphDegree; + config.intermediate_graph_degree = opt.intermediateGraphDegree; + config.build_algo = opt.buildAlgo; + + faiss::gpu::GpuIndexCagra gpuIndex( + &res, cpuIndex.d, faiss::METRIC_L2, config); + gpuIndex.train(opt.numTrain, trainVecs.data()); + + faiss::gpu::compareIndices( + cpuIndex, + gpuIndex, + opt.numQuery, + opt.dim, + opt.k, + opt.toString(), + 0.15f, + 1.0f, + 0.15f); + } +} + +TEST(TestGpuIndexCagra, Float32_Query_L2) { + queryTest(); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} From ed32954e13b754e56c9d040f28fa91282b4d8835 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 7 Feb 2024 14:01:41 -0800 Subject: [PATCH 04/46] passing tests --- faiss/gpu/test/CMakeLists.txt | 2 +- faiss/gpu/test/TestGpuIndexCagra.cu | 159 ++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 faiss/gpu/test/TestGpuIndexCagra.cu diff --git a/faiss/gpu/test/CMakeLists.txt b/faiss/gpu/test/CMakeLists.txt index 4b654f534d..60f78ef74f 100644 --- a/faiss/gpu/test/CMakeLists.txt +++ b/faiss/gpu/test/CMakeLists.txt @@ -42,7 +42,7 @@ faiss_gpu_test(TestGpuIndexIVFScalarQuantizer.cpp) faiss_gpu_test(TestGpuDistance.cu) faiss_gpu_test(TestGpuSelect.cu) if(FAISS_ENABLE_RAFT) - faiss_gpu_test(TestGpuIndexCagra.cpp) + faiss_gpu_test(TestGpuIndexCagra.cu) endif() add_executable(demo_ivfpq_indexing_gpu EXCLUDE_FROM_ALL diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu new file mode 100644 index 0000000000..90215c07f3 --- /dev/null +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -0,0 +1,159 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +struct Options { + Options() { + numTrain = 2 * faiss::gpu::randVal(2000, 5000); + dim = faiss::gpu::randVal(4, 10); + + graphDegree = faiss::gpu::randSelect({32, 64}); + intermediateGraphDegree = faiss::gpu::randSelect({64, 98}); + buildAlgo = faiss::gpu::randSelect( + {faiss::gpu::graph_build_algo::IVF_PQ, faiss::gpu::graph_build_algo::NN_DESCENT}); + + numQuery = faiss::gpu::randVal(32, 100); + k = faiss::gpu::randVal(10, 30); + + device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + } + + std::string toString() const { + std::stringstream str; + str << "CAGRA device " << device << " numVecs " << numTrain << " dim " + << dim << " graphDegree " << graphDegree << " intermediateGraphDegree " << intermediateGraphDegree + << "buildAlgo " << static_cast(buildAlgo) + << " numQuery " << numQuery << " k " << k; + + return str.str(); + } + + int numTrain; + int dim; + size_t graphDegree; + size_t intermediateGraphDegree; + faiss::gpu::graph_build_algo buildAlgo; + int numQuery; + int k; + int device; +}; + +void queryTest() { + for (int tries = 0; tries < 2; ++tries) { + Options opt; + + std::vector trainVecs = + faiss::gpu::randVecs(opt.numTrain, opt.dim); + + faiss::IndexHNSWFlat cpuIndex( + opt.dim, opt.graphDegree / 2); + cpuIndex.hnsw.efConstruction = opt.k * 2; + cpuIndex.train(opt.numTrain, trainVecs.data()); + cpuIndex.add(opt.numTrain, trainVecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexCagraConfig config; + config.device = opt.device; + config.graph_degree = opt.graphDegree; + config.intermediate_graph_degree = opt.intermediateGraphDegree; + config.build_algo = opt.buildAlgo; + + faiss::gpu::GpuIndexCagra gpuIndex( + &res, cpuIndex.d, faiss::METRIC_L2, config); + gpuIndex.train(opt.numTrain, trainVecs.data()); + + auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + + std::vector refDistance(opt.numQuery * opt.k, 0); + std::vector refIndices(opt.numQuery * opt.k, -1); + faiss::SearchParametersHNSW cpuSearchParams; + cpuSearchParams.efSearch = opt.k * 2; + cpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + refDistance.data(), + refIndices.data(), + &cpuSearchParams); + + auto gpuRes = res.getResources(); + auto devAlloc = faiss::gpu::makeDevAlloc(faiss::gpu::AllocType::FlatData, gpuRes->getDefaultStreamCurrentDevice()); + faiss::gpu::DeviceTensor testDistance(gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + faiss::gpu::DeviceTensor testIndices(gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + gpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + testDistance.data(), + testIndices.data()); + + auto refDistanceDev = faiss::gpu::toDeviceTemporary(gpuRes.get(), refDistance, gpuRes->getDefaultStreamCurrentDevice()); + auto refIndicesDev = faiss::gpu::toDeviceTemporary(gpuRes.get(), refIndices, gpuRes->getDefaultStreamCurrentDevice()); + + auto raft_handle = gpuRes->getRaftHandleCurrentDevice(); + + auto ref_dis_mds = raft::make_device_matrix_view(refDistanceDev.data(), opt.numQuery, opt.k); + auto ref_dis_mds_opt = std::optional>(ref_dis_mds); + auto ref_ind_mds = raft::make_device_matrix_view(refIndicesDev.data(), opt.numQuery, opt.k); + + auto test_dis_mds = raft::make_device_matrix_view(testDistance.data(), opt.numQuery, opt.k); + auto test_dis_mds_opt = std::optional>(test_dis_mds); + + auto test_ind_mds = raft::make_device_matrix_view(testIndices.data(), opt.numQuery, opt.k); + + double scalar_init = 0; + auto recall_score = raft::make_host_scalar(scalar_init); + + raft::stats::neighborhood_recall(raft_handle, test_ind_mds, ref_ind_mds, recall_score.view(), test_dis_mds_opt, ref_dis_mds_opt); + ASSERT_TRUE(*recall_score.data_handle() > 0.98); + + } +} + +TEST(TestGpuIndexCagra, Float32_Query_L2) { + queryTest(); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} From 42ca86227937d2e5f3add800c745f8483f939dec Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 7 Feb 2024 14:04:14 -0800 Subject: [PATCH 05/46] remove cpp test file --- faiss/gpu/test/TestGpuIndexCagra.cpp | 115 --------------------------- 1 file changed, 115 deletions(-) delete mode 100644 faiss/gpu/test/TestGpuIndexCagra.cpp diff --git a/faiss/gpu/test/TestGpuIndexCagra.cpp b/faiss/gpu/test/TestGpuIndexCagra.cpp deleted file mode 100644 index 1179d3a3cb..0000000000 --- a/faiss/gpu/test/TestGpuIndexCagra.cpp +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include "faiss/MetricType.h" - -struct Options { - Options() { - numTrain = 2 * faiss::gpu::randVal(2000, 5000); - dim = faiss::gpu::randVal(64, 200); - - graphDegree = faiss::gpu::randSelect({16, 32}); - intermediateGraphDegree = faiss::gpu::randSelect({32, 64}); - buildAlgo = faiss::gpu::randSelect( - {faiss::gpu::graph_build_algo::IVF_PQ, faiss::gpu::graph_build_algo::NN_DESCENT}); - - numQuery = faiss::gpu::randVal(32, 100); - k = faiss::gpu::randVal(10, 30); - - device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); - } - - std::string toString() const { - std::stringstream str; - str << "CAGRA device " << device << " numVecs " << numTrain << " dim " - << dim << " graphDegree " << graphDegree << " intermediateGraphDegree " << intermediateGraphDegree - << "buildAlgo " << static_cast(buildAlgo) - << " numQuery " << numQuery << " k " << k; - - return str.str(); - } - - int numTrain; - int dim; - size_t graphDegree; - size_t intermediateGraphDegree; - faiss::gpu::graph_build_algo buildAlgo; - int numQuery; - int k; - int device; -}; - -void queryTest() { - for (int tries = 0; tries < 2; ++tries) { - Options opt; - - std::vector trainVecs = - faiss::gpu::randVecs(opt.numTrain, opt.dim); - - faiss::IndexHNSWFlat cpuIndex( - opt.dim, opt.graphDegree / 2); - cpuIndex.train(opt.numTrain, trainVecs.data()); - cpuIndex.add(opt.numTrain, trainVecs.data()); - - faiss::gpu::StandardGpuResources res; - res.noTempMemory(); - - faiss::gpu::GpuIndexCagraConfig config; - config.device = opt.device; - config.graph_degree = opt.graphDegree; - config.intermediate_graph_degree = opt.intermediateGraphDegree; - config.build_algo = opt.buildAlgo; - - faiss::gpu::GpuIndexCagra gpuIndex( - &res, cpuIndex.d, faiss::METRIC_L2, config); - gpuIndex.train(opt.numTrain, trainVecs.data()); - - faiss::gpu::compareIndices( - cpuIndex, - gpuIndex, - opt.numQuery, - opt.dim, - opt.k, - opt.toString(), - 0.15f, - 1.0f, - 0.15f); - } -} - -TEST(TestGpuIndexCagra, Float32_Query_L2) { - queryTest(); -} - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - - // just run with a fixed test seed - faiss::gpu::setTestSeed(100); - - return RUN_ALL_TESTS(); -} From 2c9e965d8ae88bd06b898afde4e1e056555698eb Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 7 Feb 2024 14:24:20 -0800 Subject: [PATCH 06/46] style check --- faiss/gpu/GpuIndexCagra.cu | 10 ++-- faiss/gpu/GpuIndexCagra.h | 4 +- faiss/gpu/impl/RaftCagra.cuh | 2 +- faiss/gpu/test/TestGpuIndexCagra.cu | 72 ++++++++++++++++++++--------- 4 files changed, 57 insertions(+), 31 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index db266ddbed..b0a60268d3 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -34,8 +34,8 @@ GpuIndexCagra::GpuIndexCagra( GpuIndexCagraConfig config) : GpuIndex(provider->getResources(), dims, metric, 0.0f, config), cagraConfig_(config) { - this->is_trained = false; - } + this->is_trained = false; +} void GpuIndexCagra::train(idx_t n, const float* x) { if (this->is_trained) { @@ -78,9 +78,9 @@ void GpuIndexCagra::searchImpl_( SearchParametersCagra* params; if (search_params) { - params = dynamic_cast(const_cast(search_params)); - } - else { + params = dynamic_cast( + const_cast(search_params)); + } else { params = new SearchParametersCagra{}; } diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index c17183635f..a812ebafee 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -121,9 +121,9 @@ struct GpuIndexCagra : public GpuIndex { void reset() {} protected: - bool addImplRequiresIDs_() const {} + bool addImplRequiresIDs_() const {} - void addImpl_(idx_t n, const float* x, const idx_t* ids) {} + void addImpl_(idx_t n, const float* x, const idx_t* ids) {} /// Called from GpuIndex for search void searchImpl_( idx_t n, diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index f92d04e38d..5783cbf706 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -22,10 +22,10 @@ #pragma once +#include #include #include #include -#include #include diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index 90215c07f3..3a99ba35f0 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -21,14 +21,14 @@ */ #include +#include #include #include #include #include +#include #include #include -#include -#include #include #include @@ -43,7 +43,8 @@ struct Options { graphDegree = faiss::gpu::randSelect({32, 64}); intermediateGraphDegree = faiss::gpu::randSelect({64, 98}); buildAlgo = faiss::gpu::randSelect( - {faiss::gpu::graph_build_algo::IVF_PQ, faiss::gpu::graph_build_algo::NN_DESCENT}); + {faiss::gpu::graph_build_algo::IVF_PQ, + faiss::gpu::graph_build_algo::NN_DESCENT}); numQuery = faiss::gpu::randVal(32, 100); k = faiss::gpu::randVal(10, 30); @@ -54,9 +55,10 @@ struct Options { std::string toString() const { std::stringstream str; str << "CAGRA device " << device << " numVecs " << numTrain << " dim " - << dim << " graphDegree " << graphDegree << " intermediateGraphDegree " << intermediateGraphDegree - << "buildAlgo " << static_cast(buildAlgo) - << " numQuery " << numQuery << " k " << k; + << dim << " graphDegree " << graphDegree + << " intermediateGraphDegree " << intermediateGraphDegree + << "buildAlgo " << static_cast(buildAlgo) << " numQuery " + << numQuery << " k " << k; return str.str(); } @@ -78,8 +80,7 @@ void queryTest() { std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); - faiss::IndexHNSWFlat cpuIndex( - opt.dim, opt.graphDegree / 2); + faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2); cpuIndex.hnsw.efConstruction = opt.k * 2; cpuIndex.train(opt.numTrain, trainVecs.data()); cpuIndex.add(opt.numTrain, trainVecs.data()); @@ -112,9 +113,13 @@ void queryTest() { &cpuSearchParams); auto gpuRes = res.getResources(); - auto devAlloc = faiss::gpu::makeDevAlloc(faiss::gpu::AllocType::FlatData, gpuRes->getDefaultStreamCurrentDevice()); - faiss::gpu::DeviceTensor testDistance(gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); - faiss::gpu::DeviceTensor testIndices(gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + auto devAlloc = faiss::gpu::makeDevAlloc( + faiss::gpu::AllocType::FlatData, + gpuRes->getDefaultStreamCurrentDevice()); + faiss::gpu::DeviceTensor testDistance( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + faiss::gpu::DeviceTensor testIndices( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); gpuIndex.search( opt.numQuery, queryVecs.data(), @@ -122,26 +127,47 @@ void queryTest() { testDistance.data(), testIndices.data()); - auto refDistanceDev = faiss::gpu::toDeviceTemporary(gpuRes.get(), refDistance, gpuRes->getDefaultStreamCurrentDevice()); - auto refIndicesDev = faiss::gpu::toDeviceTemporary(gpuRes.get(), refIndices, gpuRes->getDefaultStreamCurrentDevice()); + auto refDistanceDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + refDistance, + gpuRes->getDefaultStreamCurrentDevice()); + auto refIndicesDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + refIndices, + gpuRes->getDefaultStreamCurrentDevice()); auto raft_handle = gpuRes->getRaftHandleCurrentDevice(); - auto ref_dis_mds = raft::make_device_matrix_view(refDistanceDev.data(), opt.numQuery, opt.k); - auto ref_dis_mds_opt = std::optional>(ref_dis_mds); - auto ref_ind_mds = raft::make_device_matrix_view(refIndicesDev.data(), opt.numQuery, opt.k); - - auto test_dis_mds = raft::make_device_matrix_view(testDistance.data(), opt.numQuery, opt.k); - auto test_dis_mds_opt = std::optional>(test_dis_mds); - - auto test_ind_mds = raft::make_device_matrix_view(testIndices.data(), opt.numQuery, opt.k); + auto ref_dis_mds = raft::make_device_matrix_view( + refDistanceDev.data(), opt.numQuery, opt.k); + auto ref_dis_mds_opt = + std::optional>( + ref_dis_mds); + auto ref_ind_mds = + raft::make_device_matrix_view( + refIndicesDev.data(), opt.numQuery, opt.k); + + auto test_dis_mds = raft::make_device_matrix_view( + testDistance.data(), opt.numQuery, opt.k); + auto test_dis_mds_opt = + std::optional>( + test_dis_mds); + + auto test_ind_mds = + raft::make_device_matrix_view( + testIndices.data(), opt.numQuery, opt.k); double scalar_init = 0; auto recall_score = raft::make_host_scalar(scalar_init); - raft::stats::neighborhood_recall(raft_handle, test_ind_mds, ref_ind_mds, recall_score.view(), test_dis_mds_opt, ref_dis_mds_opt); + raft::stats::neighborhood_recall( + raft_handle, + test_ind_mds, + ref_ind_mds, + recall_score.view(), + test_dis_mds_opt, + ref_dis_mds_opt); ASSERT_TRUE(*recall_score.data_handle() > 0.98); - } } From 2e434feb1f5af848591d5792d1f56ca4ebb1ceea Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 7 Feb 2024 15:54:07 -0800 Subject: [PATCH 07/46] add required methods --- faiss/gpu/GpuIndexCagra.cu | 21 ++++++++++++++++++++- faiss/gpu/GpuIndexCagra.h | 9 +++++---- faiss/gpu/impl/RaftCagra.cu | 6 ++++++ faiss/gpu/impl/RaftCagra.cuh | 2 ++ faiss/gpu/test/TestGpuIndexCagra.cu | 1 + 5 files changed, 34 insertions(+), 5 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index b0a60268d3..1a8eb382c0 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -59,9 +59,17 @@ void GpuIndexCagra::train(idx_t n, const float* x) { index_->train(n, x); this->is_trained = true; - this->ntotal += n; + this->ntotal = n; } +bool GpuIndexCagra::addImplRequiresIDs_() const { + return false; +}; + +void GpuIndexCagra::addImpl_(idx_t n, const float* x, const idx_t* ids) { + FAISS_THROW_MSG("adding vectors is not supported by GpuIndexCagra."); +}; + void GpuIndexCagra::searchImpl_( idx_t n, const float* x, @@ -108,5 +116,16 @@ void GpuIndexCagra::searchImpl_( } } +void GpuIndexCagra::reset() { + DeviceScope scope(config_.device); + + if (index_) { + index_->reset(); + this->ntotal = 0; + } else { + FAISS_ASSERT(this->ntotal == 0); + } +} + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index a812ebafee..2c31ab9f59 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -113,17 +113,18 @@ struct GpuIndexCagra : public GpuIndex { faiss::MetricType metric = faiss::METRIC_L2, GpuIndexCagraConfig config = GpuIndexCagraConfig()); - ~GpuIndexCagra() {} + ~GpuIndexCagra() override = default; /// Trains CAGRA based on the given vector data void train(idx_t n, const float* x) override; - void reset() {} + void reset() override; protected: - bool addImplRequiresIDs_() const {} + bool addImplRequiresIDs_() const override; + + void addImpl_(idx_t n, const float* x, const idx_t* ids) override; - void addImpl_(idx_t n, const float* x, const idx_t* ids) {} /// Called from GpuIndex for search void searchImpl_( idx_t n, diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu index 6253213fde..c0f7bbba69 100644 --- a/faiss/gpu/impl/RaftCagra.cu +++ b/faiss/gpu/impl/RaftCagra.cu @@ -59,6 +59,8 @@ RaftCagra::RaftCagra( static_cast( graph_build_algo); index_pams_.nn_descent_niter = nn_descent_niter; + + reset(); } void RaftCagra::train(idx_t n, const float* x) { @@ -138,5 +140,9 @@ void RaftCagra::search( distances_view); } +void RaftCagra::reset() { + raft_knn_index.reset(); +} + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index 5783cbf706..7f2b8b485c 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -77,6 +77,8 @@ class RaftCagra { idx_t num_random_samplings, idx_t rand_xor_mask); + void reset(); + private: /// Collection of GPU resources that we use GpuResources* resources_; diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index 3a99ba35f0..8ba11e63ac 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -82,6 +82,7 @@ void queryTest() { faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2); cpuIndex.hnsw.efConstruction = opt.k * 2; + // Training IndexHNSW is a no-op cpuIndex.train(opt.numTrain, trainVecs.data()); cpuIndex.add(opt.numTrain, trainVecs.data()); From 382c178cde27b13d8d7ccfe0af29ca713fe9295f Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 8 Feb 2024 12:11:44 -0800 Subject: [PATCH 08/46] conditionally compile cagra --- faiss/gpu/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index 1bc08826ee..2efa622fe3 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -29,7 +29,7 @@ set(FAISS_GPU_SRC GpuIndexIVFFlat.cu GpuIndexIVFPQ.cu GpuIndexIVFScalarQuantizer.cu - GpuIndexCagra.cu + $<$:GpuIndexCagra.cu> GpuResources.cpp StandardGpuResources.cpp impl/BinaryDistance.cu @@ -92,7 +92,7 @@ set(FAISS_GPU_HEADERS GpuFaissAssert.h GpuIndex.h GpuIndexBinaryFlat.h - GpuIndexCagra.h + $<$:GpuIndexCagra.h> GpuIndexFlat.h GpuIndexIVF.h GpuIndexIVFFlat.h From 867597429a5686c5ad3e124f5fac468255c2db5f Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 14 Feb 2024 13:39:26 -0800 Subject: [PATCH 09/46] copyTo and copyFrom --- faiss/IndexHNSW.cpp | 27 +++- faiss/IndexHNSW.h | 8 + faiss/gpu/GpuIndexCagra.cu | 80 +++++++++- faiss/gpu/GpuIndexCagra.h | 14 ++ faiss/gpu/impl/RaftCagra.cu | 106 +++++++++++++ faiss/gpu/impl/RaftCagra.cuh | 17 ++ faiss/gpu/test/TestGpuIndexCagra.cu | 235 +++++++++++++++++++++++++++- faiss/impl/HNSW.cpp | 56 +++++-- faiss/impl/HNSW.h | 9 +- 9 files changed, 527 insertions(+), 25 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 9a67332d67..8305764df8 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -192,7 +192,7 @@ void hnsw_add_vertices( int i1 = n; - for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) { + for (int pt_level = hist.size() - 1; pt_level >= !index_hnsw.init_level0; pt_level--) { int i0 = i1 - hist[pt_level]; if (verbose) { @@ -228,7 +228,7 @@ void hnsw_add_vertices( continue; } - hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt); + hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt, index_hnsw.keep_max_size_level0); if (prev_display >= 0 && i - i0 > prev_display + 10000) { prev_display = i - i0; @@ -248,7 +248,12 @@ void hnsw_add_vertices( } i1 = i0; } - FAISS_ASSERT(i1 == 0); + if (index_hnsw.init_level0) { + FAISS_ASSERT(i1 == 0); + } + else { + FAISS_ASSERT((i1 - hist[0]) == 0); + } } if (verbose) { printf("Done in %.3f ms\n", getmillisecs() - t0); @@ -914,4 +919,20 @@ void IndexHNSW2Level::flip_to_ivf() { delete storage2l; } +/************************************************************** + * IndexHNSWCagra implementation + **************************************************************/ + +IndexHNSWCagra::IndexHNSWCagra() { + is_trained = true; +} + +IndexHNSWCagra::IndexHNSWCagra(int d, int M) + : IndexHNSW(new IndexFlatL2(d), M) { + own_fields = true; + is_trained = true; + init_level0 = true; + keep_max_size_level0 = true; +} + } // namespace faiss diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index e0b65fca9d..3d3162e423 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -34,6 +34,9 @@ struct IndexHNSW : Index { bool own_fields = false; Index* storage = nullptr; + bool init_level0 = true; + bool keep_max_size_level0 = false; + explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2); explicit IndexHNSW(Index* storage, int M = 32); @@ -148,4 +151,9 @@ struct IndexHNSW2Level : IndexHNSW { const SearchParameters* params = nullptr) const override; }; +struct IndexHNSWCagra : IndexHNSW { + IndexHNSWCagra(); + IndexHNSWCagra(int d, int M); +}; + } // namespace faiss diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index 1a8eb382c0..e20dbc8663 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -20,9 +20,10 @@ * limitations under the License. */ +#include #include +#include #include -#include "GpuIndexCagra.h" namespace faiss { namespace gpu { @@ -54,7 +55,7 @@ void GpuIndexCagra::train(idx_t n, const float* x) { cagraConfig_.nn_descent_niter, this->metric_type, this->metric_arg, - faiss::gpu::INDICES_64_BIT); + INDICES_64_BIT); index_->train(n, x); @@ -116,16 +117,91 @@ void GpuIndexCagra::searchImpl_( } } +void GpuIndexCagra::copyFrom(const faiss::IndexHNSWCagra* index) { + FAISS_ASSERT(index); + + auto base_index = index->storage; + auto l2_index = dynamic_cast(base_index); + FAISS_ASSERT(l2_index); + auto distances = l2_index->get_xb(); + + auto hnsw = index->hnsw; + // copy level 0 to a dense knn graph matrix + std::vector knn_graph; + knn_graph.reserve(index->ntotal * hnsw.nb_neighbors(0)); + +#pragma omp parallel for + for (size_t i = 0; i < index->ntotal; ++i) { + size_t begin, end; + hnsw.neighbor_range(i, 0, &begin, &end); + for (size_t j = begin; j < end; j++) { + // knn_graph.push_back(hnsw.neighbors[j]); + knn_graph[i * hnsw.nb_neighbors(0) + (j - begin)] = + hnsw.neighbors[j]; + } + } + + index_ = std::make_shared( + this->resources_.get(), + this->d, + index->ntotal, + hnsw.nb_neighbors(0), + distances, + knn_graph.data(), + this->metric_type, + this->metric_arg, + INDICES_64_BIT); + + this->is_trained = true; +} + +void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { + FAISS_ASSERT(index_ && this->is_trained && index); + + auto graph_degree = index_->get_knngraph_degree(); + FAISS_THROW_IF_NOT_MSG( + (index->hnsw.nb_neighbors(0)) == graph_degree, + "IndexHNSWCagra.hnsw.nb_neighbors(0) should be equal to GpuIndexCagraConfig.graph_degree"); + + auto n_train = this->ntotal; + auto train_dataset = index_->get_training_dataset(); + + // turn off as level 0 is copied from CAGRA graph + index->init_level0 = false; + index->add(n_train, train_dataset.data()); + + auto graph = get_knngraph(); + +#pragma omp parallel for + for (idx_t i = 0; i < n_train; i++) { + size_t begin, end; + index->hnsw.neighbor_range(i, 0, &begin, &end); + for (size_t j = begin; j < end; j++) { + index->hnsw.neighbors[j] = graph[i * graph_degree + (j - begin)]; + } + } + + // turn back on to allow new vectors to be added to level 0 + index->init_level0 = true; +} + void GpuIndexCagra::reset() { DeviceScope scope(config_.device); if (index_) { index_->reset(); this->ntotal = 0; + this->is_trained = false; } else { FAISS_ASSERT(this->ntotal == 0); } } +std::vector GpuIndexCagra::get_knngraph() const { + FAISS_ASSERT(index_ && this->is_trained); + + return index_->get_knngraph(); +} + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 2c31ab9f59..902a0d34e7 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -24,6 +24,10 @@ #include +namespace faiss { +struct IndexHNSWCagra; +} + namespace faiss { namespace gpu { @@ -118,8 +122,18 @@ struct GpuIndexCagra : public GpuIndex { /// Trains CAGRA based on the given vector data void train(idx_t n, const float* x) override; + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(const faiss::IndexHNSWCagra* index); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexHNSWCagra* index) const; + void reset() override; + std::vector get_knngraph() const; + protected: bool addImplRequiresIDs_() const override; diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu index c0f7bbba69..0a55901a1b 100644 --- a/faiss/gpu/impl/RaftCagra.cu +++ b/faiss/gpu/impl/RaftCagra.cu @@ -63,6 +63,61 @@ RaftCagra::RaftCagra( reset(); } +RaftCagra::RaftCagra( + GpuResources* resources, + int dim, + idx_t n, + int graph_degree, + const float* distances, + const idx_t* knn_graph, + faiss::MetricType metric, + float metricArg, + IndicesOptions indicesOptions) + : resources_(resources), + dim_(dim), + metric_(metric), + metricArg_(metricArg) { + FAISS_THROW_IF_NOT_MSG( + metric == faiss::METRIC_L2, + "CAGRA currently only supports L2 metric."); + FAISS_THROW_IF_NOT_MSG( + indicesOptions == faiss::gpu::INDICES_64_BIT, + "only INDICES_64_BIT is supported for RAFT CAGRA index"); + + auto distances_on_gpu = getDeviceForAddress(distances) >= 0; + auto knn_graph_on_gpu = getDeviceForAddress(knn_graph) >= 0; + + FAISS_ASSERT(distances_on_gpu == knn_graph_on_gpu); + + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + if (distances_on_gpu && knn_graph_on_gpu) { + auto distances_mds = + raft::make_device_matrix_view( + distances, n, dim); + auto knn_graph_mds = + raft::make_device_matrix_view( + knn_graph, n, graph_degree); + + raft_knn_index = raft::neighbors::cagra::index( + raft_handle, + raft::distance::DistanceType::L2Expanded, + distances_mds, + knn_graph_mds); + } else { + auto distances_mds = raft::make_host_matrix_view( + distances, n, dim); + auto knn_graph_mds = raft::make_host_matrix_view( + knn_graph, n, graph_degree); + + raft_knn_index = raft::neighbors::cagra::index( + raft_handle, + raft::distance::DistanceType::L2Expanded, + distances_mds, + knn_graph_mds); + } +} + void RaftCagra::train(idx_t n, const float* x) { const raft::device_resources& raft_handle = resources_->getRaftHandleCurrentDevice(); @@ -144,5 +199,56 @@ void RaftCagra::reset() { raft_knn_index.reset(); } +idx_t RaftCagra::get_knngraph_degree() const { + FAISS_ASSERT(raft_knn_index.has_value()); + return static_cast(raft_knn_index.value().graph_degree()); +} + +std::vector RaftCagra::get_knngraph() const { + FAISS_ASSERT(raft_knn_index.has_value()); + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + auto stream = raft_handle.get_stream(); + + auto device_graph = raft_knn_index.value().graph(); + + std::vector host_graph( + device_graph.extent(0) * device_graph.extent(1)); + + raft::update_host( + host_graph.data(), + device_graph.data_handle(), + host_graph.size(), + stream); + raft_handle.sync_stream(); + + return host_graph; +} + +std::vector RaftCagra::get_training_dataset() const { + FAISS_ASSERT(raft_knn_index.has_value()); + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + auto stream = raft_handle.get_stream(); + + auto device_dataset = raft_knn_index.value().dataset(); + + std::vector host_dataset( + device_dataset.extent(0) * device_dataset.extent(1)); + + RAFT_CUDA_TRY(cudaMemcpy2DAsync( + host_dataset.data(), + sizeof(float) * dim_, + device_dataset.data_handle(), + sizeof(float) * device_dataset.stride(0), + sizeof(float) * dim_, + device_dataset.extent(0), + cudaMemcpyDefault, + raft_handle.get_stream())); + raft_handle.sync_stream(); + + return host_dataset; +} + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index 7f2b8b485c..0fddb3b39f 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -54,6 +54,17 @@ class RaftCagra { float metricArg, IndicesOptions indicesOptions); + RaftCagra( + GpuResources* resources, + int dim, + idx_t n, + int graph_degree, + const float* distances, + const idx_t* knn_graph, + faiss::MetricType metric, + float metricArg, + IndicesOptions indicesOptions); + ~RaftCagra() = default; void train(idx_t n, const float* x); @@ -79,6 +90,12 @@ class RaftCagra { void reset(); + idx_t get_knngraph_degree() const; + + std::vector get_knngraph() const; + + std::vector get_training_dataset() const; + private: /// Collection of GPU resources that we use GpuResources* resources_; diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index 8ba11e63ac..658ec8858f 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -39,6 +39,7 @@ struct Options { Options() { numTrain = 2 * faiss::gpu::randVal(2000, 5000); dim = faiss::gpu::randVal(4, 10); + numAdd = faiss::gpu::randVal(1000, 3000); graphDegree = faiss::gpu::randSelect({32, 64}); intermediateGraphDegree = faiss::gpu::randSelect({64, 98}); @@ -64,6 +65,7 @@ struct Options { } int numTrain; + int numAdd; int dim; size_t graphDegree; size_t intermediateGraphDegree; @@ -74,18 +76,18 @@ struct Options { }; void queryTest() { - for (int tries = 0; tries < 2; ++tries) { + for (int tries = 0; tries < 5; ++tries) { Options opt; std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + // train cpu index faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2); cpuIndex.hnsw.efConstruction = opt.k * 2; - // Training IndexHNSW is a no-op - cpuIndex.train(opt.numTrain, trainVecs.data()); cpuIndex.add(opt.numTrain, trainVecs.data()); + // train gpu index faiss::gpu::StandardGpuResources res; res.noTempMemory(); @@ -99,6 +101,7 @@ void queryTest() { &res, cpuIndex.d, faiss::METRIC_L2, config); gpuIndex.train(opt.numTrain, trainVecs.data()); + // query auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); std::vector refDistance(opt.numQuery * opt.k, 0); @@ -113,6 +116,7 @@ void queryTest() { refIndices.data(), &cpuSearchParams); + // test quality of searches auto gpuRes = res.getResources(); auto devAlloc = faiss::gpu::makeDevAlloc( faiss::gpu::AllocType::FlatData, @@ -176,6 +180,231 @@ TEST(TestGpuIndexCagra, Float32_Query_L2) { queryTest(); } +void copyToTest() { + for (int tries = 0; tries < 5; ++tries) { + Options opt; + + std::vector trainVecs = + faiss::gpu::randVecs(opt.numTrain, opt.dim); + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + // train gpu index and copy to cpu index + faiss::gpu::GpuIndexCagraConfig config; + config.device = opt.device; + config.graph_degree = opt.graphDegree; + config.intermediate_graph_degree = opt.intermediateGraphDegree; + config.build_algo = opt.buildAlgo; + + faiss::gpu::GpuIndexCagra gpuIndex( + &res, opt.dim, faiss::METRIC_L2, config); + gpuIndex.train(opt.numTrain, trainVecs.data()); + + faiss::IndexHNSWCagra copiedCpuIndex(opt.dim, opt.graphDegree / 2); + copiedCpuIndex.hnsw.efConstruction = opt.k * 2; + gpuIndex.copyTo(&copiedCpuIndex); + + // add more vecs to copied cpu index + copiedCpuIndex.add(opt.numAdd, addVecs.data()); + + // train cpu index + faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2); + cpuIndex.hnsw.efConstruction = opt.k * 2; + cpuIndex.add(opt.numTrain, trainVecs.data()); + + // add more vecs to cpu index + cpuIndex.add(opt.numAdd, addVecs.data()); + + // query indexes + auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + + std::vector refDistance(opt.numQuery * opt.k, 0); + std::vector refIndices(opt.numQuery * opt.k, -1); + faiss::SearchParametersHNSW cpuSearchParams; + cpuSearchParams.efSearch = opt.k * 2; + cpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + refDistance.data(), + refIndices.data(), + &cpuSearchParams); + + std::vector copyRefDistance(opt.numQuery * opt.k, 0); + std::vector copyRefIndices(opt.numQuery * opt.k, -1); + faiss::SearchParametersHNSW cpuSearchParamstwo; + cpuSearchParamstwo.efSearch = opt.k * 2; + copiedCpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + copyRefDistance.data(), + copyRefIndices.data(), + &cpuSearchParamstwo); + + // test quality of search + auto gpuRes = res.getResources(); + + auto refDistanceDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + refDistance, + gpuRes->getDefaultStreamCurrentDevice()); + auto refIndicesDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + refIndices, + gpuRes->getDefaultStreamCurrentDevice()); + + auto copyRefDistanceDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + copyRefDistance, + gpuRes->getDefaultStreamCurrentDevice()); + auto copyRefIndicesDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + copyRefIndices, + gpuRes->getDefaultStreamCurrentDevice()); + + auto raft_handle = gpuRes->getRaftHandleCurrentDevice(); + + auto ref_dis_mds = raft::make_device_matrix_view( + refDistanceDev.data(), opt.numQuery, opt.k); + auto ref_dis_mds_opt = + std::optional>( + ref_dis_mds); + auto ref_ind_mds = + raft::make_device_matrix_view( + refIndicesDev.data(), opt.numQuery, opt.k); + + auto copy_ref_dis_mds = raft::make_device_matrix_view( + copyRefDistanceDev.data(), opt.numQuery, opt.k); + auto copy_ref_dis_mds_opt = + std::optional>( + copy_ref_dis_mds); + auto copy_ref_ind_mds = + raft::make_device_matrix_view( + copyRefIndicesDev.data(), opt.numQuery, opt.k); + + double scalar_init = 0; + auto recall_score = raft::make_host_scalar(scalar_init); + + raft::stats::neighborhood_recall( + raft_handle, + copy_ref_ind_mds, + ref_ind_mds, + recall_score.view(), + copy_ref_dis_mds_opt, + ref_dis_mds_opt); + ASSERT_TRUE(*recall_score.data_handle() > 0.99); + } +} + +TEST(TestGpuIndexCagra, Float32_CopyTo_L2) { + copyToTest(); +} + +void copyFromTest() { + for (int tries = 0; tries < 5; ++tries) { + Options opt; + + std::vector trainVecs = + faiss::gpu::randVecs(opt.numTrain, opt.dim); + + // train cpu index + faiss::IndexHNSWCagra cpuIndex(opt.dim, opt.graphDegree / 2); + cpuIndex.hnsw.efConstruction = opt.k * 2; + cpuIndex.add(opt.numTrain, trainVecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + // convert to gpu index + faiss::gpu::GpuIndexCagra copiedGpuIndex( + &res, cpuIndex.d, faiss::METRIC_L2); + copiedGpuIndex.copyFrom(&cpuIndex); + + // train gpu index + faiss::gpu::GpuIndexCagraConfig config; + config.device = opt.device; + config.graph_degree = opt.graphDegree; + config.intermediate_graph_degree = opt.intermediateGraphDegree; + config.build_algo = opt.buildAlgo; + + faiss::gpu::GpuIndexCagra gpuIndex( + &res, opt.dim, faiss::METRIC_L2, config); + gpuIndex.train(opt.numTrain, trainVecs.data()); + + // query + auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + + auto gpuRes = res.getResources(); + auto devAlloc = faiss::gpu::makeDevAlloc( + faiss::gpu::AllocType::FlatData, + gpuRes->getDefaultStreamCurrentDevice()); + faiss::gpu::DeviceTensor copyTestDistance( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + faiss::gpu::DeviceTensor copyTestIndices( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + copiedGpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + copyTestDistance.data(), + copyTestIndices.data()); + + faiss::gpu::DeviceTensor testDistance( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + faiss::gpu::DeviceTensor testIndices( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + gpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + testDistance.data(), + testIndices.data()); + + // test quality of searches + auto raft_handle = gpuRes->getRaftHandleCurrentDevice(); + + auto test_dis_mds = raft::make_device_matrix_view( + testDistance.data(), opt.numQuery, opt.k); + auto test_dis_mds_opt = + std::optional>( + test_dis_mds); + + auto test_ind_mds = + raft::make_device_matrix_view( + testIndices.data(), opt.numQuery, opt.k); + + auto copy_test_dis_mds = + raft::make_device_matrix_view( + copyTestDistance.data(), opt.numQuery, opt.k); + auto copy_test_dis_mds_opt = + std::optional>( + copy_test_dis_mds); + + auto copy_test_ind_mds = + raft::make_device_matrix_view( + copyTestIndices.data(), opt.numQuery, opt.k); + + double scalar_init = 0; + auto recall_score = raft::make_host_scalar(scalar_init); + + raft::stats::neighborhood_recall( + raft_handle, + copy_test_ind_mds, + test_ind_mds, + recall_score.view(), + copy_test_dis_mds_opt, + test_dis_mds_opt); + ASSERT_TRUE(*recall_score.data_handle() > 0.99); + } +} + +TEST(TestGpuIndexCagra, Float32_CopyFrom_L2) { + copyFromTest(); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index a9fb9daf5b..c886f7d5df 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -7,6 +7,7 @@ #include +#include #include #include @@ -110,8 +111,8 @@ void HNSW::print_neighbor_stats(int level) const { level, nb_neighbors(level)); size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0; -#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \ - reduction(+: tot_reciprocal) reduction(+: n_node) +#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \ + reduction(+ : tot_reciprocal) reduction(+ : n_node) for (int i = 0; i < levels.size(); i++) { if (levels[i] > level) { n_node++; @@ -215,8 +216,8 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) { if (pt_level > max_level) max_level = pt_level; offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1)); - neighbors.resize(offsets.back(), -1); } + neighbors.resize(offsets.back(), -1); return max_level; } @@ -229,7 +230,14 @@ void HNSW::shrink_neighbor_list( DistanceComputer& qdis, std::priority_queue& input, std::vector& output, - int max_size) { + int max_size, + bool keep_max_size_level0) { + // This prevents number of neighbors at + // level 0 from being shrunk to less than 2 * M. + // This is essential in making sure + // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional + std::vector outsiders; + while (input.size() > 0) { NodeDistFarther v1 = input.top(); input.pop(); @@ -250,8 +258,14 @@ void HNSW::shrink_neighbor_list( if (output.size() >= max_size) { return; } + } else if (keep_max_size_level0) { + outsiders.push_back(v1); } } + size_t idx = 0; + while (keep_max_size_level0 && output.size() < max_size) { + output.push_back(outsiders[idx++]); + } } namespace { @@ -268,7 +282,8 @@ using NodeDistFarther = HNSW::NodeDistFarther; void shrink_neighbor_list( DistanceComputer& qdis, std::priority_queue& resultSet1, - int max_size) { + int max_size, + bool keep_max_size_level0 = false) { if (resultSet1.size() < max_size) { return; } @@ -280,7 +295,8 @@ void shrink_neighbor_list( resultSet1.pop(); } - HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size); + HNSW::shrink_neighbor_list( + qdis, resultSet, returnlist, max_size, keep_max_size_level0); for (NodeDistFarther curen2 : returnlist) { resultSet1.emplace(curen2.d, curen2.id); @@ -294,7 +310,8 @@ void add_link( DistanceComputer& qdis, storage_idx_t src, storage_idx_t dest, - int level) { + int level, + bool keep_max_size_level0 = false) { size_t begin, end; hnsw.neighbor_range(src, level, &begin, &end); if (hnsw.neighbors[end - 1] == -1) { @@ -319,7 +336,8 @@ void add_link( resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh); } - shrink_neighbor_list(qdis, resultSet, end - begin); + shrink_neighbor_list( + qdis, resultSet, end - begin, keep_max_size_level0 && level == 0); // ...and back size_t i = begin; @@ -429,7 +447,8 @@ void HNSW::add_links_starting_from( float d_nearest, int level, omp_lock_t* locks, - VisitedTable& vt) { + VisitedTable& vt, + bool keep_max_size_level0) { std::priority_queue link_targets; search_neighbors_to_add( @@ -438,13 +457,14 @@ void HNSW::add_links_starting_from( // but we can afford only this many neighbors int M = nb_neighbors(level); - ::faiss::shrink_neighbor_list(ptdis, link_targets, M); + ::faiss::shrink_neighbor_list( + ptdis, link_targets, M, keep_max_size_level0 && level == 0); std::vector neighbors; neighbors.reserve(link_targets.size()); while (!link_targets.empty()) { storage_idx_t other_id = link_targets.top().id; - add_link(*this, ptdis, pt_id, other_id, level); + add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0); neighbors.push_back(other_id); link_targets.pop(); } @@ -452,7 +472,7 @@ void HNSW::add_links_starting_from( omp_unset_lock(&locks[pt_id]); for (storage_idx_t other_id : neighbors) { omp_set_lock(&locks[other_id]); - add_link(*this, ptdis, other_id, pt_id, level); + add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0); omp_unset_lock(&locks[other_id]); } omp_set_lock(&locks[pt_id]); @@ -467,7 +487,8 @@ void HNSW::add_with_locks( int pt_level, int pt_id, std::vector& locks, - VisitedTable& vt) { + VisitedTable& vt, + bool keep_max_size_level0) { // greedy search on upper levels storage_idx_t nearest; @@ -496,7 +517,14 @@ void HNSW::add_with_locks( for (; level >= 0; level--) { add_links_starting_from( - ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt); + ptdis, + pt_id, + nearest, + d_nearest, + level, + locks.data(), + vt, + keep_max_size_level0); } omp_unset_lock(&locks[pt_id]); diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index cb6b422c3d..d1b9a955a6 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -184,7 +184,8 @@ struct HNSW { float d_nearest, int level, omp_lock_t* locks, - VisitedTable& vt); + VisitedTable& vt, + bool keep_max_size_level0 = false); /** add point pt_id on all levels <= pt_level and build the link * structure for them. */ @@ -193,7 +194,8 @@ struct HNSW { int pt_level, int pt_id, std::vector& locks, - VisitedTable& vt); + VisitedTable& vt, + bool keep_max_size_level0 = false); /// search interface for 1 point, single thread HNSWStats search( @@ -224,7 +226,8 @@ struct HNSW { DistanceComputer& qdis, std::priority_queue& input, std::vector& output, - int max_size); + int max_size, + bool keep_max_size_level0 = false); void permute_entries(const idx_t* map); }; From c7fcf4a030bd300d41668b97093eb151b0a9a890 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 14 Feb 2024 13:56:26 -0800 Subject: [PATCH 10/46] style check --- faiss/IndexHNSW.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 8305764df8..1589b10a6c 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -192,7 +192,9 @@ void hnsw_add_vertices( int i1 = n; - for (int pt_level = hist.size() - 1; pt_level >= !index_hnsw.init_level0; pt_level--) { + for (int pt_level = hist.size() - 1; + pt_level >= !index_hnsw.init_level0; + pt_level--) { int i0 = i1 - hist[pt_level]; if (verbose) { @@ -228,7 +230,13 @@ void hnsw_add_vertices( continue; } - hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt, index_hnsw.keep_max_size_level0); + hnsw.add_with_locks( + *dis, + pt_level, + pt_id, + locks, + vt, + index_hnsw.keep_max_size_level0); if (prev_display >= 0 && i - i0 > prev_display + 10000) { prev_display = i - i0; @@ -250,8 +258,7 @@ void hnsw_add_vertices( } if (index_hnsw.init_level0) { FAISS_ASSERT(i1 == 0); - } - else { + } else { FAISS_ASSERT((i1 - hist[0]) == 0); } } From 065f912d6daa23fbf7e1c345982f7c5858ba7988 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 20 Feb 2024 13:57:09 -0800 Subject: [PATCH 11/46] add read/write --- faiss/impl/index_read.cpp | 3 +++ faiss/impl/index_write.cpp | 2 ++ 2 files changed, 5 insertions(+) diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index ac62e0269e..8622b99c06 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -958,7 +958,10 @@ Index* read_index(IOReader* f, int io_flags) { idxhnsw = new IndexHNSWSQ(); if (h == fourcc("IHN2")) idxhnsw = new IndexHNSW2Level(); + if (h == fourcc("IHNc")) + idxhnsw = new IndexHNSWCagra(); read_index_header(idxhnsw, f); + READ1(idxhnsw->keep_max_size_level0); read_HNSW(&idxhnsw->hnsw, f); idxhnsw->storage = read_index(f, io_flags); idxhnsw->own_fields = true; diff --git a/faiss/impl/index_write.cpp b/faiss/impl/index_write.cpp index b2808d7170..1f27a68451 100644 --- a/faiss/impl/index_write.cpp +++ b/faiss/impl/index_write.cpp @@ -760,10 +760,12 @@ void write_index(const Index* idx, IOWriter* f) { : dynamic_cast(idx) ? fourcc("IHNp") : dynamic_cast(idx) ? fourcc("IHNs") : dynamic_cast(idx) ? fourcc("IHN2") + : dynamic_cast(idx) ? fourcc("IHNc") : 0; FAISS_THROW_IF_NOT(h != 0); WRITE1(h); write_index_header(idxhnsw, f); + WRITE1(idxhnsw->keep_max_size_level0); write_HNSW(&idxhnsw->hnsw, f); write_index(idxhnsw->storage, f); } else if (const IndexNSG* idxnsg = dynamic_cast(idx)) { From 2b0ea76e2b581f74eb47c460a80e6ef56208f12a Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 20 Feb 2024 14:43:49 -0800 Subject: [PATCH 12/46] add destructor --- faiss/gpu/GpuIndexCagra.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 902a0d34e7..324ef9c089 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -117,7 +117,7 @@ struct GpuIndexCagra : public GpuIndex { faiss::MetricType metric = faiss::METRIC_L2, GpuIndexCagraConfig config = GpuIndexCagraConfig()); - ~GpuIndexCagra() override = default; + ~GpuIndexCagra() override {}; /// Trains CAGRA based on the given vector data void train(idx_t n, const float* x) override; From 8c83bd23bae8d7bb68a455e4c259a17b94f417b2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 21 Feb 2024 08:50:38 -0800 Subject: [PATCH 13/46] destructor body, copyto reset --- faiss/gpu/GpuIndexCagra.cu | 9 ++++++--- faiss/gpu/GpuIndexCagra.h | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index e20dbc8663..81c7b79d00 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -24,6 +24,7 @@ #include #include #include +#include "GpuIndexCagra.h" namespace faiss { namespace gpu { @@ -38,6 +39,8 @@ GpuIndexCagra::GpuIndexCagra( this->is_trained = false; } +GpuIndexCagra::~GpuIndexCagra() {} + void GpuIndexCagra::train(idx_t n, const float* x) { if (this->is_trained) { FAISS_ASSERT(index_); @@ -158,10 +161,10 @@ void GpuIndexCagra::copyFrom(const faiss::IndexHNSWCagra* index) { void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { FAISS_ASSERT(index_ && this->is_trained && index); + index->reset(); + auto graph_degree = index_->get_knngraph_degree(); - FAISS_THROW_IF_NOT_MSG( - (index->hnsw.nb_neighbors(0)) == graph_degree, - "IndexHNSWCagra.hnsw.nb_neighbors(0) should be equal to GpuIndexCagraConfig.graph_degree"); + index->hnsw.M = graph_degree / 2; auto n_train = this->ntotal; auto train_dataset = index_->get_training_dataset(); diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 324ef9c089..35783a848c 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -117,7 +117,7 @@ struct GpuIndexCagra : public GpuIndex { faiss::MetricType metric = faiss::METRIC_L2, GpuIndexCagraConfig config = GpuIndexCagraConfig()); - ~GpuIndexCagra() override {}; + ~GpuIndexCagra() override; /// Trains CAGRA based on the given vector data void train(idx_t n, const float* x) override; From 39fb35a9d49295b191a9153e4bc76f5c25dc696b Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 21 Feb 2024 09:14:42 -0800 Subject: [PATCH 14/46] remove destructor --- faiss/gpu/GpuIndexCagra.cu | 6 +++--- faiss/gpu/GpuIndexCagra.h | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index 81c7b79d00..f94b5cbcf5 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -39,8 +39,6 @@ GpuIndexCagra::GpuIndexCagra( this->is_trained = false; } -GpuIndexCagra::~GpuIndexCagra() {} - void GpuIndexCagra::train(idx_t n, const float* x) { if (this->is_trained) { FAISS_ASSERT(index_); @@ -164,7 +162,9 @@ void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { index->reset(); auto graph_degree = index_->get_knngraph_degree(); - index->hnsw.M = graph_degree / 2; + auto M = graph_degree / 2; + index->hnsw.set_default_probas(M, 1.0 / log(M)); + index->hnsw.offsets.push_back(0); auto n_train = this->ntotal; auto train_dataset = index_->get_training_dataset(); diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 35783a848c..62c0b489fb 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -117,8 +117,6 @@ struct GpuIndexCagra : public GpuIndex { faiss::MetricType metric = faiss::METRIC_L2, GpuIndexCagraConfig config = GpuIndexCagraConfig()); - ~GpuIndexCagra() override; - /// Trains CAGRA based on the given vector data void train(idx_t n, const float* x) override; From 49e261018618b1e2c8bb71f7ba8766e162df9564 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 21 Feb 2024 09:38:08 -0800 Subject: [PATCH 15/46] move cmake sources around --- faiss/gpu/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index 2efa622fe3..b060d2efe1 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -29,7 +29,6 @@ set(FAISS_GPU_SRC GpuIndexIVFFlat.cu GpuIndexIVFPQ.cu GpuIndexIVFScalarQuantizer.cu - $<$:GpuIndexCagra.cu> GpuResources.cpp StandardGpuResources.cpp impl/BinaryDistance.cu @@ -92,7 +91,6 @@ set(FAISS_GPU_HEADERS GpuFaissAssert.h GpuIndex.h GpuIndexBinaryFlat.h - $<$:GpuIndexCagra.h> GpuIndexFlat.h GpuIndexIVF.h GpuIndexIVFFlat.h @@ -240,11 +238,13 @@ generate_ivf_interleaved_code() if(FAISS_ENABLE_RAFT) list(APPEND FAISS_GPU_HEADERS + GpuIndexCagra.h impl/RaftCagra.cuh impl/RaftUtils.h impl/RaftIVFFlat.cuh impl/RaftFlatIndex.cuh) list(APPEND FAISS_GPU_SRC + GpuIndexCagra.cu impl/RaftCagra.cu impl/RaftFlatIndex.cu impl/RaftIVFFlat.cu) From d4434bb45532e4651c8b28a41e0fefa42f3305a1 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 21 Feb 2024 14:41:51 -0800 Subject: [PATCH 16/46] more protections for copying --- faiss/gpu/GpuIndexCagra.cu | 13 +++++++++++-- faiss/gpu/test/TestGpuIndexCagra.cu | 2 +- faiss/impl/HNSW.cpp | 2 ++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index f94b5cbcf5..d69958d0be 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -121,6 +121,10 @@ void GpuIndexCagra::searchImpl_( void GpuIndexCagra::copyFrom(const faiss::IndexHNSWCagra* index) { FAISS_ASSERT(index); + DeviceScope scope(config_.device); + + GpuIndex::copyFrom(index); + auto base_index = index->storage; auto l2_index = dynamic_cast(base_index); FAISS_ASSERT(l2_index); @@ -159,12 +163,17 @@ void GpuIndexCagra::copyFrom(const faiss::IndexHNSWCagra* index) { void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { FAISS_ASSERT(index_ && this->is_trained && index); - index->reset(); + DeviceScope scope(config_.device); + + // + // Index information + // + GpuIndex::copyTo(index); auto graph_degree = index_->get_knngraph_degree(); auto M = graph_degree / 2; + index->reset(); index->hnsw.set_default_probas(M, 1.0 / log(M)); - index->hnsw.offsets.push_back(0); auto n_train = this->ntotal; auto train_dataset = index_->get_training_dataset(); diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index 658ec8858f..aac7dcc3f0 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -203,8 +203,8 @@ void copyToTest() { gpuIndex.train(opt.numTrain, trainVecs.data()); faiss::IndexHNSWCagra copiedCpuIndex(opt.dim, opt.graphDegree / 2); - copiedCpuIndex.hnsw.efConstruction = opt.k * 2; gpuIndex.copyTo(&copiedCpuIndex); + copiedCpuIndex.hnsw.efConstruction = opt.k * 2; // add more vecs to copied cpu index copiedCpuIndex.add(opt.numAdd, addVecs.data()); diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index c886f7d5df..f1b00fd3e0 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -99,6 +99,8 @@ void HNSW::clear_neighbor_tables(int level) { void HNSW::reset() { max_level = -1; entry_point = -1; + assign_probas.clear(); + cum_nneighbor_per_level.clear(); offsets.clear(); offsets.push_back(0); levels.clear(); From ac65c2d14b609305a2e3d548ff5cdc780dbc99c2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 21 Feb 2024 16:34:36 -0800 Subject: [PATCH 17/46] support default constructed IndexHnswCagra in copyTo --- faiss/gpu/GpuIndexCagra.cu | 11 ++++++++++- faiss/gpu/test/TestGpuIndexCagra.cu | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index d69958d0be..b70ae78f79 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -169,10 +169,19 @@ void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { // Index information // GpuIndex::copyTo(index); + // This needs to be zeroed out as this implementation adds vectors to the + // cpuIndex instead of copying fields + index->ntotal = 0; auto graph_degree = index_->get_knngraph_degree(); auto M = graph_degree / 2; - index->reset(); + if (index->storage and index->own_fields) { + delete index->storage; + } + index->storage = new IndexFlatL2(index->d); + index->own_fields = true; + index->keep_max_size_level0 = true; + index->hnsw.reset(); index->hnsw.set_default_probas(M, 1.0 / log(M)); auto n_train = this->ntotal; diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index aac7dcc3f0..54987cd5f5 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -202,7 +202,7 @@ void copyToTest() { &res, opt.dim, faiss::METRIC_L2, config); gpuIndex.train(opt.numTrain, trainVecs.data()); - faiss::IndexHNSWCagra copiedCpuIndex(opt.dim, opt.graphDegree / 2); + faiss::IndexHNSWCagra copiedCpuIndex; gpuIndex.copyTo(&copiedCpuIndex); copiedCpuIndex.hnsw.efConstruction = opt.k * 2; From 619c37662b0ec8148f060f395a45dfa75ac152bf Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 22 Feb 2024 14:58:11 -0800 Subject: [PATCH 18/46] fix failing binary hnsw tests --- faiss/gpu/GpuIndexCagra.cu | 2 ++ faiss/impl/HNSW.cpp | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index b70ae78f79..541a1caa75 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -182,6 +182,8 @@ void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { index->own_fields = true; index->keep_max_size_level0 = true; index->hnsw.reset(); + index->hnsw.assign_probas.clear(); + index->hnsw.cum_nneighbor_per_level.clear(); index->hnsw.set_default_probas(M, 1.0 / log(M)); auto n_train = this->ntotal; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index f1b00fd3e0..c886f7d5df 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -99,8 +99,6 @@ void HNSW::clear_neighbor_tables(int level) { void HNSW::reset() { max_level = -1; entry_point = -1; - assign_probas.clear(); - cum_nneighbor_per_level.clear(); offsets.clear(); offsets.push_back(0); levels.clear(); From e25f8a4883271c0825f6372a9822455ec7b4b631 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 23 Feb 2024 09:48:04 -0800 Subject: [PATCH 19/46] link faiss_gpu target to OpenMP --- faiss/gpu/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index d9e4775e02..ec72c48d7c 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -320,5 +320,5 @@ __nv_relfatbin : { *(__nv_relfatbin) } target_link_options(faiss_gpu PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") find_package(CUDAToolkit REQUIRED) -target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:raft::compiled> $<$:nvidia::cutlass::cutlass>) +target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:raft::compiled> $<$:nvidia::cutlass::cutlass> OpenMP::OpenMP_CXX) target_compile_options(faiss_gpu PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr>) From e8351503c9d104ce81a4c48db1426bf35adf2fc0 Mon Sep 17 00:00:00 2001 From: divyegala Date: Fri, 23 Feb 2024 10:14:21 -0800 Subject: [PATCH 20/46] raft still can't find openmp --- faiss/gpu/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index ec72c48d7c..f48a86d8e3 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -320,5 +320,5 @@ __nv_relfatbin : { *(__nv_relfatbin) } target_link_options(faiss_gpu PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") find_package(CUDAToolkit REQUIRED) -target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:raft::compiled> $<$:nvidia::cutlass::cutlass> OpenMP::OpenMP_CXX) -target_compile_options(faiss_gpu PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr>) +target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:raft::compiled> $<$:nvidia::cutlass::cutlass> $<$:OpenMP::OpenMP_CXX>) +target_compile_options(faiss_gpu PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr> $<$:-fopenmp>) From aeabe122b5035e99b8be5d177a82e710704ec2f2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 26 Feb 2024 15:10:51 -0800 Subject: [PATCH 21/46] openmp flags and uint32 IndexType --- faiss/gpu/CMakeLists.txt | 2 +- faiss/gpu/impl/RaftCagra.cu | 64 ++++++++++++++++++++++++++---------- faiss/gpu/impl/RaftCagra.cuh | 4 +-- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index f48a86d8e3..d20f3b7f8e 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -321,4 +321,4 @@ target_link_options(faiss_gpu PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") find_package(CUDAToolkit REQUIRED) target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:raft::compiled> $<$:nvidia::cutlass::cutlass> $<$:OpenMP::OpenMP_CXX>) -target_compile_options(faiss_gpu PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr> $<$:-fopenmp>) +target_compile_options(faiss_gpu PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr $<$:-Xcompiler=${OpenMP_CXX_FLAGS}>>) diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu index 0a55901a1b..858bc4fe88 100644 --- a/faiss/gpu/impl/RaftCagra.cu +++ b/faiss/gpu/impl/RaftCagra.cu @@ -22,10 +22,12 @@ #include #include +#include #include #include #include +#include #include namespace faiss { @@ -91,30 +93,47 @@ RaftCagra::RaftCagra( const raft::device_resources& raft_handle = resources_->getRaftHandleCurrentDevice(); + if (distances_on_gpu && knn_graph_on_gpu) { + raft_handle.sync_stream(); + // Copying to host so that raft::neighbors::cagra::index + // creates an owning copy of the knn graph on device + auto knn_graph_copy = + raft::make_host_matrix(n, graph_degree); + thrust::copy( + thrust::device_ptr(knn_graph), + thrust::device_ptr(knn_graph + (n * graph_degree)), + knn_graph_copy.data_handle()); + auto distances_mds = raft::make_device_matrix_view( distances, n, dim); - auto knn_graph_mds = - raft::make_device_matrix_view( - knn_graph, n, graph_degree); - raft_knn_index = raft::neighbors::cagra::index( + raft_knn_index = raft::neighbors::cagra::index( raft_handle, raft::distance::DistanceType::L2Expanded, distances_mds, - knn_graph_mds); - } else { + raft::make_const_mdspan(knn_graph_copy.view())); + } else if (!distances_on_gpu && !knn_graph_on_gpu) { + // copy idx_t (int64_t) host knn_graph to uint32_t host knn_graph + auto knn_graph_copy = + raft::make_host_matrix(n, graph_degree); + std::copy( + knn_graph, + knn_graph + (n * graph_degree), + knn_graph_copy.data_handle()); + auto distances_mds = raft::make_host_matrix_view( distances, n, dim); - auto knn_graph_mds = raft::make_host_matrix_view( - knn_graph, n, graph_degree); - raft_knn_index = raft::neighbors::cagra::index( + raft_knn_index = raft::neighbors::cagra::index( raft_handle, raft::distance::DistanceType::L2Expanded, distances_mds, - knn_graph_mds); + raft::make_const_mdspan(knn_graph_copy.view())); + } else { + FAISS_THROW_MSG( + "distances and knn_graph must both be in device or host memory"); } } @@ -122,12 +141,12 @@ void RaftCagra::train(idx_t n, const float* x) { const raft::device_resources& raft_handle = resources_->getRaftHandleCurrentDevice(); if (getDeviceForAddress(x) >= 0) { - raft_knn_index = raft::neighbors::cagra::build( + raft_knn_index = raft::neighbors::cagra::build( raft_handle, index_pams_, raft::make_device_matrix_view(x, n, dim_)); } else { - raft_knn_index = raft::neighbors::cagra::build( + raft_knn_index = raft::neighbors::cagra::build( raft_handle, index_pams_, raft::make_host_matrix_view(x, n, dim_)); @@ -186,13 +205,21 @@ void RaftCagra::search( search_pams.num_random_samplings = num_random_samplings; search_pams.rand_xor_mask = rand_xor_mask; + auto indices_copy = raft::make_device_matrix( + raft_handle, numQueries, k_); + raft::neighbors::cagra::search( raft_handle, search_pams, raft_knn_index.value(), queries_view, - indices_view, + indices_copy.view(), distances_view); + thrust::copy( + raft::resource::get_thrust_policy(raft_handle), + indices_copy.data_handle(), + indices_copy.data_handle() + indices_copy.size(), + indices_view.data_handle()); } void RaftCagra::reset() { @@ -215,13 +242,14 @@ std::vector RaftCagra::get_knngraph() const { std::vector host_graph( device_graph.extent(0) * device_graph.extent(1)); - raft::update_host( - host_graph.data(), - device_graph.data_handle(), - host_graph.size(), - stream); raft_handle.sync_stream(); + thrust::copy( + thrust::device_ptr(device_graph.data_handle()), + thrust::device_ptr( + device_graph.data_handle() + device_graph.size()), + host_graph.data()); + return host_graph; } diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index 0fddb3b39f..6d0bf69c17 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -113,8 +113,8 @@ class RaftCagra { raft::neighbors::cagra::index_params index_pams_; /// Instance of trained RAFT CAGRA index - std::optional> raft_knn_index{ - std::nullopt}; + std::optional> + raft_knn_index{std::nullopt}; }; } // namespace gpu From 4e80586fd17c83f3d86d68bc96b0c5fab0495b01 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 26 Feb 2024 15:34:12 -0800 Subject: [PATCH 22/46] forgot conditional check in index_read --- faiss/impl/index_read.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index 8622b99c06..1b84f4a453 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -948,7 +948,7 @@ Index* read_index(IOReader* f, int io_flags) { idx = idxp; } else if ( h == fourcc("IHNf") || h == fourcc("IHNp") || h == fourcc("IHNs") || - h == fourcc("IHN2")) { + h == fourcc("IHN2") || h == fourcc("IHNc")) { IndexHNSW* idxhnsw = nullptr; if (h == fourcc("IHNf")) idxhnsw = new IndexHNSWFlat(); From c4bcabae3ba57eef938c215077d98f39a0b4926d Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 7 Mar 2024 09:33:24 -0800 Subject: [PATCH 23/46] minor changes --- faiss/IndexHNSW.cpp | 5 ++++- faiss/impl/HNSW.cpp | 8 +++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 1589b10a6c..5fdc774b92 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -7,6 +7,8 @@ // -*- c++ -*- +#include + #include #include @@ -196,6 +198,7 @@ void hnsw_add_vertices( pt_level >= !index_hnsw.init_level0; pt_level--) { int i0 = i1 - hist[pt_level]; + // std::cout << "level: " << pt_level << "points: " << hist[pt_level] << std::endl; if (verbose) { printf("Adding %d elements at level %d\n", i1 - i0, pt_level); @@ -236,7 +239,7 @@ void hnsw_add_vertices( pt_id, locks, vt, - index_hnsw.keep_max_size_level0); + index_hnsw.keep_max_size_level0 && (pt_level == 0)); if (prev_display >= 0 && i - i0 > prev_display + 10000) { prev_display = i - i0; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index c886f7d5df..f449c8446a 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -5,6 +5,8 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include @@ -263,7 +265,7 @@ void HNSW::shrink_neighbor_list( } } size_t idx = 0; - while (keep_max_size_level0 && output.size() < max_size) { + while (keep_max_size_level0 && (output.size() < max_size) && (idx < outsiders.size())) { output.push_back(outsiders[idx++]); } } @@ -337,7 +339,7 @@ void add_link( } shrink_neighbor_list( - qdis, resultSet, end - begin, keep_max_size_level0 && level == 0); + qdis, resultSet, end - begin, keep_max_size_level0); // ...and back size_t i = begin; @@ -458,7 +460,7 @@ void HNSW::add_links_starting_from( int M = nb_neighbors(level); ::faiss::shrink_neighbor_list( - ptdis, link_targets, M, keep_max_size_level0 && level == 0); + ptdis, link_targets, M, keep_max_size_level0); std::vector neighbors; neighbors.reserve(link_targets.size()); From 341a3fcef1a07bb284c96cf564674900526f5f62 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 7 Mar 2024 09:35:32 -0800 Subject: [PATCH 24/46] api change --- faiss/gpu/impl/RaftIVFFlat.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/faiss/gpu/impl/RaftIVFFlat.cu b/faiss/gpu/impl/RaftIVFFlat.cu index 1e310723d0..0906a60f46 100644 --- a/faiss/gpu/impl/RaftIVFFlat.cu +++ b/faiss/gpu/impl/RaftIVFFlat.cu @@ -403,7 +403,8 @@ void RaftIVFFlat::copyInvertedListsFrom(const InvertedLists* ivf) { } // Update the pointers and the sizes - raft_knn_index.value().recompute_internal_state(raft_handle); + raft::neighbors::ivf_flat::helpers::recompute_internal_state( + raft_handle, &(raft_knn_index.value())); for (size_t i = 0; i < nlist; ++i) { size_t listSize = ivf->list_size(i); From 172aa6570e0b1a2a5b91cb60e98ce47f39ae6b9e Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 20 Mar 2024 16:06:08 -0700 Subject: [PATCH 25/46] working python --- faiss/python/CMakeLists.txt | 5 +++++ faiss/python/swigfaiss.swig | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/faiss/python/CMakeLists.txt b/faiss/python/CMakeLists.txt index 8bca710f5f..a2a3fdddf7 100644 --- a/faiss/python/CMakeLists.txt +++ b/faiss/python/CMakeLists.txt @@ -38,6 +38,11 @@ macro(configure_swigfaiss source) set_source_files_properties(${source} PROPERTIES COMPILE_DEFINITIONS GPU_WRAPPER ) + if (FAISS_ENABLE_RAFT) + set_source_files_properties(${source} PROPERTIES + COMPILE_DEFINITIONS FAISS_ENABLE_RAFT + ) + endif() endif() endmacro() diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index fb7f50dd2e..ab069002fa 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -304,6 +304,11 @@ void gpu_sync_all_devices(); #include #include #include + +#ifdef FAISS_ENABLE_RAFT +#include +#endif + #include #include #include @@ -557,6 +562,11 @@ struct faiss::simd16uint16 {}; %include %include %include + +#ifdef FAISS_ENABLE_RAFT +%include +#endif + %include %include %include @@ -673,6 +683,9 @@ struct faiss::simd16uint16 {}; DOWNCAST ( IndexRowwiseMinMax ) DOWNCAST ( IndexRowwiseMinMaxFP16 ) #ifdef GPU_WRAPPER +#ifdef FAISS_ENABLE_RAFT + DOWNCAST_GPU ( GpuIndexCagra ) +#endif DOWNCAST_GPU ( GpuIndexIVFPQ ) DOWNCAST_GPU ( GpuIndexIVFFlat ) DOWNCAST_GPU ( GpuIndexIVFScalarQuantizer ) From 0cd684e19601d67351411c75bbe89c8205f11dce Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 20 Mar 2024 20:07:45 -0700 Subject: [PATCH 26/46] compile option to swig --- CMakeLists.txt | 2 ++ faiss/python/CMakeLists.txt | 2 +- faiss/python/swigfaiss.swig | 6 ------ 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6cdc37c46f..39b5e18325 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,6 +46,8 @@ project(faiss LANGUAGES ${FAISS_LANGUAGES}) include(GNUInstallDirs) +set(CMAKE_INSTALL_PREFIX "$ENV{CONDA_PREFIX}") + set(CMAKE_CXX_STANDARD 17) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") diff --git a/faiss/python/CMakeLists.txt b/faiss/python/CMakeLists.txt index a2a3fdddf7..9dc14e9837 100644 --- a/faiss/python/CMakeLists.txt +++ b/faiss/python/CMakeLists.txt @@ -39,7 +39,7 @@ macro(configure_swigfaiss source) COMPILE_DEFINITIONS GPU_WRAPPER ) if (FAISS_ENABLE_RAFT) - set_source_files_properties(${source} PROPERTIES + set_property(SOURCE ${source} APPEND PROPERTY COMPILE_DEFINITIONS FAISS_ENABLE_RAFT ) endif() diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index ab069002fa..9bf5e83aee 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -304,11 +304,7 @@ void gpu_sync_all_devices(); #include #include #include - -#ifdef FAISS_ENABLE_RAFT #include -#endif - #include #include #include @@ -562,11 +558,9 @@ struct faiss::simd16uint16 {}; %include %include %include - #ifdef FAISS_ENABLE_RAFT %include #endif - %include %include %include From 7ff8b3b6c1c93545bc8d4c0f7e6da1aaa87e69dd Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 3 Apr 2024 15:15:34 -0700 Subject: [PATCH 27/46] expose ivf pq params --- faiss/gpu/GpuIndexCagra.cu | 37 +++++++++++- faiss/gpu/GpuIndexCagra.h | 114 +++++++++++++++++++++++++++++++++++ faiss/gpu/impl/RaftCagra.cu | 104 +++++++++++++++++++++++++++----- faiss/gpu/impl/RaftCagra.cuh | 12 +++- 4 files changed, 250 insertions(+), 17 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index 541a1caa75..916f774bc1 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -24,6 +24,7 @@ #include #include #include +#include #include "GpuIndexCagra.h" namespace faiss { @@ -47,6 +48,38 @@ void GpuIndexCagra::train(idx_t n, const float* x) { FAISS_ASSERT(!index_); + std::optional ivf_pq_params = + std::nullopt; + std::optional ivf_pq_search_params = + std::nullopt; + if (cagraConfig_.ivf_pq_params != nullptr) { + ivf_pq_params = + std::make_optional(); + ivf_pq_params->n_lists = cagraConfig_.ivf_pq_params->n_lists; + ivf_pq_params->kmeans_n_iters = + cagraConfig_.ivf_pq_params->kmeans_n_iters; + ivf_pq_params->kmeans_trainset_fraction = + cagraConfig_.ivf_pq_params->kmeans_trainset_fraction; + ivf_pq_params->pq_bits = cagraConfig_.ivf_pq_params->pq_bits; + ivf_pq_params->pq_dim = cagraConfig_.ivf_pq_params->pq_dim; + ivf_pq_params->codebook_kind = + static_cast( + cagraConfig_.ivf_pq_params->codebook_kind); + ivf_pq_params->force_random_rotation = + cagraConfig_.ivf_pq_params->force_random_rotation; + ivf_pq_params->conservative_memory_allocation = + cagraConfig_.ivf_pq_params->conservative_memory_allocation; + } + if (cagraConfig_.ivf_pq_search_params != nullptr) { + ivf_pq_search_params = + std::make_optional(); + ivf_pq_search_params->n_probes = + cagraConfig_.ivf_pq_search_params->n_probes; + ivf_pq_search_params->lut_dtype = + cagraConfig_.ivf_pq_search_params->lut_dtype; + ivf_pq_search_params->preferred_shmem_carveout = + cagraConfig_.ivf_pq_search_params->preferred_shmem_carveout; + } index_ = std::make_shared( this->resources_.get(), this->d, @@ -56,7 +89,9 @@ void GpuIndexCagra::train(idx_t n, const float* x) { cagraConfig_.nn_descent_niter, this->metric_type, this->metric_arg, - INDICES_64_BIT); + INDICES_64_BIT, + ivf_pq_params, + ivf_pq_search_params); index_->train(n, x); diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 62c0b489fb..5c04259092 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -23,6 +23,9 @@ #pragma once #include +#include +#include +#include "GpuIndexIVFPQ.h" namespace faiss { struct IndexHNSWCagra; @@ -40,6 +43,114 @@ enum class graph_build_algo { NN_DESCENT }; +/** A type for specifying how PQ codebooks are created. */ +enum class codebook_gen { // NOLINT + PER_SUBSPACE = 0, // NOLINT + PER_CLUSTER = 1, // NOLINT +}; + +struct IVFPQBuildCagraConfig { + /** + * The number of inverted lists (clusters) + * + * Hint: the number of vectors per cluster (`n_rows/n_lists`) should be approximately 1,000 to + * 10,000. + */ + uint32_t n_lists = 1024; + /** The number of iterations searching for kmeans centers (index building). */ + uint32_t kmeans_n_iters = 20; + /** The fraction of data to use during iterative kmeans building. */ + double kmeans_trainset_fraction = 0.5; + /** + * The bit length of the vector element after compression by PQ. + * + * Possible values: [4, 5, 6, 7, 8]. + * + * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + * performance, but the lower the recall. + */ + uint32_t pq_bits = 8; + /** + * The dimensionality of the vector after compression by PQ. When zero, an optimal value is + * selected using a heuristic. + * + * NB: `pq_dim * pq_bits` must be a multiple of 8. + * + * Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but + * lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are + * desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. + * For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' + * should be also a divisor of the dataset dim. + */ + uint32_t pq_dim = 0; + /** How PQ codebooks are created. */ + codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE; + /** + * Apply a random rotation matrix on the input data and queries even if `dim % pq_dim == 0`. + * + * Note: if `dim` is not multiple of `pq_dim`, a random rotation is always applied to the input + * data and queries to transform the working space from `dim` to `rot_dim`, which may be slightly + * larger than the original space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). + * However, this transform is not necessary when `dim` is multiple of `pq_dim` + * (`dim == rot_dim`, hence no need in adding "extra" data columns / features). + * + * By default, if `dim == rot_dim`, the rotation transform is initialized with the identity + * matrix. When `force_random_rotation == true`, a random orthogonal transform matrix is generated + * regardless of the values of `dim` and `pq_dim`. + */ + bool force_random_rotation = false; + /** + * By default, the algorithm allocates more space than necessary for individual clusters + * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + * data copies during repeated calls to `extend` (extending the database). + * + * The alternative is the conservative allocation behavior; when enabled, the algorithm always + * allocates the minimum amount of memory required to store the given number of records. Set this + * flag to `true` if you prefer to use as little GPU memory for the database as possible. + */ + bool conservative_memory_allocation = false; +}; + +struct IVFPQSearchCagraConfig { + /** The number of clusters to search. */ + uint32_t n_probes = 20; + /** + * Data type of look up table to be created dynamically at search time. + * + * Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U] + * + * The use of low-precision types reduces the amount of shared memory required at search time, so + * fast shared memory kernels can be used even for datasets with large dimansionality. Note that + * the recall is slightly degraded when low-precision type is selected. + */ + cudaDataType_t lut_dtype = CUDA_R_32F; + /** + * Storage data type for distance/similarity computed at search time. + * + * Possible values: [CUDA_R_16F, CUDA_R_32F] + * + * If the performance limiter at search time is device memory access, selecting FP16 will improve + * performance slightly. + */ + cudaDataType_t internal_distance_dtype = CUDA_R_32F; + /** + * Preferred fraction of SM's unified memory / L1 cache to be used as shared memory. + * + * Possible values: [0.0 - 1.0] as a fraction of the `sharedMemPerMultiprocessor`. + * + * One wants to increase the carveout to make sure a good GPU occupancy for the main search + * kernel, but not to keep it too high to leave some memory to be used as L1 cache. Note, this + * value is interpreted only as a hint. Moreover, a GPU usually allows only a fixed set of cache + * configurations, so the provided value is rounded up to the nearest configuration. Refer to the + * NVIDIA tuning guide for the target GPU architecture. + * + * Note, this is a low-level tuning parameter that can have drastic negative effects on the search + * performance if tweaked incorrectly. + */ + double preferred_shmem_carveout = 1.0; +}; + + struct GpuIndexCagraConfig : public GpuIndexConfig { /** Degree of input graph for pruning. */ size_t intermediate_graph_degree = 128; @@ -49,6 +160,9 @@ struct GpuIndexCagraConfig : public GpuIndexConfig { graph_build_algo build_algo = graph_build_algo::IVF_PQ; /** Number of Iterations to run if building with NN_DESCENT */ size_t nn_descent_niter = 20; + + IVFPQBuildCagraConfig *ivf_pq_params = nullptr; + IVFPQSearchCagraConfig *ivf_pq_search_params = nullptr; }; enum class search_algo { diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu index 858bc4fe88..c2657103f2 100644 --- a/faiss/gpu/impl/RaftCagra.cu +++ b/faiss/gpu/impl/RaftCagra.cu @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace faiss { @@ -42,12 +43,17 @@ RaftCagra::RaftCagra( size_t nn_descent_niter, faiss::MetricType metric, float metricArg, - IndicesOptions indicesOptions) + IndicesOptions indicesOptions, + std::optional ivf_pq_params, + std::optional + ivf_pq_search_params) : resources_(resources), dim_(dim), metric_(metric), metricArg_(metricArg), - index_pams_() { + index_pams_(), + ivf_pq_params_(ivf_pq_params), + ivf_pq_search_params_(ivf_pq_search_params) { FAISS_THROW_IF_NOT_MSG( metric == faiss::METRIC_L2, "CAGRA currently only supports L2 metric."); @@ -62,6 +68,15 @@ RaftCagra::RaftCagra( graph_build_algo); index_pams_.nn_descent_niter = nn_descent_niter; + if (!ivf_pq_params_) { + ivf_pq_params_ = + std::make_optional(); + } + if (!ivf_pq_search_params_) { + ivf_pq_search_params_ = + std::make_optional(); + } + reset(); } @@ -140,16 +155,75 @@ RaftCagra::RaftCagra( void RaftCagra::train(idx_t n, const float* x) { const raft::device_resources& raft_handle = resources_->getRaftHandleCurrentDevice(); - if (getDeviceForAddress(x) >= 0) { - raft_knn_index = raft::neighbors::cagra::build( - raft_handle, - index_pams_, - raft::make_device_matrix_view(x, n, dim_)); + if (index_pams_.build_algo == + raft::neighbors::cagra::graph_build_algo::IVF_PQ) { + std::optional> knn_graph( + raft::make_host_matrix( + n, index_pams_.intermediate_graph_degree)); + if (getDeviceForAddress(x) >= 0) { + auto dataset_d = + raft::make_device_matrix_view( + x, n, dim_); + raft::neighbors::cagra::build_knn_graph( + raft_handle, + dataset_d, + knn_graph->view(), + 1.0f, + ivf_pq_params_, + ivf_pq_search_params_); + } else { + auto dataset_h = raft::make_host_matrix_view( + x, n, dim_); + raft::neighbors::cagra::build_knn_graph( + raft_handle, + dataset_h, + knn_graph->view(), + 1.0f, + ivf_pq_params_, + ivf_pq_search_params_); + } + auto cagra_graph = raft::make_host_matrix( + n, index_pams_.graph_degree); + + raft::neighbors::cagra::optimize( + raft_handle, knn_graph->view(), cagra_graph.view()); + + // free intermediate graph before trying to create the index + knn_graph.reset(); + + if (getDeviceForAddress(x) >= 0) { + auto dataset_d = + raft::make_device_matrix_view( + x, n, dim_); + raft_knn_index = raft::neighbors::cagra::index( + raft_handle, + index_pams_.metric, + dataset_d, + raft::make_const_mdspan(cagra_graph.view())); + } else { + auto dataset_h = raft::make_host_matrix_view( + x, n, dim_); + raft_knn_index = raft::neighbors::cagra::index( + raft_handle, + index_pams_.metric, + dataset_h, + raft::make_const_mdspan(cagra_graph.view())); + } + } else { - raft_knn_index = raft::neighbors::cagra::build( - raft_handle, - index_pams_, - raft::make_host_matrix_view(x, n, dim_)); + if (getDeviceForAddress(x) >= 0) { + raft_knn_index = raft::neighbors::cagra::build( + raft_handle, + index_pams_, + raft::make_device_matrix_view( + x, n, dim_)); + } else { + raft_knn_index = raft::neighbors::cagra::build( + raft_handle, + index_pams_, + raft::make_host_matrix_view( + x, n, dim_)); + } } } @@ -181,11 +255,11 @@ void RaftCagra::search( FAISS_ASSERT(numQueries > 0); FAISS_ASSERT(cols == dim_); - auto queries_view = raft::make_device_matrix_view( + auto queries_view = raft::make_device_matrix_view( queries.data(), numQueries, cols); - auto distances_view = raft::make_device_matrix_view( + auto distances_view = raft::make_device_matrix_view( outDistances.data(), numQueries, k_); - auto indices_view = raft::make_device_matrix_view( + auto indices_view = raft::make_device_matrix_view( outIndices.data(), numQueries, k_); raft::neighbors::cagra::search_params search_pams; @@ -205,7 +279,7 @@ void RaftCagra::search( search_pams.num_random_samplings = num_random_samplings; search_pams.rand_xor_mask = rand_xor_mask; - auto indices_copy = raft::make_device_matrix( + auto indices_copy = raft::make_device_matrix( raft_handle, numQueries, k_); raft::neighbors::cagra::search( diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index 6d0bf69c17..d75ca29fc1 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -26,10 +26,12 @@ #include #include #include +#include #include #include +#include namespace faiss { @@ -52,7 +54,11 @@ class RaftCagra { size_t nn_descent_niter, faiss::MetricType metric, float metricArg, - IndicesOptions indicesOptions); + IndicesOptions indicesOptions, + std::optional ivf_pq_params = + std::nullopt, + std::optional + ivf_pq_search_params = std::nullopt); RaftCagra( GpuResources* resources, @@ -112,6 +118,10 @@ class RaftCagra { /// Parameters to build RAFT CAGRA index raft::neighbors::cagra::index_params index_pams_; + /// Parameters to build CAGRA graph using IVF PQ + std::optional ivf_pq_params_; + std::optional ivf_pq_search_params_; + /// Instance of trained RAFT CAGRA index std::optional> raft_knn_index{std::nullopt}; From 66d236f52f64707cc29012ecd57ae7a481c34959 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 22 Apr 2024 13:53:01 -0700 Subject: [PATCH 28/46] update comments style --- faiss/gpu/GpuIndexCagra.h | 230 +++++++++++++++++------------------ faiss/gpu/impl/RaftCagra.cuh | 1 + 2 files changed, 116 insertions(+), 115 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 5c04259092..0bccc27562 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -37,138 +37,138 @@ namespace gpu { class RaftCagra; enum class graph_build_algo { - /* Use IVF-PQ to build all-neighbors knn graph */ + /// Use IVF-PQ to build all-neighbors knn graph IVF_PQ, - /* Experimental, use NN-Descent to build all-neighbors knn graph */ + /// Experimental, use NN-Descent to build all-neighbors knn graph NN_DESCENT }; -/** A type for specifying how PQ codebooks are created. */ +/// A type for specifying how PQ codebooks are created. enum class codebook_gen { // NOLINT PER_SUBSPACE = 0, // NOLINT PER_CLUSTER = 1, // NOLINT }; struct IVFPQBuildCagraConfig { - /** - * The number of inverted lists (clusters) - * - * Hint: the number of vectors per cluster (`n_rows/n_lists`) should be approximately 1,000 to - * 10,000. - */ + /// + /// The number of inverted lists (clusters) + /// + /// Hint: the number of vectors per cluster (`n_rows/n_lists`) should be approximately 1,000 to + /// 10,000. + uint32_t n_lists = 1024; - /** The number of iterations searching for kmeans centers (index building). */ + /// The number of iterations searching for kmeans centers (index building). uint32_t kmeans_n_iters = 20; - /** The fraction of data to use during iterative kmeans building. */ + /// The fraction of data to use during iterative kmeans building. double kmeans_trainset_fraction = 0.5; - /** - * The bit length of the vector element after compression by PQ. - * - * Possible values: [4, 5, 6, 7, 8]. - * - * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search - * performance, but the lower the recall. - */ + /// + /// The bit length of the vector element after compression by PQ. + /// + /// Possible values: [4, 5, 6, 7, 8]. + /// + /// Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + /// performance, but the lower the recall. + uint32_t pq_bits = 8; - /** - * The dimensionality of the vector after compression by PQ. When zero, an optimal value is - * selected using a heuristic. - * - * NB: `pq_dim * pq_bits` must be a multiple of 8. - * - * Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but - * lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are - * desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. - * For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' - * should be also a divisor of the dataset dim. - */ + /// + /// The dimensionality of the vector after compression by PQ. When zero, an optimal value is + /// selected using a heuristic. + /// + /// NB: `pq_dim/// pq_bits` must be a multiple of 8. + /// + /// Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but + /// lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are + /// desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. + /// For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' + /// should be also a divisor of the dataset dim. + uint32_t pq_dim = 0; - /** How PQ codebooks are created. */ + /// How PQ codebooks are created. codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE; - /** - * Apply a random rotation matrix on the input data and queries even if `dim % pq_dim == 0`. - * - * Note: if `dim` is not multiple of `pq_dim`, a random rotation is always applied to the input - * data and queries to transform the working space from `dim` to `rot_dim`, which may be slightly - * larger than the original space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). - * However, this transform is not necessary when `dim` is multiple of `pq_dim` - * (`dim == rot_dim`, hence no need in adding "extra" data columns / features). - * - * By default, if `dim == rot_dim`, the rotation transform is initialized with the identity - * matrix. When `force_random_rotation == true`, a random orthogonal transform matrix is generated - * regardless of the values of `dim` and `pq_dim`. - */ + /// + /// Apply a random rotation matrix on the input data and queries even if `dim % pq_dim == 0`. + /// + /// Note: if `dim` is not multiple of `pq_dim`, a random rotation is always applied to the input + /// data and queries to transform the working space from `dim` to `rot_dim`, which may be slightly + /// larger than the original space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). + /// However, this transform is not necessary when `dim` is multiple of `pq_dim` + /// (`dim == rot_dim`, hence no need in adding "extra" data columns / features). + /// + /// By default, if `dim == rot_dim`, the rotation transform is initialized with the identity + /// matrix. When `force_random_rotation == true`, a random orthogonal transform matrix is generated + /// regardless of the values of `dim` and `pq_dim`. + bool force_random_rotation = false; - /** - * By default, the algorithm allocates more space than necessary for individual clusters - * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of - * data copies during repeated calls to `extend` (extending the database). - * - * The alternative is the conservative allocation behavior; when enabled, the algorithm always - * allocates the minimum amount of memory required to store the given number of records. Set this - * flag to `true` if you prefer to use as little GPU memory for the database as possible. - */ + /// + /// By default, the algorithm allocates more space than necessary for individual clusters + /// (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + /// data copies during repeated calls to `extend` (extending the database). + /// + /// The alternative is the conservative allocation behavior; when enabled, the algorithm always + /// allocates the minimum amount of memory required to store the given number of records. Set this + /// flag to `true` if you prefer to use as little GPU memory for the database as possible. + bool conservative_memory_allocation = false; }; struct IVFPQSearchCagraConfig { - /** The number of clusters to search. */ + /// The number of clusters to search. uint32_t n_probes = 20; - /** - * Data type of look up table to be created dynamically at search time. - * - * Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U] - * - * The use of low-precision types reduces the amount of shared memory required at search time, so - * fast shared memory kernels can be used even for datasets with large dimansionality. Note that - * the recall is slightly degraded when low-precision type is selected. - */ + /// + /// Data type of look up table to be created dynamically at search time. + /// + /// Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U] + /// + /// The use of low-precision types reduces the amount of shared memory required at search time, so + /// fast shared memory kernels can be used even for datasets with large dimansionality. Note that + /// the recall is slightly degraded when low-precision type is selected. + cudaDataType_t lut_dtype = CUDA_R_32F; - /** - * Storage data type for distance/similarity computed at search time. - * - * Possible values: [CUDA_R_16F, CUDA_R_32F] - * - * If the performance limiter at search time is device memory access, selecting FP16 will improve - * performance slightly. - */ + /// + /// Storage data type for distance/similarity computed at search time. + /// + /// Possible values: [CUDA_R_16F, CUDA_R_32F] + /// + /// If the performance limiter at search time is device memory access, selecting FP16 will improve + /// performance slightly. + cudaDataType_t internal_distance_dtype = CUDA_R_32F; - /** - * Preferred fraction of SM's unified memory / L1 cache to be used as shared memory. - * - * Possible values: [0.0 - 1.0] as a fraction of the `sharedMemPerMultiprocessor`. - * - * One wants to increase the carveout to make sure a good GPU occupancy for the main search - * kernel, but not to keep it too high to leave some memory to be used as L1 cache. Note, this - * value is interpreted only as a hint. Moreover, a GPU usually allows only a fixed set of cache - * configurations, so the provided value is rounded up to the nearest configuration. Refer to the - * NVIDIA tuning guide for the target GPU architecture. - * - * Note, this is a low-level tuning parameter that can have drastic negative effects on the search - * performance if tweaked incorrectly. - */ + /// + /// Preferred fraction of SM's unified memory / L1 cache to be used as shared memory. + /// + /// Possible values: [0.0 - 1.0] as a fraction of the `sharedMemPerMultiprocessor`. + /// + /// One wants to increase the carveout to make sure a good GPU occupancy for the main search + /// kernel, but not to keep it too high to leave some memory to be used as L1 cache. Note, this + /// value is interpreted only as a hint. Moreover, a GPU usually allows only a fixed set of cache + /// configurations, so the provided value is rounded up to the nearest configuration. Refer to the + /// NVIDIA tuning guide for the target GPU architecture. + /// + /// Note, this is a low-level tuning parameter that can have drastic negative effects on the search + /// performance if tweaked incorrectly. + double preferred_shmem_carveout = 1.0; }; struct GpuIndexCagraConfig : public GpuIndexConfig { - /** Degree of input graph for pruning. */ + /// Degree of input graph for pruning. size_t intermediate_graph_degree = 128; - /** Degree of output graph. */ + /// Degree of output graph. size_t graph_degree = 64; - /** ANN algorithm to build knn graph. */ + /// ANN algorithm to build knn graph. graph_build_algo build_algo = graph_build_algo::IVF_PQ; - /** Number of Iterations to run if building with NN_DESCENT */ + /// Number of Iterations to run if building with NN_DESCENT size_t nn_descent_niter = 20; - IVFPQBuildCagraConfig *ivf_pq_params = nullptr; - IVFPQSearchCagraConfig *ivf_pq_search_params = nullptr; + IVFPQBuildCagraConfig///ivf_pq_params = nullptr; + IVFPQSearchCagraConfig///ivf_pq_search_params = nullptr; }; enum class search_algo { - /** For large batch sizes. */ + /// For large batch sizes. SINGLE_CTA, - /** For small batch sizes. */ + /// For small batch sizes. MULTI_CTA, MULTI_KERNEL, AUTO @@ -177,49 +177,49 @@ enum class search_algo { enum class hash_mode { HASH, SMALL, AUTO }; struct SearchParametersCagra : SearchParameters { - /** Maximum number of queries to search at the same time (batch size). Auto - * select when 0.*/ + /// Maximum number of queries to search at the same time (batch size). Auto + /// select when 0. size_t max_queries = 0; - /** Number of intermediate search results retained during the search. - * - * This is the main knob to adjust trade off between accuracy and search - * speed. Higher values improve the search accuracy. - */ + /// Number of intermediate search results retained during the search. + /// + /// This is the main knob to adjust trade off between accuracy and search + /// speed. Higher values improve the search accuracy. + size_t itopk_size = 64; - /** Upper limit of search iterations. Auto select when 0.*/ + /// Upper limit of search iterations. Auto select when 0. size_t max_iterations = 0; // In the following we list additional search parameters for fine tuning. // Reasonable default values are automatically chosen. - /** Which search implementation to use. */ + /// Which search implementation to use. search_algo algo = search_algo::AUTO; - /** Number of threads used to calculate a single distance. 4, 8, 16, or 32. - */ + /// Number of threads used to calculate a single distance. 4, 8, 16, or 32. + size_t team_size = 0; - /** Number of graph nodes to select as the starting point for the search in - * each iteration. aka search width?*/ + /// Number of graph nodes to select as the starting point for the search in + /// each iteration. aka search width? size_t search_width = 1; - /** Lower limit of search iterations. */ + /// Lower limit of search iterations. size_t min_iterations = 0; - /** Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. */ + /// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. size_t thread_block_size = 0; - /** Hashmap type. Auto selection when AUTO. */ + /// Hashmap type. Auto selection when AUTO. hash_mode hashmap_mode = hash_mode::AUTO; - /** Lower limit of hashmap bit length. More than 8. */ + /// Lower limit of hashmap bit length. More than 8. size_t hashmap_min_bitlen = 0; - /** Upper limit of hashmap fill rate. More than 0.1, less than 0.9.*/ + /// Upper limit of hashmap fill rate. More than 0.1, less than 0.9. float hashmap_max_fill_rate = 0.5; - /** Number of iterations of initial random seed node selection. 1 or more. - */ + /// Number of iterations of initial random seed node selection. 1 or more. + uint32_t num_random_samplings = 1; - /** Bit mask used for initial random seed node selection. */ + /// Bit mask used for initial random seed node selection. uint64_t rand_xor_mask = 0x128394; }; diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index d75ca29fc1..878198d609 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -35,6 +35,7 @@ namespace faiss { +/// Algorithm used to build underlying CAGRA graph enum class cagra_build_algo { IVF_PQ, NN_DESCENT }; enum class cagra_search_algo { SINGLE_CTA, MULTI_CTA }; From 1d6e6b1674a1219483e4195eda6dd4f310bab099 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 22 Apr 2024 14:26:40 -0700 Subject: [PATCH 29/46] use raft::runtime where possible --- faiss/gpu/GpuIndexCagra.h | 4 ++-- faiss/gpu/impl/RaftCagra.cu | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 0bccc27562..cf34ec1900 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -161,8 +161,8 @@ struct GpuIndexCagraConfig : public GpuIndexConfig { /// Number of Iterations to run if building with NN_DESCENT size_t nn_descent_niter = 20; - IVFPQBuildCagraConfig///ivf_pq_params = nullptr; - IVFPQSearchCagraConfig///ivf_pq_search_params = nullptr; + IVFPQBuildCagraConfig *ivf_pq_params = nullptr; + IVFPQSearchCagraConfig *ivf_pq_search_params = nullptr; }; enum class search_algo { diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu index c2657103f2..2f51b6b35a 100644 --- a/faiss/gpu/impl/RaftCagra.cu +++ b/faiss/gpu/impl/RaftCagra.cu @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -212,13 +213,13 @@ void RaftCagra::train(idx_t n, const float* x) { } else { if (getDeviceForAddress(x) >= 0) { - raft_knn_index = raft::neighbors::cagra::build( + raft_knn_index = raft::runtime::neighbors::cagra::build( raft_handle, index_pams_, raft::make_device_matrix_view( x, n, dim_)); } else { - raft_knn_index = raft::neighbors::cagra::build( + raft_knn_index = raft::runtime::neighbors::cagra::build( raft_handle, index_pams_, raft::make_host_matrix_view( @@ -282,7 +283,7 @@ void RaftCagra::search( auto indices_copy = raft::make_device_matrix( raft_handle, numQueries, k_); - raft::neighbors::cagra::search( + raft::runtime::neighbors::cagra::search( raft_handle, search_pams, raft_knn_index.value(), From 4a01ad4a4cbdf1c3da8d9d2e3693d67b2e15fcd1 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 22 Apr 2024 14:39:10 -0700 Subject: [PATCH 30/46] format --- faiss/IndexHNSW.cpp | 1 - faiss/gpu/GpuIndexCagra.h | 180 ++++++++++++++++++------------------- faiss/impl/HNSW.cpp | 9 +- faiss/impl/index_write.cpp | 2 +- 4 files changed, 95 insertions(+), 97 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 5c542f3e89..f6a2ca9587 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -198,7 +198,6 @@ void hnsw_add_vertices( pt_level >= !index_hnsw.init_level0; pt_level--) { int i0 = i1 - hist[pt_level]; - // std::cout << "level: " << pt_level << "points: " << hist[pt_level] << std::endl; if (verbose) { printf("Adding %d elements at level %d\n", i1 - i0, pt_level); diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index cf34ec1900..563208857c 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -22,9 +22,9 @@ #pragma once +#include #include #include -#include #include "GpuIndexIVFPQ.h" namespace faiss { @@ -37,128 +37,128 @@ namespace gpu { class RaftCagra; enum class graph_build_algo { - /// Use IVF-PQ to build all-neighbors knn graph + /// Use IVF-PQ to build all-neighbors knn graph IVF_PQ, - /// Experimental, use NN-Descent to build all-neighbors knn graph + /// Experimental, use NN-Descent to build all-neighbors knn graph NN_DESCENT }; -/// A type for specifying how PQ codebooks are created. +/// A type for specifying how PQ codebooks are created. enum class codebook_gen { // NOLINT PER_SUBSPACE = 0, // NOLINT PER_CLUSTER = 1, // NOLINT }; struct IVFPQBuildCagraConfig { - /// - /// The number of inverted lists (clusters) - /// - /// Hint: the number of vectors per cluster (`n_rows/n_lists`) should be approximately 1,000 to - /// 10,000. + /// + /// The number of inverted lists (clusters) + /// + /// Hint: the number of vectors per cluster (`n_rows/n_lists`) should be approximately 1,000 to + /// 10,000. uint32_t n_lists = 1024; - /// The number of iterations searching for kmeans centers (index building). + /// The number of iterations searching for kmeans centers (index building). uint32_t kmeans_n_iters = 20; - /// The fraction of data to use during iterative kmeans building. + /// The fraction of data to use during iterative kmeans building. double kmeans_trainset_fraction = 0.5; - /// - /// The bit length of the vector element after compression by PQ. - /// - /// Possible values: [4, 5, 6, 7, 8]. - /// - /// Hint: the smaller the 'pq_bits', the smaller the index size and the better the search - /// performance, but the lower the recall. + /// + /// The bit length of the vector element after compression by PQ. + /// + /// Possible values: [4, 5, 6, 7, 8]. + /// + /// Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + /// performance, but the lower the recall. uint32_t pq_bits = 8; - /// - /// The dimensionality of the vector after compression by PQ. When zero, an optimal value is - /// selected using a heuristic. - /// - /// NB: `pq_dim/// pq_bits` must be a multiple of 8. - /// - /// Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but - /// lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are - /// desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. - /// For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' - /// should be also a divisor of the dataset dim. + /// + /// The dimensionality of the vector after compression by PQ. When zero, an optimal value is + /// selected using a heuristic. + /// + /// NB: `pq_dim /// pq_bits` must be a multiple of 8. + /// + /// Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but + /// lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are + /// desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. + /// For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' + /// should be also a divisor of the dataset dim. uint32_t pq_dim = 0; - /// How PQ codebooks are created. + /// How PQ codebooks are created. codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE; - /// - /// Apply a random rotation matrix on the input data and queries even if `dim % pq_dim == 0`. - /// - /// Note: if `dim` is not multiple of `pq_dim`, a random rotation is always applied to the input - /// data and queries to transform the working space from `dim` to `rot_dim`, which may be slightly - /// larger than the original space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). - /// However, this transform is not necessary when `dim` is multiple of `pq_dim` - /// (`dim == rot_dim`, hence no need in adding "extra" data columns / features). - /// - /// By default, if `dim == rot_dim`, the rotation transform is initialized with the identity - /// matrix. When `force_random_rotation == true`, a random orthogonal transform matrix is generated - /// regardless of the values of `dim` and `pq_dim`. + /// + /// Apply a random rotation matrix on the input data and queries even if `dim % pq_dim == 0`. + /// + /// Note: if `dim` is not multiple of `pq_dim`, a random rotation is always applied to the input + /// data and queries to transform the working space from `dim` to `rot_dim`, which may be slightly + /// larger than the original space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). + /// However, this transform is not necessary when `dim` is multiple of `pq_dim` + /// (`dim == rot_dim`, hence no need in adding "extra" data columns / features). + /// + /// By default, if `dim == rot_dim`, the rotation transform is initialized with the identity + /// matrix. When `force_random_rotation == true`, a random orthogonal transform matrix is generated + /// regardless of the values of `dim` and `pq_dim`. bool force_random_rotation = false; - /// - /// By default, the algorithm allocates more space than necessary for individual clusters - /// (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of - /// data copies during repeated calls to `extend` (extending the database). - /// - /// The alternative is the conservative allocation behavior; when enabled, the algorithm always - /// allocates the minimum amount of memory required to store the given number of records. Set this - /// flag to `true` if you prefer to use as little GPU memory for the database as possible. + /// + /// By default, the algorithm allocates more space than necessary for individual clusters + /// (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + /// data copies during repeated calls to `extend` (extending the database). + /// + /// The alternative is the conservative allocation behavior; when enabled, the algorithm always + /// allocates the minimum amount of memory required to store the given number of records. Set this + /// flag to `true` if you prefer to use as little GPU memory for the database as possible. bool conservative_memory_allocation = false; }; struct IVFPQSearchCagraConfig { - /// The number of clusters to search. + /// The number of clusters to search. uint32_t n_probes = 20; - /// - /// Data type of look up table to be created dynamically at search time. - /// - /// Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U] - /// - /// The use of low-precision types reduces the amount of shared memory required at search time, so - /// fast shared memory kernels can be used even for datasets with large dimansionality. Note that - /// the recall is slightly degraded when low-precision type is selected. + /// + /// Data type of look up table to be created dynamically at search time. + /// + /// Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U] + /// + /// The use of low-precision types reduces the amount of shared memory required at search time, so + /// fast shared memory kernels can be used even for datasets with large dimansionality. Note that + /// the recall is slightly degraded when low-precision type is selected. cudaDataType_t lut_dtype = CUDA_R_32F; - /// - /// Storage data type for distance/similarity computed at search time. - /// - /// Possible values: [CUDA_R_16F, CUDA_R_32F] - /// - /// If the performance limiter at search time is device memory access, selecting FP16 will improve - /// performance slightly. + /// + /// Storage data type for distance/similarity computed at search time. + /// + /// Possible values: [CUDA_R_16F, CUDA_R_32F] + /// + /// If the performance limiter at search time is device memory access, selecting FP16 will improve + /// performance slightly. cudaDataType_t internal_distance_dtype = CUDA_R_32F; - /// - /// Preferred fraction of SM's unified memory / L1 cache to be used as shared memory. - /// - /// Possible values: [0.0 - 1.0] as a fraction of the `sharedMemPerMultiprocessor`. - /// - /// One wants to increase the carveout to make sure a good GPU occupancy for the main search - /// kernel, but not to keep it too high to leave some memory to be used as L1 cache. Note, this - /// value is interpreted only as a hint. Moreover, a GPU usually allows only a fixed set of cache - /// configurations, so the provided value is rounded up to the nearest configuration. Refer to the - /// NVIDIA tuning guide for the target GPU architecture. - /// - /// Note, this is a low-level tuning parameter that can have drastic negative effects on the search - /// performance if tweaked incorrectly. + /// + /// Preferred fraction of SM's unified memory / L1 cache to be used as shared memory. + /// + /// Possible values: [0.0 - 1.0] as a fraction of the `sharedMemPerMultiprocessor`. + /// + /// One wants to increase the carveout to make sure a good GPU occupancy for the main search + /// kernel, but not to keep it too high to leave some memory to be used as L1 cache. Note, this + /// value is interpreted only as a hint. Moreover, a GPU usually allows only a fixed set of cache + /// configurations, so the provided value is rounded up to the nearest configuration. Refer to the + /// NVIDIA tuning guide for the target GPU architecture. + /// + /// Note, this is a low-level tuning parameter that can have drastic negative effects on the search + /// performance if tweaked incorrectly. double preferred_shmem_carveout = 1.0; }; struct GpuIndexCagraConfig : public GpuIndexConfig { - /// Degree of input graph for pruning. + /// Degree of input graph for pruning. size_t intermediate_graph_degree = 128; - /// Degree of output graph. + /// Degree of output graph. size_t graph_degree = 64; - /// ANN algorithm to build knn graph. + /// ANN algorithm to build knn graph. graph_build_algo build_algo = graph_build_algo::IVF_PQ; - /// Number of Iterations to run if building with NN_DESCENT + /// Number of Iterations to run if building with NN_DESCENT size_t nn_descent_niter = 20; IVFPQBuildCagraConfig *ivf_pq_params = nullptr; @@ -166,9 +166,9 @@ struct GpuIndexCagraConfig : public GpuIndexConfig { }; enum class search_algo { - /// For large batch sizes. + /// For large batch sizes. SINGLE_CTA, - /// For small batch sizes. + /// For small batch sizes. MULTI_CTA, MULTI_KERNEL, AUTO @@ -194,7 +194,7 @@ struct SearchParametersCagra : SearchParameters { // In the following we list additional search parameters for fine tuning. // Reasonable default values are automatically chosen. - /// Which search implementation to use. + /// Which search implementation to use. search_algo algo = search_algo::AUTO; /// Number of threads used to calculate a single distance. 4, 8, 16, or 32. @@ -204,14 +204,14 @@ struct SearchParametersCagra : SearchParameters { /// Number of graph nodes to select as the starting point for the search in /// each iteration. aka search width? size_t search_width = 1; - /// Lower limit of search iterations. + /// Lower limit of search iterations. size_t min_iterations = 0; - /// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. + /// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. size_t thread_block_size = 0; /// Hashmap type. Auto selection when AUTO. hash_mode hashmap_mode = hash_mode::AUTO; - /// Lower limit of hashmap bit length. More than 8. + /// Lower limit of hashmap bit length. More than 8. size_t hashmap_min_bitlen = 0; /// Upper limit of hashmap fill rate. More than 0.1, less than 0.9. float hashmap_max_fill_rate = 0.5; @@ -219,7 +219,7 @@ struct SearchParametersCagra : SearchParameters { /// Number of iterations of initial random seed node selection. 1 or more. uint32_t num_random_samplings = 1; - /// Bit mask used for initial random seed node selection. + /// Bit mask used for initial random seed node selection. uint64_t rand_xor_mask = 0x128394; }; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index efc837f267..26a2860d2b 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -265,7 +265,8 @@ void HNSW::shrink_neighbor_list( } } size_t idx = 0; - while (keep_max_size_level0 && (output.size() < max_size) && (idx < outsiders.size())) { + while (keep_max_size_level0 && (output.size() < max_size) && + (idx < outsiders.size())) { output.push_back(outsiders[idx++]); } } @@ -338,8 +339,7 @@ void add_link( resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh); } - shrink_neighbor_list( - qdis, resultSet, end - begin, keep_max_size_level0); + shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0); // ...and back size_t i = begin; @@ -459,8 +459,7 @@ void HNSW::add_links_starting_from( // but we can afford only this many neighbors int M = nb_neighbors(level); - ::faiss::shrink_neighbor_list( - ptdis, link_targets, M, keep_max_size_level0); + ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0); std::vector neighbors; neighbors.reserve(link_targets.size()); diff --git a/faiss/impl/index_write.cpp b/faiss/impl/index_write.cpp index 1f27a68451..efdc488112 100644 --- a/faiss/impl/index_write.cpp +++ b/faiss/impl/index_write.cpp @@ -760,7 +760,7 @@ void write_index(const Index* idx, IOWriter* f) { : dynamic_cast(idx) ? fourcc("IHNp") : dynamic_cast(idx) ? fourcc("IHNs") : dynamic_cast(idx) ? fourcc("IHN2") - : dynamic_cast(idx) ? fourcc("IHNc") + : dynamic_cast(idx) ? fourcc("IHNc") : 0; FAISS_THROW_IF_NOT(h != 0); WRITE1(h); From 949e6349f9dece623d097e502631a102db6195d4 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 22 Apr 2024 14:46:52 -0700 Subject: [PATCH 31/46] format properly --- faiss/gpu/GpuIndexCagra.h | 160 ++++++++++++++++++++------------------ faiss/impl/HNSW.cpp | 2 +- 2 files changed, 87 insertions(+), 75 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 563208857c..a97543f3d9 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -44,112 +44,124 @@ enum class graph_build_algo { }; /// A type for specifying how PQ codebooks are created. -enum class codebook_gen { // NOLINT - PER_SUBSPACE = 0, // NOLINT - PER_CLUSTER = 1, // NOLINT +enum class codebook_gen { // NOLINT + PER_SUBSPACE = 0, // NOLINT + PER_CLUSTER = 1, // NOLINT }; struct IVFPQBuildCagraConfig { /// /// The number of inverted lists (clusters) /// - /// Hint: the number of vectors per cluster (`n_rows/n_lists`) should be approximately 1,000 to - /// 10,000. - - uint32_t n_lists = 1024; - /// The number of iterations searching for kmeans centers (index building). - uint32_t kmeans_n_iters = 20; - /// The fraction of data to use during iterative kmeans building. - double kmeans_trainset_fraction = 0.5; + /// Hint: the number of vectors per cluster (`n_rows/n_lists`) should be + /// approximately 1,000 to 10,000. + + uint32_t n_lists = 1024; + /// The number of iterations searching for kmeans centers (index building). + uint32_t kmeans_n_iters = 20; + /// The fraction of data to use during iterative kmeans building. + double kmeans_trainset_fraction = 0.5; /// /// The bit length of the vector element after compression by PQ. /// /// Possible values: [4, 5, 6, 7, 8]. /// - /// Hint: the smaller the 'pq_bits', the smaller the index size and the better the search - /// performance, but the lower the recall. - - uint32_t pq_bits = 8; + /// Hint: the smaller the 'pq_bits', the smaller the index size and the + /// better the search performance, but the lower the recall. + + uint32_t pq_bits = 8; /// - /// The dimensionality of the vector after compression by PQ. When zero, an optimal value is - /// selected using a heuristic. + /// The dimensionality of the vector after compression by PQ. When zero, an + /// optimal value is selected using a heuristic. /// /// NB: `pq_dim /// pq_bits` must be a multiple of 8. /// - /// Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but - /// lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are - /// desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. - /// For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' - /// should be also a divisor of the dataset dim. - - uint32_t pq_dim = 0; - /// How PQ codebooks are created. - codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE; + /// Hint: a smaller 'pq_dim' results in a smaller index size and better + /// search performance, but lower recall. If 'pq_bits' is 8, 'pq_dim' can be + /// set to any number, but multiple of 8 are desirable for good performance. + /// If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. For good + /// performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, + /// 'pq_dim' should be also a divisor of the dataset dim. + + uint32_t pq_dim = 0; + /// How PQ codebooks are created. + codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE; /// - /// Apply a random rotation matrix on the input data and queries even if `dim % pq_dim == 0`. + /// Apply a random rotation matrix on the input data and queries even if + /// `dim % pq_dim == 0`. /// - /// Note: if `dim` is not multiple of `pq_dim`, a random rotation is always applied to the input - /// data and queries to transform the working space from `dim` to `rot_dim`, which may be slightly - /// larger than the original space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). - /// However, this transform is not necessary when `dim` is multiple of `pq_dim` - /// (`dim == rot_dim`, hence no need in adding "extra" data columns / features). + /// Note: if `dim` is not multiple of `pq_dim`, a random rotation is always + /// applied to the input data and queries to transform the working space + /// from `dim` to `rot_dim`, which may be slightly larger than the original + /// space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). + /// However, this transform is not necessary when `dim` is multiple of + /// `pq_dim` + /// (`dim == rot_dim`, hence no need in adding "extra" data columns / + /// features). /// - /// By default, if `dim == rot_dim`, the rotation transform is initialized with the identity - /// matrix. When `force_random_rotation == true`, a random orthogonal transform matrix is generated - /// regardless of the values of `dim` and `pq_dim`. - - bool force_random_rotation = false; + /// By default, if `dim == rot_dim`, the rotation transform is initialized + /// with the identity matrix. When `force_random_rotation == true`, a random + /// orthogonal transform matrix is generated regardless of the values of + /// `dim` and `pq_dim`. + + bool force_random_rotation = false; /// - /// By default, the algorithm allocates more space than necessary for individual clusters - /// (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of - /// data copies during repeated calls to `extend` (extending the database). + /// By default, the algorithm allocates more space than necessary for + /// individual clusters + /// (`list_data`). This allows to amortize the cost of memory allocation and + /// reduce the number of data copies during repeated calls to `extend` + /// (extending the database). /// - /// The alternative is the conservative allocation behavior; when enabled, the algorithm always - /// allocates the minimum amount of memory required to store the given number of records. Set this - /// flag to `true` if you prefer to use as little GPU memory for the database as possible. - - bool conservative_memory_allocation = false; + /// The alternative is the conservative allocation behavior; when enabled, + /// the algorithm always allocates the minimum amount of memory required to + /// store the given number of records. Set this flag to `true` if you prefer + /// to use as little GPU memory for the database as possible. + + bool conservative_memory_allocation = false; }; struct IVFPQSearchCagraConfig { - /// The number of clusters to search. - uint32_t n_probes = 20; + /// The number of clusters to search. + uint32_t n_probes = 20; /// /// Data type of look up table to be created dynamically at search time. /// /// Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U] /// - /// The use of low-precision types reduces the amount of shared memory required at search time, so - /// fast shared memory kernels can be used even for datasets with large dimansionality. Note that - /// the recall is slightly degraded when low-precision type is selected. - - cudaDataType_t lut_dtype = CUDA_R_32F; + /// The use of low-precision types reduces the amount of shared memory + /// required at search time, so fast shared memory kernels can be used even + /// for datasets with large dimansionality. Note that the recall is slightly + /// degraded when low-precision type is selected. + + cudaDataType_t lut_dtype = CUDA_R_32F; /// /// Storage data type for distance/similarity computed at search time. /// /// Possible values: [CUDA_R_16F, CUDA_R_32F] /// - /// If the performance limiter at search time is device memory access, selecting FP16 will improve - /// performance slightly. - - cudaDataType_t internal_distance_dtype = CUDA_R_32F; + /// If the performance limiter at search time is device memory access, + /// selecting FP16 will improve performance slightly. + + cudaDataType_t internal_distance_dtype = CUDA_R_32F; /// - /// Preferred fraction of SM's unified memory / L1 cache to be used as shared memory. + /// Preferred fraction of SM's unified memory / L1 cache to be used as + /// shared memory. /// - /// Possible values: [0.0 - 1.0] as a fraction of the `sharedMemPerMultiprocessor`. + /// Possible values: [0.0 - 1.0] as a fraction of the + /// `sharedMemPerMultiprocessor`. /// - /// One wants to increase the carveout to make sure a good GPU occupancy for the main search - /// kernel, but not to keep it too high to leave some memory to be used as L1 cache. Note, this - /// value is interpreted only as a hint. Moreover, a GPU usually allows only a fixed set of cache - /// configurations, so the provided value is rounded up to the nearest configuration. Refer to the - /// NVIDIA tuning guide for the target GPU architecture. + /// One wants to increase the carveout to make sure a good GPU occupancy for + /// the main search kernel, but not to keep it too high to leave some memory + /// to be used as L1 cache. Note, this value is interpreted only as a hint. + /// Moreover, a GPU usually allows only a fixed set of cache configurations, + /// so the provided value is rounded up to the nearest configuration. Refer + /// to the NVIDIA tuning guide for the target GPU architecture. /// - /// Note, this is a low-level tuning parameter that can have drastic negative effects on the search - /// performance if tweaked incorrectly. - - double preferred_shmem_carveout = 1.0; -}; + /// Note, this is a low-level tuning parameter that can have drastic + /// negative effects on the search performance if tweaked incorrectly. + double preferred_shmem_carveout = 1.0; +}; struct GpuIndexCagraConfig : public GpuIndexConfig { /// Degree of input graph for pruning. @@ -161,8 +173,8 @@ struct GpuIndexCagraConfig : public GpuIndexConfig { /// Number of Iterations to run if building with NN_DESCENT size_t nn_descent_niter = 20; - IVFPQBuildCagraConfig *ivf_pq_params = nullptr; - IVFPQSearchCagraConfig *ivf_pq_search_params = nullptr; + IVFPQBuildCagraConfig* ivf_pq_params = nullptr; + IVFPQSearchCagraConfig* ivf_pq_search_params = nullptr; }; enum class search_algo { @@ -185,7 +197,7 @@ struct SearchParametersCagra : SearchParameters { /// /// This is the main knob to adjust trade off between accuracy and search /// speed. Higher values improve the search accuracy. - + size_t itopk_size = 64; /// Upper limit of search iterations. Auto select when 0. @@ -198,7 +210,7 @@ struct SearchParametersCagra : SearchParameters { search_algo algo = search_algo::AUTO; /// Number of threads used to calculate a single distance. 4, 8, 16, or 32. - + size_t team_size = 0; /// Number of graph nodes to select as the starting point for the search in @@ -209,7 +221,7 @@ struct SearchParametersCagra : SearchParameters { /// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. size_t thread_block_size = 0; - /// Hashmap type. Auto selection when AUTO. + /// Hashmap type. Auto selection when AUTO. hash_mode hashmap_mode = hash_mode::AUTO; /// Lower limit of hashmap bit length. More than 8. size_t hashmap_min_bitlen = 0; @@ -217,7 +229,7 @@ struct SearchParametersCagra : SearchParameters { float hashmap_max_fill_rate = 0.5; /// Number of iterations of initial random seed node selection. 1 or more. - + uint32_t num_random_samplings = 1; /// Bit mask used for initial random seed node selection. uint64_t rand_xor_mask = 0x128394; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 26a2860d2b..fedfd801a8 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -265,7 +265,7 @@ void HNSW::shrink_neighbor_list( } } size_t idx = 0; - while (keep_max_size_level0 && (output.size() < max_size) && + while (keep_max_size_level0 && (output.size() < max_size) && (idx < outsiders.size())) { output.push_back(outsiders[idx++]); } From bccd54a3f6a8308ed5e10fab748bf2a6c7c6949e Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 30 Apr 2024 11:30:33 -0700 Subject: [PATCH 32/46] InnerProduct --- faiss/gpu/impl/RaftCagra.cu | 50 ++++++++++++++++++----------- faiss/gpu/impl/RaftCagra.cuh | 2 +- faiss/gpu/test/TestGpuIndexCagra.cu | 36 +++++++++++++-------- 3 files changed, 55 insertions(+), 33 deletions(-) diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu index 2f51b6b35a..292079321d 100644 --- a/faiss/gpu/impl/RaftCagra.cu +++ b/faiss/gpu/impl/RaftCagra.cu @@ -52,22 +52,22 @@ RaftCagra::RaftCagra( dim_(dim), metric_(metric), metricArg_(metricArg), - index_pams_(), + index_params_(), ivf_pq_params_(ivf_pq_params), ivf_pq_search_params_(ivf_pq_search_params) { FAISS_THROW_IF_NOT_MSG( - metric == faiss::METRIC_L2, - "CAGRA currently only supports L2 metric."); + metric == faiss::METRIC_L2 || metric == faiss::METRIC_INNER_PRODUCT, + "CAGRA currently only supports L2 or Inner Product metric."); FAISS_THROW_IF_NOT_MSG( indicesOptions == faiss::gpu::INDICES_64_BIT, "only INDICES_64_BIT is supported for RAFT CAGRA index"); - index_pams_.intermediate_graph_degree = intermediate_graph_degree; - index_pams_.graph_degree = graph_degree; - index_pams_.build_algo = + index_params_.intermediate_graph_degree = intermediate_graph_degree; + index_params_.graph_degree = graph_degree; + index_params_.build_algo = static_cast( graph_build_algo); - index_pams_.nn_descent_niter = nn_descent_niter; + index_params_.nn_descent_niter = nn_descent_niter; if (!ivf_pq_params_) { ivf_pq_params_ = @@ -77,6 +77,12 @@ RaftCagra::RaftCagra( ivf_pq_search_params_ = std::make_optional(); } + index_params_.metric = metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct; + ivf_pq_params_->metric = metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct; reset(); } @@ -96,8 +102,8 @@ RaftCagra::RaftCagra( metric_(metric), metricArg_(metricArg) { FAISS_THROW_IF_NOT_MSG( - metric == faiss::METRIC_L2, - "CAGRA currently only supports L2 metric."); + metric == faiss::METRIC_L2 || metric == faiss::METRIC_INNER_PRODUCT, + "CAGRA currently only supports L2 or Inner Product metric."); FAISS_THROW_IF_NOT_MSG( indicesOptions == faiss::gpu::INDICES_64_BIT, "only INDICES_64_BIT is supported for RAFT CAGRA index"); @@ -127,7 +133,9 @@ RaftCagra::RaftCagra( raft_knn_index = raft::neighbors::cagra::index( raft_handle, - raft::distance::DistanceType::L2Expanded, + metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct, distances_mds, raft::make_const_mdspan(knn_graph_copy.view())); } else if (!distances_on_gpu && !knn_graph_on_gpu) { @@ -144,7 +152,9 @@ RaftCagra::RaftCagra( raft_knn_index = raft::neighbors::cagra::index( raft_handle, - raft::distance::DistanceType::L2Expanded, + metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct, distances_mds, raft::make_const_mdspan(knn_graph_copy.view())); } else { @@ -156,11 +166,11 @@ RaftCagra::RaftCagra( void RaftCagra::train(idx_t n, const float* x) { const raft::device_resources& raft_handle = resources_->getRaftHandleCurrentDevice(); - if (index_pams_.build_algo == + if (index_params_.build_algo == raft::neighbors::cagra::graph_build_algo::IVF_PQ) { std::optional> knn_graph( raft::make_host_matrix( - n, index_pams_.intermediate_graph_degree)); + n, index_params_.intermediate_graph_degree)); if (getDeviceForAddress(x) >= 0) { auto dataset_d = raft::make_device_matrix_view( @@ -184,7 +194,7 @@ void RaftCagra::train(idx_t n, const float* x) { ivf_pq_search_params_); } auto cagra_graph = raft::make_host_matrix( - n, index_pams_.graph_degree); + n, index_params_.graph_degree); raft::neighbors::cagra::optimize( raft_handle, knn_graph->view(), cagra_graph.view()); @@ -198,7 +208,9 @@ void RaftCagra::train(idx_t n, const float* x) { x, n, dim_); raft_knn_index = raft::neighbors::cagra::index( raft_handle, - index_pams_.metric, + metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct, dataset_d, raft::make_const_mdspan(cagra_graph.view())); } else { @@ -206,7 +218,9 @@ void RaftCagra::train(idx_t n, const float* x) { x, n, dim_); raft_knn_index = raft::neighbors::cagra::index( raft_handle, - index_pams_.metric, + metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct, dataset_h, raft::make_const_mdspan(cagra_graph.view())); } @@ -215,13 +229,13 @@ void RaftCagra::train(idx_t n, const float* x) { if (getDeviceForAddress(x) >= 0) { raft_knn_index = raft::runtime::neighbors::cagra::build( raft_handle, - index_pams_, + index_params_, raft::make_device_matrix_view( x, n, dim_)); } else { raft_knn_index = raft::runtime::neighbors::cagra::build( raft_handle, - index_pams_, + index_params_, raft::make_host_matrix_view( x, n, dim_)); } diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index 878198d609..95f6c03fca 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -117,7 +117,7 @@ class RaftCagra { float metricArg_; /// Parameters to build RAFT CAGRA index - raft::neighbors::cagra::index_params index_pams_; + raft::neighbors::cagra::index_params index_params_; /// Parameters to build CAGRA graph using IVF PQ std::optional ivf_pq_params_; diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index 54987cd5f5..c4b6c8a768 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -75,7 +75,7 @@ struct Options { int device; }; -void queryTest() { +void queryTest(faiss::MetricType metric) { for (int tries = 0; tries < 5; ++tries) { Options opt; @@ -97,8 +97,7 @@ void queryTest() { config.intermediate_graph_degree = opt.intermediateGraphDegree; config.build_algo = opt.buildAlgo; - faiss::gpu::GpuIndexCagra gpuIndex( - &res, cpuIndex.d, faiss::METRIC_L2, config); + faiss::gpu::GpuIndexCagra gpuIndex(&res, cpuIndex.d, metric, config); gpuIndex.train(opt.numTrain, trainVecs.data()); // query @@ -177,10 +176,14 @@ void queryTest() { } TEST(TestGpuIndexCagra, Float32_Query_L2) { - queryTest(); + queryTest(faiss::METRIC_L2); } -void copyToTest() { +TEST(TestGpuIndexCagra, Float32_Query_IP) { + queryTest(faiss::METRIC_INNER_PRODUCT); +} + +void copyToTest(faiss::MetricType metric) { for (int tries = 0; tries < 5; ++tries) { Options opt; @@ -198,8 +201,7 @@ void copyToTest() { config.intermediate_graph_degree = opt.intermediateGraphDegree; config.build_algo = opt.buildAlgo; - faiss::gpu::GpuIndexCagra gpuIndex( - &res, opt.dim, faiss::METRIC_L2, config); + faiss::gpu::GpuIndexCagra gpuIndex(&res, opt.dim, metric, config); gpuIndex.train(opt.numTrain, trainVecs.data()); faiss::IndexHNSWCagra copiedCpuIndex; @@ -300,10 +302,14 @@ void copyToTest() { } TEST(TestGpuIndexCagra, Float32_CopyTo_L2) { - copyToTest(); + copyToTest(faiss::METRIC_L2); +} + +TEST(TestGpuIndexCagra, Float32_CopyTo_IP) { + copyToTest(faiss::METRIC_INNER_PRODUCT); } -void copyFromTest() { +void copyFromTest(faiss::MetricType metric) { for (int tries = 0; tries < 5; ++tries) { Options opt; @@ -319,8 +325,7 @@ void copyFromTest() { res.noTempMemory(); // convert to gpu index - faiss::gpu::GpuIndexCagra copiedGpuIndex( - &res, cpuIndex.d, faiss::METRIC_L2); + faiss::gpu::GpuIndexCagra copiedGpuIndex(&res, cpuIndex.d, metric); copiedGpuIndex.copyFrom(&cpuIndex); // train gpu index @@ -330,8 +335,7 @@ void copyFromTest() { config.intermediate_graph_degree = opt.intermediateGraphDegree; config.build_algo = opt.buildAlgo; - faiss::gpu::GpuIndexCagra gpuIndex( - &res, opt.dim, faiss::METRIC_L2, config); + faiss::gpu::GpuIndexCagra gpuIndex(&res, opt.dim, metric, config); gpuIndex.train(opt.numTrain, trainVecs.data()); // query @@ -402,7 +406,11 @@ void copyFromTest() { } TEST(TestGpuIndexCagra, Float32_CopyFrom_L2) { - copyFromTest(); + copyFromTest(faiss::METRIC_L2); +} + +TEST(TestGpuIndexCagra, Float32_CopyFrom_IP) { + copyFromTest(faiss::METRIC_INNER_PRODUCT); } int main(int argc, char** argv) { From 2aaa6e91476ee079ade9cf4a5197bb4783f7fda3 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 7 May 2024 16:01:34 -0700 Subject: [PATCH 33/46] passing ip tests --- faiss/IndexHNSW.cpp | 12 ++++++-- faiss/IndexHNSW.h | 2 +- faiss/gpu/GpuIndexCagra.cu | 14 +++++---- faiss/gpu/test/TestGpuIndexCagra.cu | 45 +++++++++++++++++++---------- 4 files changed, 49 insertions(+), 24 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index f6a2ca9587..7d5d37e838 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -932,8 +933,15 @@ IndexHNSWCagra::IndexHNSWCagra() { is_trained = true; } -IndexHNSWCagra::IndexHNSWCagra(int d, int M) - : IndexHNSW(new IndexFlatL2(d), M) { +IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric) + : IndexHNSW( + (metric == METRIC_L2) + ? static_cast(new IndexFlatL2(d)) + : static_cast(new IndexFlatIP(d)), + M) { + FAISS_THROW_IF_NOT_MSG( + ((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)), + "unsupported metric type for IndexHNSWCagra"); own_fields = true; is_trained = true; init_level0 = true; diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index 3d3162e423..12a90cabe4 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -153,7 +153,7 @@ struct IndexHNSW2Level : IndexHNSW { struct IndexHNSWCagra : IndexHNSW { IndexHNSWCagra(); - IndexHNSWCagra(int d, int M); + IndexHNSWCagra(int d, int M, MetricType metric = METRIC_L2); }; } // namespace faiss diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index 916f774bc1..dcd2e6944b 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -160,10 +160,9 @@ void GpuIndexCagra::copyFrom(const faiss::IndexHNSWCagra* index) { GpuIndex::copyFrom(index); - auto base_index = index->storage; - auto l2_index = dynamic_cast(base_index); - FAISS_ASSERT(l2_index); - auto distances = l2_index->get_xb(); + auto base_index = dynamic_cast(index->storage); + FAISS_ASSERT(base_index); + auto distances = base_index->get_xb(); auto hnsw = index->hnsw; // copy level 0 to a dense knn graph matrix @@ -213,7 +212,12 @@ void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { if (index->storage and index->own_fields) { delete index->storage; } - index->storage = new IndexFlatL2(index->d); + + if (this->metric_type == METRIC_L2) { + index->storage = new IndexFlatL2(index->d); + } else if (this->metric_type == METRIC_INNER_PRODUCT) { + index->storage = new IndexFlatIP(index->d); + } index->own_fields = true; index->keep_max_size_level0 = true; index->hnsw.reset(); diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index c4b6c8a768..228c9ac39e 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -75,15 +75,19 @@ struct Options { int device; }; -void queryTest(faiss::MetricType metric) { +void queryTest(faiss::MetricType metric, double expected_recall) { for (int tries = 0; tries < 5; ++tries) { Options opt; + if (opt.buildAlgo == faiss::gpu::graph_build_algo::NN_DESCENT && + metric == faiss::METRIC_INNER_PRODUCT) { + continue; + } std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); // train cpu index - faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2); + faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2, metric); cpuIndex.hnsw.efConstruction = opt.k * 2; cpuIndex.add(opt.numTrain, trainVecs.data()); @@ -171,21 +175,25 @@ void queryTest(faiss::MetricType metric) { recall_score.view(), test_dis_mds_opt, ref_dis_mds_opt); - ASSERT_TRUE(*recall_score.data_handle() > 0.98); + ASSERT_TRUE(*recall_score.data_handle() > expected_recall); } } TEST(TestGpuIndexCagra, Float32_Query_L2) { - queryTest(faiss::METRIC_L2); + queryTest(faiss::METRIC_L2, 0.98); } TEST(TestGpuIndexCagra, Float32_Query_IP) { - queryTest(faiss::METRIC_INNER_PRODUCT); + queryTest(faiss::METRIC_INNER_PRODUCT, 0.88); } -void copyToTest(faiss::MetricType metric) { +void copyToTest(faiss::MetricType metric, double expected_recall) { for (int tries = 0; tries < 5; ++tries) { Options opt; + if (opt.buildAlgo == faiss::gpu::graph_build_algo::NN_DESCENT && + metric == faiss::METRIC_INNER_PRODUCT) { + continue; + } std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); @@ -204,7 +212,8 @@ void copyToTest(faiss::MetricType metric) { faiss::gpu::GpuIndexCagra gpuIndex(&res, opt.dim, metric, config); gpuIndex.train(opt.numTrain, trainVecs.data()); - faiss::IndexHNSWCagra copiedCpuIndex; + faiss::IndexHNSWCagra copiedCpuIndex( + opt.dim, opt.graphDegree / 2, metric); gpuIndex.copyTo(&copiedCpuIndex); copiedCpuIndex.hnsw.efConstruction = opt.k * 2; @@ -212,7 +221,7 @@ void copyToTest(faiss::MetricType metric) { copiedCpuIndex.add(opt.numAdd, addVecs.data()); // train cpu index - faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2); + faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2, metric); cpuIndex.hnsw.efConstruction = opt.k * 2; cpuIndex.add(opt.numTrain, trainVecs.data()); @@ -297,27 +306,31 @@ void copyToTest(faiss::MetricType metric) { recall_score.view(), copy_ref_dis_mds_opt, ref_dis_mds_opt); - ASSERT_TRUE(*recall_score.data_handle() > 0.99); + ASSERT_TRUE(*recall_score.data_handle() > expected_recall); } } TEST(TestGpuIndexCagra, Float32_CopyTo_L2) { - copyToTest(faiss::METRIC_L2); + copyToTest(faiss::METRIC_L2, 0.98); } TEST(TestGpuIndexCagra, Float32_CopyTo_IP) { - copyToTest(faiss::METRIC_INNER_PRODUCT); + copyToTest(faiss::METRIC_INNER_PRODUCT, 0.88); } -void copyFromTest(faiss::MetricType metric) { +void copyFromTest(faiss::MetricType metric, double expected_recall) { for (int tries = 0; tries < 5; ++tries) { Options opt; + if (opt.buildAlgo == faiss::gpu::graph_build_algo::NN_DESCENT && + metric == faiss::METRIC_INNER_PRODUCT) { + continue; + } std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); // train cpu index - faiss::IndexHNSWCagra cpuIndex(opt.dim, opt.graphDegree / 2); + faiss::IndexHNSWCagra cpuIndex(opt.dim, opt.graphDegree / 2, metric); cpuIndex.hnsw.efConstruction = opt.k * 2; cpuIndex.add(opt.numTrain, trainVecs.data()); @@ -401,16 +414,16 @@ void copyFromTest(faiss::MetricType metric) { recall_score.view(), copy_test_dis_mds_opt, test_dis_mds_opt); - ASSERT_TRUE(*recall_score.data_handle() > 0.99); + ASSERT_TRUE(*recall_score.data_handle() > expected_recall); } } TEST(TestGpuIndexCagra, Float32_CopyFrom_L2) { - copyFromTest(faiss::METRIC_L2); + copyFromTest(faiss::METRIC_L2, 0.98); } TEST(TestGpuIndexCagra, Float32_CopyFrom_IP) { - copyFromTest(faiss::METRIC_INNER_PRODUCT); + copyFromTest(faiss::METRIC_INNER_PRODUCT, 0.88); } int main(int argc, char** argv) { From 70b0ab8fab26f1751625165b1adf3a5c8e968997 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 9 May 2024 13:02:23 -0700 Subject: [PATCH 34/46] address review --- faiss/IndexHNSW.cpp | 2 -- faiss/IndexHNSW.h | 9 +++++++++ faiss/gpu/GpuIndexCagra.cu | 3 +-- faiss/gpu/GpuIndexCagra.h | 3 +-- faiss/gpu/test/TestGpuIndexCagra.cu | 6 +++--- faiss/impl/HNSW.cpp | 2 -- faiss/impl/index_read.cpp | 3 ++- faiss/impl/index_write.cpp | 3 ++- 8 files changed, 18 insertions(+), 13 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 7d5d37e838..f1e018fd4e 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -7,8 +7,6 @@ // -*- c++ -*- -#include - #include #include diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index 12a90cabe4..a4675157de 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -34,7 +34,16 @@ struct IndexHNSW : Index { bool own_fields = false; Index* storage = nullptr; + // When set to false, level 0 in the knn graph is not initialized. + // This option is used by GpuIndexCagra::copyTo(IndexHNSWCagra*) + // as level 0 knn graph is copied over from the index built by + // GpuIndexCagra. bool init_level0 = true; + + // When set to true, all neighbors in level 0 are filled up + // to the maximum size allowed (2 * M). This option is used by + // IndexHHNSWCagra to create a full base layer graph that is + // used when GpuIndexCagra::copyFrom(IndexHNSWCagra*) is invoked. bool keep_max_size_level0 = false; explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2); diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index dcd2e6944b..634a6c1095 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -25,7 +25,6 @@ #include #include #include -#include "GpuIndexCagra.h" namespace faiss { namespace gpu { @@ -146,7 +145,7 @@ void GpuIndexCagra::searchImpl_( params->hashmap_min_bitlen, params->hashmap_max_fill_rate, params->num_random_samplings, - params->rand_xor_mask); + params->seed); if (not search_params) { delete params; diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index a97543f3d9..6ecee3ae03 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -25,7 +25,6 @@ #include #include #include -#include "GpuIndexIVFPQ.h" namespace faiss { struct IndexHNSWCagra; @@ -232,7 +231,7 @@ struct SearchParametersCagra : SearchParameters { uint32_t num_random_samplings = 1; /// Bit mask used for initial random seed node selection. - uint64_t rand_xor_mask = 0x128394; + uint64_t seed = 0x128394; }; struct GpuIndexCagra : public GpuIndex { diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index 228c9ac39e..b763c591ae 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -184,7 +184,7 @@ TEST(TestGpuIndexCagra, Float32_Query_L2) { } TEST(TestGpuIndexCagra, Float32_Query_IP) { - queryTest(faiss::METRIC_INNER_PRODUCT, 0.88); + queryTest(faiss::METRIC_INNER_PRODUCT, 0.85); } void copyToTest(faiss::MetricType metric, double expected_recall) { @@ -315,7 +315,7 @@ TEST(TestGpuIndexCagra, Float32_CopyTo_L2) { } TEST(TestGpuIndexCagra, Float32_CopyTo_IP) { - copyToTest(faiss::METRIC_INNER_PRODUCT, 0.88); + copyToTest(faiss::METRIC_INNER_PRODUCT, 0.85); } void copyFromTest(faiss::MetricType metric, double expected_recall) { @@ -423,7 +423,7 @@ TEST(TestGpuIndexCagra, Float32_CopyFrom_L2) { } TEST(TestGpuIndexCagra, Float32_CopyFrom_IP) { - copyFromTest(faiss::METRIC_INNER_PRODUCT, 0.88); + copyFromTest(faiss::METRIC_INNER_PRODUCT, 0.85); } int main(int argc, char** argv) { diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index fedfd801a8..a07d1556d5 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -#include - #include #include diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index 5f231997ff..5b08640295 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -961,7 +961,8 @@ Index* read_index(IOReader* f, int io_flags) { if (h == fourcc("IHNc")) idxhnsw = new IndexHNSWCagra(); read_index_header(idxhnsw, f); - READ1(idxhnsw->keep_max_size_level0); + if (h == fourcc("IHNc")) + READ1(idxhnsw->keep_max_size_level0); read_HNSW(&idxhnsw->hnsw, f); idxhnsw->storage = read_index(f, io_flags); idxhnsw->own_fields = true; diff --git a/faiss/impl/index_write.cpp b/faiss/impl/index_write.cpp index efdc488112..e9c6a23a64 100644 --- a/faiss/impl/index_write.cpp +++ b/faiss/impl/index_write.cpp @@ -765,7 +765,8 @@ void write_index(const Index* idx, IOWriter* f) { FAISS_THROW_IF_NOT(h != 0); WRITE1(h); write_index_header(idxhnsw, f); - WRITE1(idxhnsw->keep_max_size_level0); + if (h == fourcc("IHNc")) + WRITE1(idxhnsw->keep_max_size_level0); write_HNSW(&idxhnsw->hnsw, f); write_index(idxhnsw->storage, f); } else if (const IndexNSG* idxnsg = dynamic_cast(idx)) { From 4148feaef697b4974321d1225e7724be30049722 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 21 May 2024 13:05:01 -0700 Subject: [PATCH 35/46] base level only search --- faiss/IndexHNSW.cpp | 64 +++++++++++++++++++++++++++++ faiss/IndexHNSW.h | 24 +++++++++++ faiss/gpu/GpuIndexCagra.cu | 8 +++- faiss/gpu/test/TestGpuIndexCagra.cu | 54 ++++++++++++++++++++---- faiss/gpu/test/test_cagra.py | 44 ++++++++++++++++++++ 5 files changed, 186 insertions(+), 8 deletions(-) create mode 100644 faiss/gpu/test/test_cagra.py diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index f1e018fd4e..5d5a455ab6 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -17,8 +17,10 @@ #include #include +#include #include #include +#include #include #include @@ -946,4 +948,66 @@ IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric) keep_max_size_level0 = true; } +void IndexHNSWCagra::add(idx_t n, const float* x) { + FAISS_THROW_IF_NOT_MSG( + !base_level_only, + "Cannot add vectors when base_level_only is set to True"); + + IndexHNSW::add(n, x); +} + +void IndexHNSWCagra::search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params) { + if (!base_level_only) { + IndexHNSW::search(n, x, k, distances, labels, params); + } else { + std::vector nearest(n); + std::vector nearest_d(n); + +#pragma omp for + for (idx_t i = 0; i < n; i++) { + std::unique_ptr dis( + storage_distance_computer(this->storage)); + dis->set_query(x + i * d); + storage_idx_t entrypoint = -1; + float entrypoint_d = std::numeric_limits::max(); + + std::random_device rd; + std::mt19937 gen(i); + std::uniform_int_distribution distrib(0, this->ntotal); + + for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) { + auto idx = distrib(gen); + auto distance = (*dis)(idx); + if (distance < entrypoint_d) { + entrypoint_d = distance; + entrypoint = idx; + } + } + + FAISS_THROW_IF_NOT_MSG( + entrypoint >= 0, "Could not find a valid entrypoint."); + + nearest[i] = entrypoint; + nearest_d[i] = entrypoint_d; + } + + if (params) { + const SearchParametersHNSW* params_hnsw = + dynamic_cast(params); + this->hnsw.efSearch = params_hnsw->efSearch; + this->hnsw.check_relative_distance = + params_hnsw->check_relative_distance; + } + + search_level_0( + n, x, k, nearest.data(), nearest_d.data(), distances, labels); + } +} + } // namespace faiss diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index a4675157de..e1f1d68dee 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -163,6 +163,30 @@ struct IndexHNSW2Level : IndexHNSW { struct IndexHNSWCagra : IndexHNSW { IndexHNSWCagra(); IndexHNSWCagra(int d, int M, MetricType metric = METRIC_L2); + + /// When set to true, the index is immutable. + /// This option is used to copy the knn graph from GpuIndexCagra + /// to the base level of IndexHNSWCagra without adding upper levels. + /// Doing so enables to search the HNSW index, but removes the + /// ability to add vectors. + bool base_level_only = false; + + /// When `base_level_only` is set to `True`, the search function + /// searches only the base level knn graph of the HNSW index. + /// This parameter selects the entry point by randomly selecting + /// some points and using the best one. + int num_base_level_search_entrypoints = 32; + + void add(idx_t n, const float* x); + + /// entry point for search + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params = nullptr); }; } // namespace faiss diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index 634a6c1095..4ae56df10d 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -229,7 +229,13 @@ void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { // turn off as level 0 is copied from CAGRA graph index->init_level0 = false; - index->add(n_train, train_dataset.data()); + if (!index->base_level_only) { + index->add(n_train, train_dataset.data()); + } else { + index->hnsw.prepare_level_tab(n_train, false); + index->storage->add(n_train, train_dataset.data()); + index->ntotal = n_train; + } auto graph = get_knngraph(); diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index b763c591ae..c77c7974d0 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -85,6 +86,9 @@ void queryTest(faiss::MetricType metric, double expected_recall) { std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numTrain, opt.dim, trainVecs.data()); + } // train cpu index faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2, metric); @@ -106,6 +110,9 @@ void queryTest(faiss::MetricType metric, double expected_recall) { // query auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numQuery, opt.dim, queryVecs.data()); + } std::vector refDistance(opt.numQuery * opt.k, 0); std::vector refIndices(opt.numQuery * opt.k, -1); @@ -184,10 +191,13 @@ TEST(TestGpuIndexCagra, Float32_Query_L2) { } TEST(TestGpuIndexCagra, Float32_Query_IP) { - queryTest(faiss::METRIC_INNER_PRODUCT, 0.85); + queryTest(faiss::METRIC_INNER_PRODUCT, 0.98); } -void copyToTest(faiss::MetricType metric, double expected_recall) { +void copyToTest( + faiss::MetricType metric, + double expected_recall, + bool base_level_only) { for (int tries = 0; tries < 5; ++tries) { Options opt; if (opt.buildAlgo == faiss::gpu::graph_build_algo::NN_DESCENT && @@ -197,7 +207,13 @@ void copyToTest(faiss::MetricType metric, double expected_recall) { std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numTrain, opt.dim, trainVecs.data()); + } std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numAdd, opt.dim, addVecs.data()); + } faiss::gpu::StandardGpuResources res; res.noTempMemory(); @@ -214,11 +230,14 @@ void copyToTest(faiss::MetricType metric, double expected_recall) { faiss::IndexHNSWCagra copiedCpuIndex( opt.dim, opt.graphDegree / 2, metric); + copiedCpuIndex.base_level_only = base_level_only; gpuIndex.copyTo(&copiedCpuIndex); copiedCpuIndex.hnsw.efConstruction = opt.k * 2; // add more vecs to copied cpu index - copiedCpuIndex.add(opt.numAdd, addVecs.data()); + if (!base_level_only) { + copiedCpuIndex.add(opt.numAdd, addVecs.data()); + } // train cpu index faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2, metric); @@ -226,10 +245,15 @@ void copyToTest(faiss::MetricType metric, double expected_recall) { cpuIndex.add(opt.numTrain, trainVecs.data()); // add more vecs to cpu index - cpuIndex.add(opt.numAdd, addVecs.data()); + if (!base_level_only) { + cpuIndex.add(opt.numAdd, addVecs.data()); + } // query indexes auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numQuery, opt.dim, queryVecs.data()); + } std::vector refDistance(opt.numQuery * opt.k, 0); std::vector refIndices(opt.numQuery * opt.k, -1); @@ -306,16 +330,26 @@ void copyToTest(faiss::MetricType metric, double expected_recall) { recall_score.view(), copy_ref_dis_mds_opt, ref_dis_mds_opt); + std::cout << "recall_score: " << *recall_score.data_handle() + << std::endl; ASSERT_TRUE(*recall_score.data_handle() > expected_recall); } } TEST(TestGpuIndexCagra, Float32_CopyTo_L2) { - copyToTest(faiss::METRIC_L2, 0.98); + copyToTest(faiss::METRIC_L2, 0.98, false); +} + +TEST(TestGpuIndexCagra, Float32_CopyTo_L2_BaseLevelOnly) { + copyToTest(faiss::METRIC_L2, 0.98, true); } TEST(TestGpuIndexCagra, Float32_CopyTo_IP) { - copyToTest(faiss::METRIC_INNER_PRODUCT, 0.85); + copyToTest(faiss::METRIC_INNER_PRODUCT, 0.98, false); +} + +TEST(TestGpuIndexCagra, Float32_CopyTo_IP_BaseLevelOnly) { + copyToTest(faiss::METRIC_INNER_PRODUCT, 0.8, true); } void copyFromTest(faiss::MetricType metric, double expected_recall) { @@ -328,6 +362,9 @@ void copyFromTest(faiss::MetricType metric, double expected_recall) { std::vector trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numTrain, opt.dim, trainVecs.data()); + } // train cpu index faiss::IndexHNSWCagra cpuIndex(opt.dim, opt.graphDegree / 2, metric); @@ -353,6 +390,9 @@ void copyFromTest(faiss::MetricType metric, double expected_recall) { // query auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numQuery, opt.dim, queryVecs.data()); + } auto gpuRes = res.getResources(); auto devAlloc = faiss::gpu::makeDevAlloc( @@ -423,7 +463,7 @@ TEST(TestGpuIndexCagra, Float32_CopyFrom_L2) { } TEST(TestGpuIndexCagra, Float32_CopyFrom_IP) { - copyFromTest(faiss::METRIC_INNER_PRODUCT, 0.85); + copyFromTest(faiss::METRIC_INNER_PRODUCT, 0.98); } int main(int argc, char** argv) { diff --git a/faiss/gpu/test/test_cagra.py b/faiss/gpu/test/test_cagra.py new file mode 100644 index 0000000000..d670a75c08 --- /dev/null +++ b/faiss/gpu/test/test_cagra.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import faiss +import numpy as np + +from common_faiss_tests import get_dataset_2 + +from faiss.contrib import datasets, evaluation, big_batch_search +from faiss.contrib.exhaustive_search import knn_ground_truth, \ + range_ground_truth, range_search_gpu, \ + range_search_max_results, exponential_query_iterator + + +class TestComputeGT(unittest.TestCase): + + def do_compute_GT(self, metric): + d = 64 + xt, xb, xq = get_dataset_2(d, 0, 10000, 100) + + index = faiss.GpuIndexCagra(d) + index.train(xb) + Dref, Iref = index.search(xq, 10) + + # iterator function on the matrix + + def matrix_iterator(xb, bs): + for i0 in range(0, xb.shape[0], bs): + yield xb[i0:i0 + bs] + + Dnew, Inew = knn_ground_truth(xq, matrix_iterator(xb, 1000), 10) + + np.testing.assert_array_equal(Iref, Inew) + np.testing.assert_almost_equal(Dref, Dnew, decimal=4) + + def test_compute_GT_L2(self): + self.do_compute_GT(faiss.METRIC_L2) + + def test_range_IP(self): + self.do_compute_GT(faiss.METRIC_INNER_PRODUCT) From 24a555d78a73e54bc0f2c90fe2153a8c08386ff1 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 22 May 2024 18:32:32 -0700 Subject: [PATCH 36/46] fix virtual functions and serialization --- faiss/IndexHNSW.cpp | 49 ++++++++++++++++------------- faiss/IndexHNSW.h | 7 +++-- faiss/gpu/test/TestGpuIndexCagra.cu | 1 + faiss/impl/HNSW.cpp | 20 +++++++++--- faiss/impl/HNSW.h | 3 +- faiss/impl/index_read.cpp | 6 +++- faiss/impl/index_write.cpp | 6 +++- 7 files changed, 59 insertions(+), 33 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 5d5a455ab6..35f18014ed 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -26,6 +26,7 @@ #include #include #include +#include "impl/HNSW.h" #include #include @@ -466,7 +467,8 @@ void IndexHNSW::search_level_0( float* distances, idx_t* labels, int nprobe, - int search_type) const { + int search_type, + const SearchParameters* params) const { FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(nprobe > 0); @@ -501,7 +503,9 @@ void IndexHNSW::search_level_0( vt.advance(); } #pragma omp critical - { hnsw_stats.combine(search_stats); } + { + hnsw_stats.combine(search_stats); + } } } @@ -962,7 +966,7 @@ void IndexHNSWCagra::search( idx_t k, float* distances, idx_t* labels, - const SearchParameters* params) { + const SearchParameters* params) const { if (!base_level_only) { IndexHNSW::search(n, x, k, distances, labels, params); } else { @@ -971,42 +975,43 @@ void IndexHNSWCagra::search( #pragma omp for for (idx_t i = 0; i < n; i++) { + // std::unique_ptr dis( + // this->storage->get_distance_computer()); std::unique_ptr dis( storage_distance_computer(this->storage)); dis->set_query(x + i * d); - storage_idx_t entrypoint = -1; - float entrypoint_d = std::numeric_limits::max(); + nearest[i] = -1; + nearest_d[i] = std::numeric_limits::max(); std::random_device rd; - std::mt19937 gen(i); + std::mt19937 gen(rd()); std::uniform_int_distribution distrib(0, this->ntotal); for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) { auto idx = distrib(gen); auto distance = (*dis)(idx); - if (distance < entrypoint_d) { - entrypoint_d = distance; - entrypoint = idx; + // std::cout << "distance: " << distance << std::endl; + if (distance > nearest_d[i]) { + nearest[i] = idx; + nearest_d[i] = distance; } } FAISS_THROW_IF_NOT_MSG( - entrypoint >= 0, "Could not find a valid entrypoint."); - - nearest[i] = entrypoint; - nearest_d[i] = entrypoint_d; - } - - if (params) { - const SearchParametersHNSW* params_hnsw = - dynamic_cast(params); - this->hnsw.efSearch = params_hnsw->efSearch; - this->hnsw.check_relative_distance = - params_hnsw->check_relative_distance; + nearest[i] >= 0, "Could not find a valid entrypoint."); } search_level_0( - n, x, k, nearest.data(), nearest_d.data(), distances, labels); + n, + x, + k, + nearest.data(), + nearest_d.data(), + distances, + labels, + 1, // n_probes + 1, // search_type + dynamic_cast(params)); } } diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index e1f1d68dee..71807c6537 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -93,7 +93,8 @@ struct IndexHNSW : Index { float* distances, idx_t* labels, int nprobe = 1, - int search_type = 1) const; + int search_type = 1, + const SearchParameters* params = nullptr) const; /// alternative graph building void init_level_0_from_knngraph(int k, const float* D, const idx_t* I); @@ -177,7 +178,7 @@ struct IndexHNSWCagra : IndexHNSW { /// some points and using the best one. int num_base_level_search_entrypoints = 32; - void add(idx_t n, const float* x); + void add(idx_t n, const float* x) override; /// entry point for search void search( @@ -186,7 +187,7 @@ struct IndexHNSWCagra : IndexHNSW { idx_t k, float* distances, idx_t* labels, - const SearchParameters* params = nullptr); + const SearchParameters* params = nullptr) const override; }; } // namespace faiss diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index c77c7974d0..7ed08aaa9d 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -278,6 +278,7 @@ void copyToTest( copyRefDistance.data(), copyRefIndices.data(), &cpuSearchParamstwo); + std::cout << "copyRefIndices[0]: " << copyRefIndices[0] << std::endl; // test quality of search auto gpuRes = res.getResources(); diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index a07d1556d5..dc61e12152 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -937,8 +937,10 @@ void HNSW::search_level_0( const float* nearest_d, int search_type, HNSWStats& search_stats, - VisitedTable& vt) const { + VisitedTable& vt, + const SearchParametersHNSW* params) const { const HNSW& hnsw = *this; + auto efSearch = params ? params->efSearch : hnsw.efSearch; int k = extract_k_from_ResultHandler(res); if (search_type == 1) { int nres = 0; @@ -952,16 +954,24 @@ void HNSW::search_level_0( if (vt.get(cj)) continue; - int candidates_size = std::max(hnsw.efSearch, k); + int candidates_size = std::max(efSearch, k); MinimaxHeap candidates(candidates_size); candidates.push(cj, nearest_d[j]); nres = search_from_candidates( - hnsw, qdis, res, candidates, vt, search_stats, 0, nres); + hnsw, + qdis, + res, + candidates, + vt, + search_stats, + 0, + nres, + params); } } else if (search_type == 2) { - int candidates_size = std::max(hnsw.efSearch, int(k)); + int candidates_size = std::max(efSearch, int(k)); candidates_size = std::max(candidates_size, int(nprobe)); MinimaxHeap candidates(candidates_size); @@ -974,7 +984,7 @@ void HNSW::search_level_0( } search_from_candidates( - hnsw, qdis, res, candidates, vt, search_stats, 0); + hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params); } } diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index 7b08096a86..f3aacf8a5b 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -213,7 +213,8 @@ struct HNSW { const float* nearest_d, int search_type, HNSWStats& search_stats, - VisitedTable& vt) const; + VisitedTable& vt, + const SearchParametersHNSW* params = nullptr) const; void reset(); diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index 5b08640295..1085d3a0d1 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -961,8 +961,12 @@ Index* read_index(IOReader* f, int io_flags) { if (h == fourcc("IHNc")) idxhnsw = new IndexHNSWCagra(); read_index_header(idxhnsw, f); - if (h == fourcc("IHNc")) + if (h == fourcc("IHNc")) { READ1(idxhnsw->keep_max_size_level0); + auto idx_hnsw_cagra = dynamic_cast(idxhnsw); + READ1(idx_hnsw_cagra->base_level_only); + READ1(idx_hnsw_cagra->num_base_level_search_entrypoints); + } read_HNSW(&idxhnsw->hnsw, f); idxhnsw->storage = read_index(f, io_flags); idxhnsw->own_fields = true; diff --git a/faiss/impl/index_write.cpp b/faiss/impl/index_write.cpp index e9c6a23a64..24303ac376 100644 --- a/faiss/impl/index_write.cpp +++ b/faiss/impl/index_write.cpp @@ -765,8 +765,12 @@ void write_index(const Index* idx, IOWriter* f) { FAISS_THROW_IF_NOT(h != 0); WRITE1(h); write_index_header(idxhnsw, f); - if (h == fourcc("IHNc")) + if (h == fourcc("IHNc")) { WRITE1(idxhnsw->keep_max_size_level0); + auto idx_hnsw_cagra = dynamic_cast(idxhnsw); + WRITE1(idx_hnsw_cagra->base_level_only); + WRITE1(idx_hnsw_cagra->num_base_level_search_entrypoints); + } write_HNSW(&idxhnsw->hnsw, f); write_index(idxhnsw->storage, f); } else if (const IndexNSG* idxnsg = dynamic_cast(idx)) { From 51227b1df6cf97d8b71e73c38e89c421391e9e6a Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 22 May 2024 18:56:38 -0700 Subject: [PATCH 37/46] invert conditional --- faiss/IndexHNSW.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 35f18014ed..c20ea99b6e 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -975,8 +975,6 @@ void IndexHNSWCagra::search( #pragma omp for for (idx_t i = 0; i < n; i++) { - // std::unique_ptr dis( - // this->storage->get_distance_computer()); std::unique_ptr dis( storage_distance_computer(this->storage)); dis->set_query(x + i * d); @@ -990,13 +988,11 @@ void IndexHNSWCagra::search( for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) { auto idx = distrib(gen); auto distance = (*dis)(idx); - // std::cout << "distance: " << distance << std::endl; - if (distance > nearest_d[i]) { + if (distance < nearest_d[i]) { nearest[i] = idx; nearest_d[i] = distance; } } - FAISS_THROW_IF_NOT_MSG( nearest[i] >= 0, "Could not find a valid entrypoint."); } From 579a301be18282dfcbba1376e395e0b964c5e5dd Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 23 May 2024 09:59:36 -0700 Subject: [PATCH 38/46] debug msg --- faiss/IndexHNSW.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index c20ea99b6e..023e39a526 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -970,6 +970,7 @@ void IndexHNSWCagra::search( if (!base_level_only) { IndexHNSW::search(n, x, k, distances, labels, params); } else { + std::cout << "LEVEL 0 SEARCH" << std::endl; std::vector nearest(n); std::vector nearest_d(n); From ae0b8ba7a708b13c55ca0d887cedeeba9c8040ae Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 23 May 2024 10:46:36 -0700 Subject: [PATCH 39/46] more debug prints --- faiss/impl/HNSW.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index dc61e12152..65fc93cfac 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -942,6 +942,8 @@ void HNSW::search_level_0( const HNSW& hnsw = *this; auto efSearch = params ? params->efSearch : hnsw.efSearch; int k = extract_k_from_ResultHandler(res); + std::cout << "efSearch: " << efSearch << ", k: " << k << std::endl; + if (search_type == 1) { int nres = 0; From 4170a3e4a3d0a8284519c5bf1d529159e47e7e5e Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 23 May 2024 11:50:24 -0700 Subject: [PATCH 40/46] fix efSearch setting in base search --- faiss/IndexHNSW.cpp | 15 +++++++++++---- faiss/impl/HNSW.cpp | 1 - 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 023e39a526..02445989a4 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -468,10 +468,17 @@ void IndexHNSW::search_level_0( idx_t* labels, int nprobe, int search_type, - const SearchParameters* params) const { + const SearchParameters* params_in) const { FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(nprobe > 0); + const SearchParametersHNSW* params = nullptr; + + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); + } + storage_idx_t ntotal = hnsw.levels.size(); using RH = HeapBlockResultHandler; @@ -498,7 +505,8 @@ void IndexHNSW::search_level_0( nearest_d + i * nprobe, search_type, search_stats, - vt); + vt, + params); res.end(); vt.advance(); } @@ -970,7 +978,6 @@ void IndexHNSWCagra::search( if (!base_level_only) { IndexHNSW::search(n, x, k, distances, labels, params); } else { - std::cout << "LEVEL 0 SEARCH" << std::endl; std::vector nearest(n); std::vector nearest_d(n); @@ -1008,7 +1015,7 @@ void IndexHNSWCagra::search( labels, 1, // n_probes 1, // search_type - dynamic_cast(params)); + params); } } diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 65fc93cfac..3ba5f72f68 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -942,7 +942,6 @@ void HNSW::search_level_0( const HNSW& hnsw = *this; auto efSearch = params ? params->efSearch : hnsw.efSearch; int k = extract_k_from_ResultHandler(res); - std::cout << "efSearch: " << efSearch << ", k: " << k << std::endl; if (search_type == 1) { int nres = 0; From 75808b1ab5d427948797a26ad9739bb972db3787 Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 23 May 2024 15:42:57 -0700 Subject: [PATCH 41/46] re-negate ip distances in search_level --- faiss/IndexHNSW.cpp | 7 +++++++ faiss/gpu/test/TestGpuIndexCagra.cu | 3 --- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 02445989a4..19ebfcfe01 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -515,6 +515,13 @@ void IndexHNSW::search_level_0( hnsw_stats.combine(search_stats); } } + if (is_similarity_metric(this->metric_type)) { +// we need to revert the negated distances +#pragma omp parallel for + for (size_t i = 0; i < k * n; i++) { + distances[i] = -distances[i]; + } + } } void IndexHNSW::init_level_0_from_knngraph( diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index 7ed08aaa9d..62f4c7627a 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -278,7 +278,6 @@ void copyToTest( copyRefDistance.data(), copyRefIndices.data(), &cpuSearchParamstwo); - std::cout << "copyRefIndices[0]: " << copyRefIndices[0] << std::endl; // test quality of search auto gpuRes = res.getResources(); @@ -331,8 +330,6 @@ void copyToTest( recall_score.view(), copy_ref_dis_mds_opt, ref_dis_mds_opt); - std::cout << "recall_score: " << *recall_score.data_handle() - << std::endl; ASSERT_TRUE(*recall_score.data_handle() > expected_recall); } } From 9bd10399778e24ce570d1105fe82e41669af5c7d Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 23 May 2024 15:46:34 -0700 Subject: [PATCH 42/46] fix format --- faiss/IndexHNSW.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 35e2a9f286..efd8972477 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -465,9 +465,7 @@ void IndexHNSW::search_level_0( vt.advance(); } #pragma omp critical - { - hnsw_stats.combine(search_stats); - } + { hnsw_stats.combine(search_stats); } } if (is_similarity_metric(this->metric_type)) { // we need to revert the negated distances From ea8028dc559add56f60c9531464277ad0127a77d Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 23 May 2024 15:53:36 -0700 Subject: [PATCH 43/46] re-up minimum recall for base only IP distance --- faiss/gpu/test/TestGpuIndexCagra.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index 62f4c7627a..8d330a81cb 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -347,7 +347,7 @@ TEST(TestGpuIndexCagra, Float32_CopyTo_IP) { } TEST(TestGpuIndexCagra, Float32_CopyTo_IP_BaseLevelOnly) { - copyToTest(faiss::METRIC_INNER_PRODUCT, 0.8, true); + copyToTest(faiss::METRIC_INNER_PRODUCT, 0.98, true); } void copyFromTest(faiss::MetricType metric, double expected_recall) { From fc313514e21954353fd53995d8f790e34b74bf99 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 29 May 2024 17:17:04 -0700 Subject: [PATCH 44/46] add python tests --- faiss/gpu/GpuCloner.cpp | 13 +++++++ faiss/gpu/test/test_cagra.py | 69 +++++++++++++++++++++++++----------- 2 files changed, 61 insertions(+), 21 deletions(-) diff --git a/faiss/gpu/GpuCloner.cpp b/faiss/gpu/GpuCloner.cpp index 8f895ac9c7..de5041b8f6 100644 --- a/faiss/gpu/GpuCloner.cpp +++ b/faiss/gpu/GpuCloner.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -24,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -85,6 +87,10 @@ Index* ToCPUCloner::clone_Index(const Index* index) { // objective is to make a single component out of them // (inverse op of ToGpuClonerMultiple) + } else if (auto icg = dynamic_cast(index)) { + IndexHNSWCagra* res = new IndexHNSWCagra(); + icg->copyTo(res); + return res; } else if (auto ish = dynamic_cast(index)) { int nshard = ish->count(); FAISS_ASSERT(nshard > 0); @@ -214,6 +220,13 @@ Index* ToGpuCloner::clone_Index(const Index* index) { res->reserveMemory(reserveVecs); } + return res; + } else if (auto icg = dynamic_cast(index)) { + GpuIndexCagraConfig config; + config.device = device; + GpuIndexCagra* res = + new GpuIndexCagra(provider, icg->d, icg->metric_type, config); + res->copyFrom(icg); return res; } else { // use CPU cloner for IDMap and PreTransform diff --git a/faiss/gpu/test/test_cagra.py b/faiss/gpu/test/test_cagra.py index d670a75c08..dd7d09f2de 100644 --- a/faiss/gpu/test/test_cagra.py +++ b/faiss/gpu/test/test_cagra.py @@ -8,37 +8,64 @@ import faiss import numpy as np -from common_faiss_tests import get_dataset_2 - -from faiss.contrib import datasets, evaluation, big_batch_search -from faiss.contrib.exhaustive_search import knn_ground_truth, \ - range_ground_truth, range_search_gpu, \ - range_search_max_results, exponential_query_iterator +from faiss.contrib import datasets, evaluation +@unittest.skipIf( + "RAFT" not in faiss.get_compile_options(), + "only if RAFT is compiled in") class TestComputeGT(unittest.TestCase): def do_compute_GT(self, metric): d = 64 - xt, xb, xq = get_dataset_2(d, 0, 10000, 100) - - index = faiss.GpuIndexCagra(d) - index.train(xb) - Dref, Iref = index.search(xq, 10) - - # iterator function on the matrix - - def matrix_iterator(xb, bs): - for i0 in range(0, xb.shape[0], bs): - yield xb[i0:i0 + bs] + k = 12 + ds = datasets.SyntheticDataset(d, 0, 10000, 100) + Dref, Iref = faiss.knn(ds.get_queries(), ds.get_database(), k, metric) - Dnew, Inew = knn_ground_truth(xq, matrix_iterator(xb, 1000), 10) + res = faiss.StandardGpuResources() - np.testing.assert_array_equal(Iref, Inew) - np.testing.assert_almost_equal(Dref, Dnew, decimal=4) + index = faiss.GpuIndexCagra(res, d, metric) + index.train(ds.get_database()) + Dnew, Inew = index.search(ds.get_queries(), k) + + evaluation.check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, k) def test_compute_GT_L2(self): self.do_compute_GT(faiss.METRIC_L2) - def test_range_IP(self): + def test_compute_GT_IP(self): self.do_compute_GT(faiss.METRIC_INNER_PRODUCT) + +@unittest.skipIf( + "RAFT" not in faiss.get_compile_options(), + "only if RAFT is compiled in") +class TestInterop(unittest.TestCase): + + def do_interop(self, metric): + d = 64 + k = 12 + ds = datasets.SyntheticDataset(d, 0, 10000, 100) + + res = faiss.StandardGpuResources() + + index = faiss.GpuIndexCagra(res, d, metric) + index.train(ds.get_database()) + Dnew, Inew = index.search(ds.get_queries(), k) + + cpu_index = faiss.index_gpu_to_cpu(index) + Dref, Iref = cpu_index.search(ds.get_queries(), k) + + evaluation.check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, k) + + faiss.write_index(cpu_index, "index_hnsw_cagra.index") + deserialized_index = faiss.read_index("index_hnsw_cagra.index") + gpu_index = faiss.index_cpu_to_gpu(res, 0, deserialized_index) + Dnew2, Inew2 = gpu_index.search(ds.get_queries(), k) + + evaluation.check_ref_knn_with_draws(Dnew2, Inew2, Dnew, Inew, k) + + def test_interop_L2(self): + self.do_interop(faiss.METRIC_L2) + + def test_interop_IP(self): + self.do_interop(faiss.METRIC_INNER_PRODUCT) From 03ee1fb65bdc423859855af8d4aa042e94595b44 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 29 May 2024 19:17:07 -0700 Subject: [PATCH 45/46] ifdef guards in gpu cloner --- faiss/gpu/GpuCloner.cpp | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/faiss/gpu/GpuCloner.cpp b/faiss/gpu/GpuCloner.cpp index de5041b8f6..b6d55a47aa 100644 --- a/faiss/gpu/GpuCloner.cpp +++ b/faiss/gpu/GpuCloner.cpp @@ -14,7 +14,9 @@ #include #include +#if defined USE_NVIDIA_RAFT #include +#endif #include #include #include @@ -25,7 +27,9 @@ #include #include #include +#if defined USE_NVIDIA_RAFT #include +#endif #include #include #include @@ -87,11 +91,15 @@ Index* ToCPUCloner::clone_Index(const Index* index) { // objective is to make a single component out of them // (inverse op of ToGpuClonerMultiple) - } else if (auto icg = dynamic_cast(index)) { + } +#if defined USE_NVIDIA_RAFT + else if (auto icg = dynamic_cast(index)) { IndexHNSWCagra* res = new IndexHNSWCagra(); icg->copyTo(res); return res; - } else if (auto ish = dynamic_cast(index)) { + } +#endif + else if (auto ish = dynamic_cast(index)) { int nshard = ish->count(); FAISS_ASSERT(nshard > 0); Index* res = clone_Index(ish->at(0)); @@ -221,14 +229,18 @@ Index* ToGpuCloner::clone_Index(const Index* index) { } return res; - } else if (auto icg = dynamic_cast(index)) { + } +#if defined USE_NVIDIA_RAFT + else if (auto icg = dynamic_cast(index)) { GpuIndexCagraConfig config; config.device = device; GpuIndexCagra* res = new GpuIndexCagra(provider, icg->d, icg->metric_type, config); res->copyFrom(icg); return res; - } else { + } +#endif + else { // use CPU cloner for IDMap and PreTransform auto index_idmap = dynamic_cast(index); auto index_pt = dynamic_cast(index); From 2e9cbc839b1bb11f20e0eac19cca034a49fd0abc Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 6 Jun 2024 15:02:39 -0700 Subject: [PATCH 46/46] option to exclude dataset store on index --- faiss/gpu/GpuIndexCagra.cu | 26 ++++- faiss/gpu/GpuIndexCagra.h | 2 + faiss/gpu/impl/RaftCagra.cu | 157 +++++++++++----------------- faiss/gpu/impl/RaftCagra.cuh | 15 ++- faiss/gpu/test/TestGpuIndexCagra.cu | 15 +-- 5 files changed, 108 insertions(+), 107 deletions(-) diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu index 4ae56df10d..b183e74568 100644 --- a/faiss/gpu/GpuIndexCagra.cu +++ b/faiss/gpu/GpuIndexCagra.cu @@ -86,11 +86,13 @@ void GpuIndexCagra::train(idx_t n, const float* x) { cagraConfig_.graph_degree, static_cast(cagraConfig_.build_algo), cagraConfig_.nn_descent_niter, + cagraConfig_.store_dataset, this->metric_type, this->metric_arg, INDICES_64_BIT, ivf_pq_params, - ivf_pq_search_params); + ivf_pq_search_params, + cagraConfig_.refine_rate); index_->train(n, x); @@ -225,17 +227,33 @@ void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { index->hnsw.set_default_probas(M, 1.0 / log(M)); auto n_train = this->ntotal; - auto train_dataset = index_->get_training_dataset(); + float* train_dataset; + auto dataset = index_->get_training_dataset(); + bool allocation = false; + if (getDeviceForAddress(dataset) >= 0) { + train_dataset = new float[n_train * index->d]; + allocation = true; + raft::copy( + train_dataset, + dataset, + n_train * index->d, + this->resources_->getRaftHandleCurrentDevice().get_stream()); + } else { + train_dataset = const_cast(dataset); + } // turn off as level 0 is copied from CAGRA graph index->init_level0 = false; if (!index->base_level_only) { - index->add(n_train, train_dataset.data()); + index->add(n_train, train_dataset); } else { index->hnsw.prepare_level_tab(n_train, false); - index->storage->add(n_train, train_dataset.data()); + index->storage->add(n_train, train_dataset); index->ntotal = n_train; } + if (allocation) { + delete[] train_dataset; + } auto graph = get_knngraph(); diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h index 6ecee3ae03..62042d531f 100644 --- a/faiss/gpu/GpuIndexCagra.h +++ b/faiss/gpu/GpuIndexCagra.h @@ -174,6 +174,8 @@ struct GpuIndexCagraConfig : public GpuIndexConfig { IVFPQBuildCagraConfig* ivf_pq_params = nullptr; IVFPQSearchCagraConfig* ivf_pq_search_params = nullptr; + float refine_rate = 2.0f; + bool store_dataset = true; }; enum class search_algo { diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu index 292079321d..50903220df 100644 --- a/faiss/gpu/impl/RaftCagra.cu +++ b/faiss/gpu/impl/RaftCagra.cu @@ -42,19 +42,23 @@ RaftCagra::RaftCagra( idx_t graph_degree, faiss::cagra_build_algo graph_build_algo, size_t nn_descent_niter, + bool store_dataset, faiss::MetricType metric, float metricArg, IndicesOptions indicesOptions, std::optional ivf_pq_params, std::optional - ivf_pq_search_params) + ivf_pq_search_params, + float refine_rate) : resources_(resources), dim_(dim), + store_dataset_(store_dataset), metric_(metric), metricArg_(metricArg), index_params_(), ivf_pq_params_(ivf_pq_params), - ivf_pq_search_params_(ivf_pq_search_params) { + ivf_pq_search_params_(ivf_pq_search_params), + refine_rate_(refine_rate) { FAISS_THROW_IF_NOT_MSG( metric == faiss::METRIC_L2 || metric == faiss::METRIC_INNER_PRODUCT, "CAGRA currently only supports L2 or Inner Product metric."); @@ -113,6 +117,9 @@ RaftCagra::RaftCagra( FAISS_ASSERT(distances_on_gpu == knn_graph_on_gpu); + storage_ = distances; + n_ = n; + const raft::device_resources& raft_handle = resources_->getRaftHandleCurrentDevice(); @@ -164,81 +171,50 @@ RaftCagra::RaftCagra( } void RaftCagra::train(idx_t n, const float* x) { + storage_ = x; + n_ = n; + const raft::device_resources& raft_handle = resources_->getRaftHandleCurrentDevice(); + + auto nn_descent_params = std::make_optional< + raft::neighbors::experimental::nn_descent::index_params>(); + nn_descent_params->graph_degree = index_params_.intermediate_graph_degree; + nn_descent_params->intermediate_graph_degree = + 1.5 * index_params_.intermediate_graph_degree; + nn_descent_params->max_iterations = index_params_.nn_descent_niter; + if (index_params_.build_algo == - raft::neighbors::cagra::graph_build_algo::IVF_PQ) { - std::optional> knn_graph( - raft::make_host_matrix( - n, index_params_.intermediate_graph_degree)); - if (getDeviceForAddress(x) >= 0) { - auto dataset_d = - raft::make_device_matrix_view( - x, n, dim_); - raft::neighbors::cagra::build_knn_graph( - raft_handle, - dataset_d, - knn_graph->view(), - 1.0f, - ivf_pq_params_, - ivf_pq_search_params_); - } else { - auto dataset_h = raft::make_host_matrix_view( - x, n, dim_); - raft::neighbors::cagra::build_knn_graph( - raft_handle, - dataset_h, - knn_graph->view(), - 1.0f, - ivf_pq_params_, - ivf_pq_search_params_); - } - auto cagra_graph = raft::make_host_matrix( - n, index_params_.graph_degree); - - raft::neighbors::cagra::optimize( - raft_handle, knn_graph->view(), cagra_graph.view()); - - // free intermediate graph before trying to create the index - knn_graph.reset(); - - if (getDeviceForAddress(x) >= 0) { - auto dataset_d = - raft::make_device_matrix_view( - x, n, dim_); - raft_knn_index = raft::neighbors::cagra::index( - raft_handle, - metric_ == faiss::METRIC_L2 - ? raft::distance::DistanceType::L2Expanded - : raft::distance::DistanceType::InnerProduct, - dataset_d, - raft::make_const_mdspan(cagra_graph.view())); - } else { - auto dataset_h = raft::make_host_matrix_view( - x, n, dim_); - raft_knn_index = raft::neighbors::cagra::index( - raft_handle, - metric_ == faiss::METRIC_L2 - ? raft::distance::DistanceType::L2Expanded - : raft::distance::DistanceType::InnerProduct, - dataset_h, - raft::make_const_mdspan(cagra_graph.view())); - } + raft::neighbors::cagra::graph_build_algo::IVF_PQ && + index_params_.graph_degree == index_params_.intermediate_graph_degree) { + index_params_.intermediate_graph_degree = + 1.5 * index_params_.graph_degree; + } + if (getDeviceForAddress(x) >= 0) { + auto dataset = + raft::make_device_matrix_view(x, n, dim_); + raft_knn_index = raft::neighbors::cagra::detail::build( + raft_handle, + index_params_, + dataset, + nn_descent_params, + refine_rate_, + ivf_pq_params_, + ivf_pq_search_params_, + store_dataset_); } else { - if (getDeviceForAddress(x) >= 0) { - raft_knn_index = raft::runtime::neighbors::cagra::build( - raft_handle, - index_params_, - raft::make_device_matrix_view( - x, n, dim_)); - } else { - raft_knn_index = raft::runtime::neighbors::cagra::build( - raft_handle, - index_params_, - raft::make_host_matrix_view( - x, n, dim_)); - } + auto dataset = + raft::make_host_matrix_view(x, n, dim_); + raft_knn_index = raft::neighbors::cagra::detail::build( + raft_handle, + index_params_, + dataset, + nn_descent_params, + refine_rate_, + ivf_pq_params_, + ivf_pq_search_params_, + store_dataset_); } } @@ -270,6 +246,18 @@ void RaftCagra::search( FAISS_ASSERT(numQueries > 0); FAISS_ASSERT(cols == dim_); + if (!store_dataset_) { + if (getDeviceForAddress(storage_) >= 0) { + auto dataset = raft::make_device_matrix_view( + storage_, n_, dim_); + raft_knn_index.value().update_dataset(raft_handle, dataset); + } else { + auto dataset = raft::make_host_matrix_view( + storage_, n_, dim_); + raft_knn_index.value().update_dataset(raft_handle, dataset); + } + } + auto queries_view = raft::make_device_matrix_view( queries.data(), numQueries, cols); auto distances_view = raft::make_device_matrix_view( @@ -342,29 +330,8 @@ std::vector RaftCagra::get_knngraph() const { return host_graph; } -std::vector RaftCagra::get_training_dataset() const { - FAISS_ASSERT(raft_knn_index.has_value()); - const raft::device_resources& raft_handle = - resources_->getRaftHandleCurrentDevice(); - auto stream = raft_handle.get_stream(); - - auto device_dataset = raft_knn_index.value().dataset(); - - std::vector host_dataset( - device_dataset.extent(0) * device_dataset.extent(1)); - - RAFT_CUDA_TRY(cudaMemcpy2DAsync( - host_dataset.data(), - sizeof(float) * dim_, - device_dataset.data_handle(), - sizeof(float) * device_dataset.stride(0), - sizeof(float) * dim_, - device_dataset.extent(0), - cudaMemcpyDefault, - raft_handle.get_stream())); - raft_handle.sync_stream(); - - return host_dataset; +const float* RaftCagra::get_training_dataset() const { + return storage_; } } // namespace gpu diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh index 95f6c03fca..0913ba5947 100644 --- a/faiss/gpu/impl/RaftCagra.cuh +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -53,13 +53,15 @@ class RaftCagra { idx_t graph_degree, faiss::cagra_build_algo graph_build_algo, size_t nn_descent_niter, + bool store_dataset, faiss::MetricType metric, float metricArg, IndicesOptions indicesOptions, std::optional ivf_pq_params = std::nullopt, std::optional - ivf_pq_search_params = std::nullopt); + ivf_pq_search_params = std::nullopt, + float refine_rate = 2.0f); RaftCagra( GpuResources* resources, @@ -101,15 +103,23 @@ class RaftCagra { std::vector get_knngraph() const; - std::vector get_training_dataset() const; + const float* get_training_dataset() const; private: /// Collection of GPU resources that we use GpuResources* resources_; + /// Training dataset + const float* storage_; + int n_; + /// Expected dimensionality of the vectors const int dim_; + /// Controls the underlying RAFT index if it should store the dataset in + /// device memory + bool store_dataset_; + /// Metric type of the index faiss::MetricType metric_; @@ -122,6 +132,7 @@ class RaftCagra { /// Parameters to build CAGRA graph using IVF PQ std::optional ivf_pq_params_; std::optional ivf_pq_search_params_; + std::optional refine_rate_; /// Instance of trained RAFT CAGRA index std::optional> diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu index 8d330a81cb..3d9e14ae34 100644 --- a/faiss/gpu/test/TestGpuIndexCagra.cu +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -38,7 +38,7 @@ struct Options { Options() { - numTrain = 2 * faiss::gpu::randVal(2000, 5000); + numTrain = 2 * faiss::gpu::randVal(4000, 10000); dim = faiss::gpu::randVal(4, 10); numAdd = faiss::gpu::randVal(1000, 3000); @@ -47,8 +47,9 @@ struct Options { buildAlgo = faiss::gpu::randSelect( {faiss::gpu::graph_build_algo::IVF_PQ, faiss::gpu::graph_build_algo::NN_DESCENT}); + storeDataset = faiss::gpu::randSelect({true, false}); - numQuery = faiss::gpu::randVal(32, 100); + numQuery = faiss::gpu::randVal(300, 600); k = faiss::gpu::randVal(10, 30); device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); @@ -71,6 +72,7 @@ struct Options { size_t graphDegree; size_t intermediateGraphDegree; faiss::gpu::graph_build_algo buildAlgo; + bool storeDataset; int numQuery; int k; int device; @@ -224,6 +226,7 @@ void copyToTest( config.graph_degree = opt.graphDegree; config.intermediate_graph_degree = opt.intermediateGraphDegree; config.build_algo = opt.buildAlgo; + config.store_dataset = opt.storeDataset; faiss::gpu::GpuIndexCagra gpuIndex(&res, opt.dim, metric, config); gpuIndex.train(opt.numTrain, trainVecs.data()); @@ -339,7 +342,7 @@ TEST(TestGpuIndexCagra, Float32_CopyTo_L2) { } TEST(TestGpuIndexCagra, Float32_CopyTo_L2_BaseLevelOnly) { - copyToTest(faiss::METRIC_L2, 0.98, true); + copyToTest(faiss::METRIC_L2, 0.95, true); } TEST(TestGpuIndexCagra, Float32_CopyTo_IP) { @@ -347,7 +350,7 @@ TEST(TestGpuIndexCagra, Float32_CopyTo_IP) { } TEST(TestGpuIndexCagra, Float32_CopyTo_IP_BaseLevelOnly) { - copyToTest(faiss::METRIC_INNER_PRODUCT, 0.98, true); + copyToTest(faiss::METRIC_INNER_PRODUCT, 0.95, true); } void copyFromTest(faiss::MetricType metric, double expected_recall) { @@ -457,11 +460,11 @@ void copyFromTest(faiss::MetricType metric, double expected_recall) { } TEST(TestGpuIndexCagra, Float32_CopyFrom_L2) { - copyFromTest(faiss::METRIC_L2, 0.98); + copyFromTest(faiss::METRIC_L2, 0.95); } TEST(TestGpuIndexCagra, Float32_CopyFrom_IP) { - copyFromTest(faiss::METRIC_INNER_PRODUCT, 0.98); + copyFromTest(faiss::METRIC_INNER_PRODUCT, 0.95); } int main(int argc, char** argv) {