diff --git a/CMakeLists.txt b/CMakeLists.txt index cedee9c456..1a468fb247 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,6 +46,8 @@ project(faiss LANGUAGES ${FAISS_LANGUAGES}) include(GNUInstallDirs) +set(CMAKE_INSTALL_PREFIX "$ENV{CONDA_PREFIX}") + set(CMAKE_CXX_STANDARD 17) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 0686917211..efd8972477 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -17,12 +17,16 @@ #include #include +#include +#include #include +#include #include #include #include #include +#include "impl/HNSW.h" #include #include @@ -146,7 +150,9 @@ void hnsw_add_vertices( int i1 = n; - for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) { + for (int pt_level = hist.size() - 1; + pt_level >= !index_hnsw.init_level0; + pt_level--) { int i0 = i1 - hist[pt_level]; if (verbose) { @@ -182,7 +188,13 @@ void hnsw_add_vertices( continue; } - hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt); + hnsw.add_with_locks( + *dis, + pt_level, + pt_id, + locks, + vt, + index_hnsw.keep_max_size_level0 && (pt_level == 0)); if (prev_display >= 0 && i - i0 > prev_display + 10000) { prev_display = i - i0; @@ -202,7 +214,11 @@ void hnsw_add_vertices( } i1 = i0; } - FAISS_ASSERT(i1 == 0); + if (index_hnsw.init_level0) { + FAISS_ASSERT(i1 == 0); + } else { + FAISS_ASSERT((i1 - hist[0]) == 0); + } } if (verbose) { printf("Done in %.3f ms\n", getmillisecs() - t0); @@ -405,10 +421,18 @@ void IndexHNSW::search_level_0( float* distances, idx_t* labels, int nprobe, - int search_type) const { + int search_type, + const SearchParameters* params_in) const { FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(nprobe > 0); + const SearchParametersHNSW* params = nullptr; + + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); + } + storage_idx_t ntotal = hnsw.levels.size(); using RH = HeapBlockResultHandler; @@ -435,13 +459,21 @@ void IndexHNSW::search_level_0( nearest_d + i * nprobe, search_type, search_stats, - vt); + vt, + params); res.end(); vt.advance(); } #pragma omp critical { hnsw_stats.combine(search_stats); } } + if (is_similarity_metric(this->metric_type)) { +// we need to revert the negated distances +#pragma omp parallel for + for (size_t i = 0; i < k * n; i++) { + distances[i] = -distances[i]; + } + } } void IndexHNSW::init_level_0_from_knngraph( @@ -864,4 +896,86 @@ void IndexHNSW2Level::flip_to_ivf() { delete storage2l; } +/************************************************************** + * IndexHNSWCagra implementation + **************************************************************/ + +IndexHNSWCagra::IndexHNSWCagra() { + is_trained = true; +} + +IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric) + : IndexHNSW( + (metric == METRIC_L2) + ? static_cast(new IndexFlatL2(d)) + : static_cast(new IndexFlatIP(d)), + M) { + FAISS_THROW_IF_NOT_MSG( + ((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)), + "unsupported metric type for IndexHNSWCagra"); + own_fields = true; + is_trained = true; + init_level0 = true; + keep_max_size_level0 = true; +} + +void IndexHNSWCagra::add(idx_t n, const float* x) { + FAISS_THROW_IF_NOT_MSG( + !base_level_only, + "Cannot add vectors when base_level_only is set to True"); + + IndexHNSW::add(n, x); +} + +void IndexHNSWCagra::search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params) const { + if (!base_level_only) { + IndexHNSW::search(n, x, k, distances, labels, params); + } else { + std::vector nearest(n); + std::vector nearest_d(n); + +#pragma omp for + for (idx_t i = 0; i < n; i++) { + std::unique_ptr dis( + storage_distance_computer(this->storage)); + dis->set_query(x + i * d); + nearest[i] = -1; + nearest_d[i] = std::numeric_limits::max(); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution distrib(0, this->ntotal); + + for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) { + auto idx = distrib(gen); + auto distance = (*dis)(idx); + if (distance < nearest_d[i]) { + nearest[i] = idx; + nearest_d[i] = distance; + } + } + FAISS_THROW_IF_NOT_MSG( + nearest[i] >= 0, "Could not find a valid entrypoint."); + } + + search_level_0( + n, + x, + k, + nearest.data(), + nearest_d.data(), + distances, + labels, + 1, // n_probes + 1, // search_type + params); + } +} + } // namespace faiss diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index e0b65fca9d..71807c6537 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -34,6 +34,18 @@ struct IndexHNSW : Index { bool own_fields = false; Index* storage = nullptr; + // When set to false, level 0 in the knn graph is not initialized. + // This option is used by GpuIndexCagra::copyTo(IndexHNSWCagra*) + // as level 0 knn graph is copied over from the index built by + // GpuIndexCagra. + bool init_level0 = true; + + // When set to true, all neighbors in level 0 are filled up + // to the maximum size allowed (2 * M). This option is used by + // IndexHHNSWCagra to create a full base layer graph that is + // used when GpuIndexCagra::copyFrom(IndexHNSWCagra*) is invoked. + bool keep_max_size_level0 = false; + explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2); explicit IndexHNSW(Index* storage, int M = 32); @@ -81,7 +93,8 @@ struct IndexHNSW : Index { float* distances, idx_t* labels, int nprobe = 1, - int search_type = 1) const; + int search_type = 1, + const SearchParameters* params = nullptr) const; /// alternative graph building void init_level_0_from_knngraph(int k, const float* D, const idx_t* I); @@ -148,4 +161,33 @@ struct IndexHNSW2Level : IndexHNSW { const SearchParameters* params = nullptr) const override; }; +struct IndexHNSWCagra : IndexHNSW { + IndexHNSWCagra(); + IndexHNSWCagra(int d, int M, MetricType metric = METRIC_L2); + + /// When set to true, the index is immutable. + /// This option is used to copy the knn graph from GpuIndexCagra + /// to the base level of IndexHNSWCagra without adding upper levels. + /// Doing so enables to search the HNSW index, but removes the + /// ability to add vectors. + bool base_level_only = false; + + /// When `base_level_only` is set to `True`, the search function + /// searches only the base level knn graph of the HNSW index. + /// This parameter selects the entry point by randomly selecting + /// some points and using the best one. + int num_base_level_search_entrypoints = 32; + + void add(idx_t n, const float* x) override; + + /// entry point for search + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params = nullptr) const override; +}; + } // namespace faiss diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index 126cbe5044..d20f3b7f8e 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -238,11 +238,15 @@ generate_ivf_interleaved_code() if(FAISS_ENABLE_RAFT) list(APPEND FAISS_GPU_HEADERS + GpuIndexCagra.h + impl/RaftCagra.cuh impl/RaftFlatIndex.cuh impl/RaftIVFFlat.cuh impl/RaftIVFPQ.cuh utils/RaftUtils.h) list(APPEND FAISS_GPU_SRC + GpuIndexCagra.cu + impl/RaftCagra.cu impl/RaftFlatIndex.cu impl/RaftIVFFlat.cu impl/RaftIVFPQ.cu @@ -316,5 +320,5 @@ __nv_relfatbin : { *(__nv_relfatbin) } target_link_options(faiss_gpu PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") find_package(CUDAToolkit REQUIRED) -target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:raft::compiled> $<$:nvidia::cutlass::cutlass>) -target_compile_options(faiss_gpu PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr>) +target_link_libraries(faiss_gpu PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:raft::compiled> $<$:nvidia::cutlass::cutlass> $<$:OpenMP::OpenMP_CXX>) +target_compile_options(faiss_gpu PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr $<$:-Xcompiler=${OpenMP_CXX_FLAGS}>>) diff --git a/faiss/gpu/GpuCloner.cpp b/faiss/gpu/GpuCloner.cpp index 8f895ac9c7..b6d55a47aa 100644 --- a/faiss/gpu/GpuCloner.cpp +++ b/faiss/gpu/GpuCloner.cpp @@ -14,6 +14,9 @@ #include #include +#if defined USE_NVIDIA_RAFT +#include +#endif #include #include #include @@ -24,6 +27,9 @@ #include #include #include +#if defined USE_NVIDIA_RAFT +#include +#endif #include #include #include @@ -85,7 +91,15 @@ Index* ToCPUCloner::clone_Index(const Index* index) { // objective is to make a single component out of them // (inverse op of ToGpuClonerMultiple) - } else if (auto ish = dynamic_cast(index)) { + } +#if defined USE_NVIDIA_RAFT + else if (auto icg = dynamic_cast(index)) { + IndexHNSWCagra* res = new IndexHNSWCagra(); + icg->copyTo(res); + return res; + } +#endif + else if (auto ish = dynamic_cast(index)) { int nshard = ish->count(); FAISS_ASSERT(nshard > 0); Index* res = clone_Index(ish->at(0)); @@ -215,7 +229,18 @@ Index* ToGpuCloner::clone_Index(const Index* index) { } return res; - } else { + } +#if defined USE_NVIDIA_RAFT + else if (auto icg = dynamic_cast(index)) { + GpuIndexCagraConfig config; + config.device = device; + GpuIndexCagra* res = + new GpuIndexCagra(provider, icg->d, icg->metric_type, config); + res->copyFrom(icg); + return res; + } +#endif + else { // use CPU cloner for IDMap and PreTransform auto index_idmap = dynamic_cast(index); auto index_pt = dynamic_cast(index); diff --git a/faiss/gpu/GpuIndexCagra.cu b/faiss/gpu/GpuIndexCagra.cu new file mode 100644 index 0000000000..4ae56df10d --- /dev/null +++ b/faiss/gpu/GpuIndexCagra.cu @@ -0,0 +1,274 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +namespace faiss { +namespace gpu { + +GpuIndexCagra::GpuIndexCagra( + GpuResourcesProvider* provider, + int dims, + faiss::MetricType metric, + GpuIndexCagraConfig config) + : GpuIndex(provider->getResources(), dims, metric, 0.0f, config), + cagraConfig_(config) { + this->is_trained = false; +} + +void GpuIndexCagra::train(idx_t n, const float* x) { + if (this->is_trained) { + FAISS_ASSERT(index_); + return; + } + + FAISS_ASSERT(!index_); + + std::optional ivf_pq_params = + std::nullopt; + std::optional ivf_pq_search_params = + std::nullopt; + if (cagraConfig_.ivf_pq_params != nullptr) { + ivf_pq_params = + std::make_optional(); + ivf_pq_params->n_lists = cagraConfig_.ivf_pq_params->n_lists; + ivf_pq_params->kmeans_n_iters = + cagraConfig_.ivf_pq_params->kmeans_n_iters; + ivf_pq_params->kmeans_trainset_fraction = + cagraConfig_.ivf_pq_params->kmeans_trainset_fraction; + ivf_pq_params->pq_bits = cagraConfig_.ivf_pq_params->pq_bits; + ivf_pq_params->pq_dim = cagraConfig_.ivf_pq_params->pq_dim; + ivf_pq_params->codebook_kind = + static_cast( + cagraConfig_.ivf_pq_params->codebook_kind); + ivf_pq_params->force_random_rotation = + cagraConfig_.ivf_pq_params->force_random_rotation; + ivf_pq_params->conservative_memory_allocation = + cagraConfig_.ivf_pq_params->conservative_memory_allocation; + } + if (cagraConfig_.ivf_pq_search_params != nullptr) { + ivf_pq_search_params = + std::make_optional(); + ivf_pq_search_params->n_probes = + cagraConfig_.ivf_pq_search_params->n_probes; + ivf_pq_search_params->lut_dtype = + cagraConfig_.ivf_pq_search_params->lut_dtype; + ivf_pq_search_params->preferred_shmem_carveout = + cagraConfig_.ivf_pq_search_params->preferred_shmem_carveout; + } + index_ = std::make_shared( + this->resources_.get(), + this->d, + cagraConfig_.intermediate_graph_degree, + cagraConfig_.graph_degree, + static_cast(cagraConfig_.build_algo), + cagraConfig_.nn_descent_niter, + this->metric_type, + this->metric_arg, + INDICES_64_BIT, + ivf_pq_params, + ivf_pq_search_params); + + index_->train(n, x); + + this->is_trained = true; + this->ntotal = n; +} + +bool GpuIndexCagra::addImplRequiresIDs_() const { + return false; +}; + +void GpuIndexCagra::addImpl_(idx_t n, const float* x, const idx_t* ids) { + FAISS_THROW_MSG("adding vectors is not supported by GpuIndexCagra."); +}; + +void GpuIndexCagra::searchImpl_( + idx_t n, + const float* x, + int k, + float* distances, + idx_t* labels, + const SearchParameters* search_params) const { + FAISS_ASSERT(this->is_trained && index_); + FAISS_ASSERT(n > 0); + + Tensor queries(const_cast(x), {n, this->d}); + Tensor outDistances(distances, {n, k}); + Tensor outLabels(const_cast(labels), {n, k}); + + SearchParametersCagra* params; + if (search_params) { + params = dynamic_cast( + const_cast(search_params)); + } else { + params = new SearchParametersCagra{}; + } + + index_->search( + queries, + k, + outDistances, + outLabels, + params->max_queries, + params->itopk_size, + params->max_iterations, + static_cast(params->algo), + params->team_size, + params->search_width, + params->min_iterations, + params->thread_block_size, + static_cast(params->hashmap_mode), + params->hashmap_min_bitlen, + params->hashmap_max_fill_rate, + params->num_random_samplings, + params->seed); + + if (not search_params) { + delete params; + } +} + +void GpuIndexCagra::copyFrom(const faiss::IndexHNSWCagra* index) { + FAISS_ASSERT(index); + + DeviceScope scope(config_.device); + + GpuIndex::copyFrom(index); + + auto base_index = dynamic_cast(index->storage); + FAISS_ASSERT(base_index); + auto distances = base_index->get_xb(); + + auto hnsw = index->hnsw; + // copy level 0 to a dense knn graph matrix + std::vector knn_graph; + knn_graph.reserve(index->ntotal * hnsw.nb_neighbors(0)); + +#pragma omp parallel for + for (size_t i = 0; i < index->ntotal; ++i) { + size_t begin, end; + hnsw.neighbor_range(i, 0, &begin, &end); + for (size_t j = begin; j < end; j++) { + // knn_graph.push_back(hnsw.neighbors[j]); + knn_graph[i * hnsw.nb_neighbors(0) + (j - begin)] = + hnsw.neighbors[j]; + } + } + + index_ = std::make_shared( + this->resources_.get(), + this->d, + index->ntotal, + hnsw.nb_neighbors(0), + distances, + knn_graph.data(), + this->metric_type, + this->metric_arg, + INDICES_64_BIT); + + this->is_trained = true; +} + +void GpuIndexCagra::copyTo(faiss::IndexHNSWCagra* index) const { + FAISS_ASSERT(index_ && this->is_trained && index); + + DeviceScope scope(config_.device); + + // + // Index information + // + GpuIndex::copyTo(index); + // This needs to be zeroed out as this implementation adds vectors to the + // cpuIndex instead of copying fields + index->ntotal = 0; + + auto graph_degree = index_->get_knngraph_degree(); + auto M = graph_degree / 2; + if (index->storage and index->own_fields) { + delete index->storage; + } + + if (this->metric_type == METRIC_L2) { + index->storage = new IndexFlatL2(index->d); + } else if (this->metric_type == METRIC_INNER_PRODUCT) { + index->storage = new IndexFlatIP(index->d); + } + index->own_fields = true; + index->keep_max_size_level0 = true; + index->hnsw.reset(); + index->hnsw.assign_probas.clear(); + index->hnsw.cum_nneighbor_per_level.clear(); + index->hnsw.set_default_probas(M, 1.0 / log(M)); + + auto n_train = this->ntotal; + auto train_dataset = index_->get_training_dataset(); + + // turn off as level 0 is copied from CAGRA graph + index->init_level0 = false; + if (!index->base_level_only) { + index->add(n_train, train_dataset.data()); + } else { + index->hnsw.prepare_level_tab(n_train, false); + index->storage->add(n_train, train_dataset.data()); + index->ntotal = n_train; + } + + auto graph = get_knngraph(); + +#pragma omp parallel for + for (idx_t i = 0; i < n_train; i++) { + size_t begin, end; + index->hnsw.neighbor_range(i, 0, &begin, &end); + for (size_t j = begin; j < end; j++) { + index->hnsw.neighbors[j] = graph[i * graph_degree + (j - begin)]; + } + } + + // turn back on to allow new vectors to be added to level 0 + index->init_level0 = true; +} + +void GpuIndexCagra::reset() { + DeviceScope scope(config_.device); + + if (index_) { + index_->reset(); + this->ntotal = 0; + this->is_trained = false; + } else { + FAISS_ASSERT(this->ntotal == 0); + } +} + +std::vector GpuIndexCagra::get_knngraph() const { + FAISS_ASSERT(index_ && this->is_trained); + + return index_->get_knngraph(); +} + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/GpuIndexCagra.h b/faiss/gpu/GpuIndexCagra.h new file mode 100644 index 0000000000..6ecee3ae03 --- /dev/null +++ b/faiss/gpu/GpuIndexCagra.h @@ -0,0 +1,282 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace faiss { +struct IndexHNSWCagra; +} + +namespace faiss { +namespace gpu { + +class RaftCagra; + +enum class graph_build_algo { + /// Use IVF-PQ to build all-neighbors knn graph + IVF_PQ, + /// Experimental, use NN-Descent to build all-neighbors knn graph + NN_DESCENT +}; + +/// A type for specifying how PQ codebooks are created. +enum class codebook_gen { // NOLINT + PER_SUBSPACE = 0, // NOLINT + PER_CLUSTER = 1, // NOLINT +}; + +struct IVFPQBuildCagraConfig { + /// + /// The number of inverted lists (clusters) + /// + /// Hint: the number of vectors per cluster (`n_rows/n_lists`) should be + /// approximately 1,000 to 10,000. + + uint32_t n_lists = 1024; + /// The number of iterations searching for kmeans centers (index building). + uint32_t kmeans_n_iters = 20; + /// The fraction of data to use during iterative kmeans building. + double kmeans_trainset_fraction = 0.5; + /// + /// The bit length of the vector element after compression by PQ. + /// + /// Possible values: [4, 5, 6, 7, 8]. + /// + /// Hint: the smaller the 'pq_bits', the smaller the index size and the + /// better the search performance, but the lower the recall. + + uint32_t pq_bits = 8; + /// + /// The dimensionality of the vector after compression by PQ. When zero, an + /// optimal value is selected using a heuristic. + /// + /// NB: `pq_dim /// pq_bits` must be a multiple of 8. + /// + /// Hint: a smaller 'pq_dim' results in a smaller index size and better + /// search performance, but lower recall. If 'pq_bits' is 8, 'pq_dim' can be + /// set to any number, but multiple of 8 are desirable for good performance. + /// If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. For good + /// performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, + /// 'pq_dim' should be also a divisor of the dataset dim. + + uint32_t pq_dim = 0; + /// How PQ codebooks are created. + codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE; + /// + /// Apply a random rotation matrix on the input data and queries even if + /// `dim % pq_dim == 0`. + /// + /// Note: if `dim` is not multiple of `pq_dim`, a random rotation is always + /// applied to the input data and queries to transform the working space + /// from `dim` to `rot_dim`, which may be slightly larger than the original + /// space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). + /// However, this transform is not necessary when `dim` is multiple of + /// `pq_dim` + /// (`dim == rot_dim`, hence no need in adding "extra" data columns / + /// features). + /// + /// By default, if `dim == rot_dim`, the rotation transform is initialized + /// with the identity matrix. When `force_random_rotation == true`, a random + /// orthogonal transform matrix is generated regardless of the values of + /// `dim` and `pq_dim`. + + bool force_random_rotation = false; + /// + /// By default, the algorithm allocates more space than necessary for + /// individual clusters + /// (`list_data`). This allows to amortize the cost of memory allocation and + /// reduce the number of data copies during repeated calls to `extend` + /// (extending the database). + /// + /// The alternative is the conservative allocation behavior; when enabled, + /// the algorithm always allocates the minimum amount of memory required to + /// store the given number of records. Set this flag to `true` if you prefer + /// to use as little GPU memory for the database as possible. + + bool conservative_memory_allocation = false; +}; + +struct IVFPQSearchCagraConfig { + /// The number of clusters to search. + uint32_t n_probes = 20; + /// + /// Data type of look up table to be created dynamically at search time. + /// + /// Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U] + /// + /// The use of low-precision types reduces the amount of shared memory + /// required at search time, so fast shared memory kernels can be used even + /// for datasets with large dimansionality. Note that the recall is slightly + /// degraded when low-precision type is selected. + + cudaDataType_t lut_dtype = CUDA_R_32F; + /// + /// Storage data type for distance/similarity computed at search time. + /// + /// Possible values: [CUDA_R_16F, CUDA_R_32F] + /// + /// If the performance limiter at search time is device memory access, + /// selecting FP16 will improve performance slightly. + + cudaDataType_t internal_distance_dtype = CUDA_R_32F; + /// + /// Preferred fraction of SM's unified memory / L1 cache to be used as + /// shared memory. + /// + /// Possible values: [0.0 - 1.0] as a fraction of the + /// `sharedMemPerMultiprocessor`. + /// + /// One wants to increase the carveout to make sure a good GPU occupancy for + /// the main search kernel, but not to keep it too high to leave some memory + /// to be used as L1 cache. Note, this value is interpreted only as a hint. + /// Moreover, a GPU usually allows only a fixed set of cache configurations, + /// so the provided value is rounded up to the nearest configuration. Refer + /// to the NVIDIA tuning guide for the target GPU architecture. + /// + /// Note, this is a low-level tuning parameter that can have drastic + /// negative effects on the search performance if tweaked incorrectly. + + double preferred_shmem_carveout = 1.0; +}; + +struct GpuIndexCagraConfig : public GpuIndexConfig { + /// Degree of input graph for pruning. + size_t intermediate_graph_degree = 128; + /// Degree of output graph. + size_t graph_degree = 64; + /// ANN algorithm to build knn graph. + graph_build_algo build_algo = graph_build_algo::IVF_PQ; + /// Number of Iterations to run if building with NN_DESCENT + size_t nn_descent_niter = 20; + + IVFPQBuildCagraConfig* ivf_pq_params = nullptr; + IVFPQSearchCagraConfig* ivf_pq_search_params = nullptr; +}; + +enum class search_algo { + /// For large batch sizes. + SINGLE_CTA, + /// For small batch sizes. + MULTI_CTA, + MULTI_KERNEL, + AUTO +}; + +enum class hash_mode { HASH, SMALL, AUTO }; + +struct SearchParametersCagra : SearchParameters { + /// Maximum number of queries to search at the same time (batch size). Auto + /// select when 0. + size_t max_queries = 0; + + /// Number of intermediate search results retained during the search. + /// + /// This is the main knob to adjust trade off between accuracy and search + /// speed. Higher values improve the search accuracy. + + size_t itopk_size = 64; + + /// Upper limit of search iterations. Auto select when 0. + size_t max_iterations = 0; + + // In the following we list additional search parameters for fine tuning. + // Reasonable default values are automatically chosen. + + /// Which search implementation to use. + search_algo algo = search_algo::AUTO; + + /// Number of threads used to calculate a single distance. 4, 8, 16, or 32. + + size_t team_size = 0; + + /// Number of graph nodes to select as the starting point for the search in + /// each iteration. aka search width? + size_t search_width = 1; + /// Lower limit of search iterations. + size_t min_iterations = 0; + + /// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. + size_t thread_block_size = 0; + /// Hashmap type. Auto selection when AUTO. + hash_mode hashmap_mode = hash_mode::AUTO; + /// Lower limit of hashmap bit length. More than 8. + size_t hashmap_min_bitlen = 0; + /// Upper limit of hashmap fill rate. More than 0.1, less than 0.9. + float hashmap_max_fill_rate = 0.5; + + /// Number of iterations of initial random seed node selection. 1 or more. + + uint32_t num_random_samplings = 1; + /// Bit mask used for initial random seed node selection. + uint64_t seed = 0x128394; +}; + +struct GpuIndexCagra : public GpuIndex { + public: + GpuIndexCagra( + GpuResourcesProvider* provider, + int dims, + faiss::MetricType metric = faiss::METRIC_L2, + GpuIndexCagraConfig config = GpuIndexCagraConfig()); + + /// Trains CAGRA based on the given vector data + void train(idx_t n, const float* x) override; + + /// Initialize ourselves from the given CPU index; will overwrite + /// all data in ourselves + void copyFrom(const faiss::IndexHNSWCagra* index); + + /// Copy ourselves to the given CPU index; will overwrite all data + /// in the index instance + void copyTo(faiss::IndexHNSWCagra* index) const; + + void reset() override; + + std::vector get_knngraph() const; + + protected: + bool addImplRequiresIDs_() const override; + + void addImpl_(idx_t n, const float* x, const idx_t* ids) override; + + /// Called from GpuIndex for search + void searchImpl_( + idx_t n, + const float* x, + int k, + float* distances, + idx_t* labels, + const SearchParameters* search_params) const override; + + /// Our configuration options + const GpuIndexCagraConfig cagraConfig_; + + /// Instance that we own; contains the inverted lists + std::shared_ptr index_; +}; + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/impl/RaftCagra.cu b/faiss/gpu/impl/RaftCagra.cu new file mode 100644 index 0000000000..292079321d --- /dev/null +++ b/faiss/gpu/impl/RaftCagra.cu @@ -0,0 +1,371 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace faiss { +namespace gpu { + +RaftCagra::RaftCagra( + GpuResources* resources, + int dim, + idx_t intermediate_graph_degree, + idx_t graph_degree, + faiss::cagra_build_algo graph_build_algo, + size_t nn_descent_niter, + faiss::MetricType metric, + float metricArg, + IndicesOptions indicesOptions, + std::optional ivf_pq_params, + std::optional + ivf_pq_search_params) + : resources_(resources), + dim_(dim), + metric_(metric), + metricArg_(metricArg), + index_params_(), + ivf_pq_params_(ivf_pq_params), + ivf_pq_search_params_(ivf_pq_search_params) { + FAISS_THROW_IF_NOT_MSG( + metric == faiss::METRIC_L2 || metric == faiss::METRIC_INNER_PRODUCT, + "CAGRA currently only supports L2 or Inner Product metric."); + FAISS_THROW_IF_NOT_MSG( + indicesOptions == faiss::gpu::INDICES_64_BIT, + "only INDICES_64_BIT is supported for RAFT CAGRA index"); + + index_params_.intermediate_graph_degree = intermediate_graph_degree; + index_params_.graph_degree = graph_degree; + index_params_.build_algo = + static_cast( + graph_build_algo); + index_params_.nn_descent_niter = nn_descent_niter; + + if (!ivf_pq_params_) { + ivf_pq_params_ = + std::make_optional(); + } + if (!ivf_pq_search_params_) { + ivf_pq_search_params_ = + std::make_optional(); + } + index_params_.metric = metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct; + ivf_pq_params_->metric = metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct; + + reset(); +} + +RaftCagra::RaftCagra( + GpuResources* resources, + int dim, + idx_t n, + int graph_degree, + const float* distances, + const idx_t* knn_graph, + faiss::MetricType metric, + float metricArg, + IndicesOptions indicesOptions) + : resources_(resources), + dim_(dim), + metric_(metric), + metricArg_(metricArg) { + FAISS_THROW_IF_NOT_MSG( + metric == faiss::METRIC_L2 || metric == faiss::METRIC_INNER_PRODUCT, + "CAGRA currently only supports L2 or Inner Product metric."); + FAISS_THROW_IF_NOT_MSG( + indicesOptions == faiss::gpu::INDICES_64_BIT, + "only INDICES_64_BIT is supported for RAFT CAGRA index"); + + auto distances_on_gpu = getDeviceForAddress(distances) >= 0; + auto knn_graph_on_gpu = getDeviceForAddress(knn_graph) >= 0; + + FAISS_ASSERT(distances_on_gpu == knn_graph_on_gpu); + + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + + if (distances_on_gpu && knn_graph_on_gpu) { + raft_handle.sync_stream(); + // Copying to host so that raft::neighbors::cagra::index + // creates an owning copy of the knn graph on device + auto knn_graph_copy = + raft::make_host_matrix(n, graph_degree); + thrust::copy( + thrust::device_ptr(knn_graph), + thrust::device_ptr(knn_graph + (n * graph_degree)), + knn_graph_copy.data_handle()); + + auto distances_mds = + raft::make_device_matrix_view( + distances, n, dim); + + raft_knn_index = raft::neighbors::cagra::index( + raft_handle, + metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct, + distances_mds, + raft::make_const_mdspan(knn_graph_copy.view())); + } else if (!distances_on_gpu && !knn_graph_on_gpu) { + // copy idx_t (int64_t) host knn_graph to uint32_t host knn_graph + auto knn_graph_copy = + raft::make_host_matrix(n, graph_degree); + std::copy( + knn_graph, + knn_graph + (n * graph_degree), + knn_graph_copy.data_handle()); + + auto distances_mds = raft::make_host_matrix_view( + distances, n, dim); + + raft_knn_index = raft::neighbors::cagra::index( + raft_handle, + metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct, + distances_mds, + raft::make_const_mdspan(knn_graph_copy.view())); + } else { + FAISS_THROW_MSG( + "distances and knn_graph must both be in device or host memory"); + } +} + +void RaftCagra::train(idx_t n, const float* x) { + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + if (index_params_.build_algo == + raft::neighbors::cagra::graph_build_algo::IVF_PQ) { + std::optional> knn_graph( + raft::make_host_matrix( + n, index_params_.intermediate_graph_degree)); + if (getDeviceForAddress(x) >= 0) { + auto dataset_d = + raft::make_device_matrix_view( + x, n, dim_); + raft::neighbors::cagra::build_knn_graph( + raft_handle, + dataset_d, + knn_graph->view(), + 1.0f, + ivf_pq_params_, + ivf_pq_search_params_); + } else { + auto dataset_h = raft::make_host_matrix_view( + x, n, dim_); + raft::neighbors::cagra::build_knn_graph( + raft_handle, + dataset_h, + knn_graph->view(), + 1.0f, + ivf_pq_params_, + ivf_pq_search_params_); + } + auto cagra_graph = raft::make_host_matrix( + n, index_params_.graph_degree); + + raft::neighbors::cagra::optimize( + raft_handle, knn_graph->view(), cagra_graph.view()); + + // free intermediate graph before trying to create the index + knn_graph.reset(); + + if (getDeviceForAddress(x) >= 0) { + auto dataset_d = + raft::make_device_matrix_view( + x, n, dim_); + raft_knn_index = raft::neighbors::cagra::index( + raft_handle, + metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct, + dataset_d, + raft::make_const_mdspan(cagra_graph.view())); + } else { + auto dataset_h = raft::make_host_matrix_view( + x, n, dim_); + raft_knn_index = raft::neighbors::cagra::index( + raft_handle, + metric_ == faiss::METRIC_L2 + ? raft::distance::DistanceType::L2Expanded + : raft::distance::DistanceType::InnerProduct, + dataset_h, + raft::make_const_mdspan(cagra_graph.view())); + } + + } else { + if (getDeviceForAddress(x) >= 0) { + raft_knn_index = raft::runtime::neighbors::cagra::build( + raft_handle, + index_params_, + raft::make_device_matrix_view( + x, n, dim_)); + } else { + raft_knn_index = raft::runtime::neighbors::cagra::build( + raft_handle, + index_params_, + raft::make_host_matrix_view( + x, n, dim_)); + } + } +} + +void RaftCagra::search( + Tensor& queries, + int k, + Tensor& outDistances, + Tensor& outIndices, + idx_t max_queries, + idx_t itopk_size, + idx_t max_iterations, + faiss::cagra_search_algo graph_search_algo, + idx_t team_size, + idx_t search_width, + idx_t min_iterations, + idx_t thread_block_size, + faiss::cagra_hash_mode hash_mode, + idx_t hashmap_min_bitlen, + float hashmap_max_fill_rate, + idx_t num_random_samplings, + idx_t rand_xor_mask) { + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + idx_t numQueries = queries.getSize(0); + idx_t cols = queries.getSize(1); + idx_t k_ = k; + + FAISS_ASSERT(raft_knn_index.has_value()); + FAISS_ASSERT(numQueries > 0); + FAISS_ASSERT(cols == dim_); + + auto queries_view = raft::make_device_matrix_view( + queries.data(), numQueries, cols); + auto distances_view = raft::make_device_matrix_view( + outDistances.data(), numQueries, k_); + auto indices_view = raft::make_device_matrix_view( + outIndices.data(), numQueries, k_); + + raft::neighbors::cagra::search_params search_pams; + search_pams.max_queries = max_queries; + search_pams.itopk_size = itopk_size; + search_pams.max_iterations = max_iterations; + search_pams.algo = + static_cast(graph_search_algo); + search_pams.team_size = team_size; + search_pams.search_width = search_width; + search_pams.min_iterations = min_iterations; + search_pams.thread_block_size = thread_block_size; + search_pams.hashmap_mode = + static_cast(hash_mode); + search_pams.hashmap_min_bitlen = hashmap_min_bitlen; + search_pams.hashmap_max_fill_rate = hashmap_max_fill_rate; + search_pams.num_random_samplings = num_random_samplings; + search_pams.rand_xor_mask = rand_xor_mask; + + auto indices_copy = raft::make_device_matrix( + raft_handle, numQueries, k_); + + raft::runtime::neighbors::cagra::search( + raft_handle, + search_pams, + raft_knn_index.value(), + queries_view, + indices_copy.view(), + distances_view); + thrust::copy( + raft::resource::get_thrust_policy(raft_handle), + indices_copy.data_handle(), + indices_copy.data_handle() + indices_copy.size(), + indices_view.data_handle()); +} + +void RaftCagra::reset() { + raft_knn_index.reset(); +} + +idx_t RaftCagra::get_knngraph_degree() const { + FAISS_ASSERT(raft_knn_index.has_value()); + return static_cast(raft_knn_index.value().graph_degree()); +} + +std::vector RaftCagra::get_knngraph() const { + FAISS_ASSERT(raft_knn_index.has_value()); + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + auto stream = raft_handle.get_stream(); + + auto device_graph = raft_knn_index.value().graph(); + + std::vector host_graph( + device_graph.extent(0) * device_graph.extent(1)); + + raft_handle.sync_stream(); + + thrust::copy( + thrust::device_ptr(device_graph.data_handle()), + thrust::device_ptr( + device_graph.data_handle() + device_graph.size()), + host_graph.data()); + + return host_graph; +} + +std::vector RaftCagra::get_training_dataset() const { + FAISS_ASSERT(raft_knn_index.has_value()); + const raft::device_resources& raft_handle = + resources_->getRaftHandleCurrentDevice(); + auto stream = raft_handle.get_stream(); + + auto device_dataset = raft_knn_index.value().dataset(); + + std::vector host_dataset( + device_dataset.extent(0) * device_dataset.extent(1)); + + RAFT_CUDA_TRY(cudaMemcpy2DAsync( + host_dataset.data(), + sizeof(float) * dim_, + device_dataset.data_handle(), + sizeof(float) * device_dataset.stride(0), + sizeof(float) * dim_, + device_dataset.extent(0), + cudaMemcpyDefault, + raft_handle.get_stream())); + raft_handle.sync_stream(); + + return host_dataset; +} + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/impl/RaftCagra.cuh b/faiss/gpu/impl/RaftCagra.cuh new file mode 100644 index 0000000000..95f6c03fca --- /dev/null +++ b/faiss/gpu/impl/RaftCagra.cuh @@ -0,0 +1,132 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace faiss { + +/// Algorithm used to build underlying CAGRA graph +enum class cagra_build_algo { IVF_PQ, NN_DESCENT }; + +enum class cagra_search_algo { SINGLE_CTA, MULTI_CTA }; + +enum class cagra_hash_mode { HASH, SMALL, AUTO }; + +namespace gpu { + +class RaftCagra { + public: + RaftCagra( + GpuResources* resources, + int dim, + idx_t intermediate_graph_degree, + idx_t graph_degree, + faiss::cagra_build_algo graph_build_algo, + size_t nn_descent_niter, + faiss::MetricType metric, + float metricArg, + IndicesOptions indicesOptions, + std::optional ivf_pq_params = + std::nullopt, + std::optional + ivf_pq_search_params = std::nullopt); + + RaftCagra( + GpuResources* resources, + int dim, + idx_t n, + int graph_degree, + const float* distances, + const idx_t* knn_graph, + faiss::MetricType metric, + float metricArg, + IndicesOptions indicesOptions); + + ~RaftCagra() = default; + + void train(idx_t n, const float* x); + + void search( + Tensor& queries, + int k, + Tensor& outDistances, + Tensor& outIndices, + idx_t max_queries, + idx_t itopk_size, + idx_t max_iterations, + faiss::cagra_search_algo graph_search_algo, + idx_t team_size, + idx_t search_width, + idx_t min_iterations, + idx_t thread_block_size, + faiss::cagra_hash_mode hash_mode, + idx_t hashmap_min_bitlen, + float hashmap_max_fill_rate, + idx_t num_random_samplings, + idx_t rand_xor_mask); + + void reset(); + + idx_t get_knngraph_degree() const; + + std::vector get_knngraph() const; + + std::vector get_training_dataset() const; + + private: + /// Collection of GPU resources that we use + GpuResources* resources_; + + /// Expected dimensionality of the vectors + const int dim_; + + /// Metric type of the index + faiss::MetricType metric_; + + /// Metric arg + float metricArg_; + + /// Parameters to build RAFT CAGRA index + raft::neighbors::cagra::index_params index_params_; + + /// Parameters to build CAGRA graph using IVF PQ + std::optional ivf_pq_params_; + std::optional ivf_pq_search_params_; + + /// Instance of trained RAFT CAGRA index + std::optional> + raft_knn_index{std::nullopt}; +}; + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/test/CMakeLists.txt b/faiss/gpu/test/CMakeLists.txt index 9300deead9..60f78ef74f 100644 --- a/faiss/gpu/test/CMakeLists.txt +++ b/faiss/gpu/test/CMakeLists.txt @@ -21,7 +21,6 @@ find_package(CUDAToolkit REQUIRED) # Defines `gtest_discover_tests()`. include(GoogleTest) - add_library(faiss_gpu_test_helper TestUtils.cpp) target_link_libraries(faiss_gpu_test_helper PUBLIC faiss gtest CUDA::cudart $<$:raft::raft> $<$:raft::compiled>) @@ -42,6 +41,9 @@ faiss_gpu_test(TestGpuIndexIVFPQ.cpp) faiss_gpu_test(TestGpuIndexIVFScalarQuantizer.cpp) faiss_gpu_test(TestGpuDistance.cu) faiss_gpu_test(TestGpuSelect.cu) +if(FAISS_ENABLE_RAFT) + faiss_gpu_test(TestGpuIndexCagra.cu) +endif() add_executable(demo_ivfpq_indexing_gpu EXCLUDE_FROM_ALL demo_ivfpq_indexing_gpu.cpp) diff --git a/faiss/gpu/test/TestGpuIndexCagra.cu b/faiss/gpu/test/TestGpuIndexCagra.cu new file mode 100644 index 0000000000..8d330a81cb --- /dev/null +++ b/faiss/gpu/test/TestGpuIndexCagra.cu @@ -0,0 +1,474 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +struct Options { + Options() { + numTrain = 2 * faiss::gpu::randVal(2000, 5000); + dim = faiss::gpu::randVal(4, 10); + numAdd = faiss::gpu::randVal(1000, 3000); + + graphDegree = faiss::gpu::randSelect({32, 64}); + intermediateGraphDegree = faiss::gpu::randSelect({64, 98}); + buildAlgo = faiss::gpu::randSelect( + {faiss::gpu::graph_build_algo::IVF_PQ, + faiss::gpu::graph_build_algo::NN_DESCENT}); + + numQuery = faiss::gpu::randVal(32, 100); + k = faiss::gpu::randVal(10, 30); + + device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); + } + + std::string toString() const { + std::stringstream str; + str << "CAGRA device " << device << " numVecs " << numTrain << " dim " + << dim << " graphDegree " << graphDegree + << " intermediateGraphDegree " << intermediateGraphDegree + << "buildAlgo " << static_cast(buildAlgo) << " numQuery " + << numQuery << " k " << k; + + return str.str(); + } + + int numTrain; + int numAdd; + int dim; + size_t graphDegree; + size_t intermediateGraphDegree; + faiss::gpu::graph_build_algo buildAlgo; + int numQuery; + int k; + int device; +}; + +void queryTest(faiss::MetricType metric, double expected_recall) { + for (int tries = 0; tries < 5; ++tries) { + Options opt; + if (opt.buildAlgo == faiss::gpu::graph_build_algo::NN_DESCENT && + metric == faiss::METRIC_INNER_PRODUCT) { + continue; + } + + std::vector trainVecs = + faiss::gpu::randVecs(opt.numTrain, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numTrain, opt.dim, trainVecs.data()); + } + + // train cpu index + faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2, metric); + cpuIndex.hnsw.efConstruction = opt.k * 2; + cpuIndex.add(opt.numTrain, trainVecs.data()); + + // train gpu index + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + faiss::gpu::GpuIndexCagraConfig config; + config.device = opt.device; + config.graph_degree = opt.graphDegree; + config.intermediate_graph_degree = opt.intermediateGraphDegree; + config.build_algo = opt.buildAlgo; + + faiss::gpu::GpuIndexCagra gpuIndex(&res, cpuIndex.d, metric, config); + gpuIndex.train(opt.numTrain, trainVecs.data()); + + // query + auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numQuery, opt.dim, queryVecs.data()); + } + + std::vector refDistance(opt.numQuery * opt.k, 0); + std::vector refIndices(opt.numQuery * opt.k, -1); + faiss::SearchParametersHNSW cpuSearchParams; + cpuSearchParams.efSearch = opt.k * 2; + cpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + refDistance.data(), + refIndices.data(), + &cpuSearchParams); + + // test quality of searches + auto gpuRes = res.getResources(); + auto devAlloc = faiss::gpu::makeDevAlloc( + faiss::gpu::AllocType::FlatData, + gpuRes->getDefaultStreamCurrentDevice()); + faiss::gpu::DeviceTensor testDistance( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + faiss::gpu::DeviceTensor testIndices( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + gpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + testDistance.data(), + testIndices.data()); + + auto refDistanceDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + refDistance, + gpuRes->getDefaultStreamCurrentDevice()); + auto refIndicesDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + refIndices, + gpuRes->getDefaultStreamCurrentDevice()); + + auto raft_handle = gpuRes->getRaftHandleCurrentDevice(); + + auto ref_dis_mds = raft::make_device_matrix_view( + refDistanceDev.data(), opt.numQuery, opt.k); + auto ref_dis_mds_opt = + std::optional>( + ref_dis_mds); + auto ref_ind_mds = + raft::make_device_matrix_view( + refIndicesDev.data(), opt.numQuery, opt.k); + + auto test_dis_mds = raft::make_device_matrix_view( + testDistance.data(), opt.numQuery, opt.k); + auto test_dis_mds_opt = + std::optional>( + test_dis_mds); + + auto test_ind_mds = + raft::make_device_matrix_view( + testIndices.data(), opt.numQuery, opt.k); + + double scalar_init = 0; + auto recall_score = raft::make_host_scalar(scalar_init); + + raft::stats::neighborhood_recall( + raft_handle, + test_ind_mds, + ref_ind_mds, + recall_score.view(), + test_dis_mds_opt, + ref_dis_mds_opt); + ASSERT_TRUE(*recall_score.data_handle() > expected_recall); + } +} + +TEST(TestGpuIndexCagra, Float32_Query_L2) { + queryTest(faiss::METRIC_L2, 0.98); +} + +TEST(TestGpuIndexCagra, Float32_Query_IP) { + queryTest(faiss::METRIC_INNER_PRODUCT, 0.98); +} + +void copyToTest( + faiss::MetricType metric, + double expected_recall, + bool base_level_only) { + for (int tries = 0; tries < 5; ++tries) { + Options opt; + if (opt.buildAlgo == faiss::gpu::graph_build_algo::NN_DESCENT && + metric == faiss::METRIC_INNER_PRODUCT) { + continue; + } + + std::vector trainVecs = + faiss::gpu::randVecs(opt.numTrain, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numTrain, opt.dim, trainVecs.data()); + } + std::vector addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numAdd, opt.dim, addVecs.data()); + } + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + // train gpu index and copy to cpu index + faiss::gpu::GpuIndexCagraConfig config; + config.device = opt.device; + config.graph_degree = opt.graphDegree; + config.intermediate_graph_degree = opt.intermediateGraphDegree; + config.build_algo = opt.buildAlgo; + + faiss::gpu::GpuIndexCagra gpuIndex(&res, opt.dim, metric, config); + gpuIndex.train(opt.numTrain, trainVecs.data()); + + faiss::IndexHNSWCagra copiedCpuIndex( + opt.dim, opt.graphDegree / 2, metric); + copiedCpuIndex.base_level_only = base_level_only; + gpuIndex.copyTo(&copiedCpuIndex); + copiedCpuIndex.hnsw.efConstruction = opt.k * 2; + + // add more vecs to copied cpu index + if (!base_level_only) { + copiedCpuIndex.add(opt.numAdd, addVecs.data()); + } + + // train cpu index + faiss::IndexHNSWFlat cpuIndex(opt.dim, opt.graphDegree / 2, metric); + cpuIndex.hnsw.efConstruction = opt.k * 2; + cpuIndex.add(opt.numTrain, trainVecs.data()); + + // add more vecs to cpu index + if (!base_level_only) { + cpuIndex.add(opt.numAdd, addVecs.data()); + } + + // query indexes + auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numQuery, opt.dim, queryVecs.data()); + } + + std::vector refDistance(opt.numQuery * opt.k, 0); + std::vector refIndices(opt.numQuery * opt.k, -1); + faiss::SearchParametersHNSW cpuSearchParams; + cpuSearchParams.efSearch = opt.k * 2; + cpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + refDistance.data(), + refIndices.data(), + &cpuSearchParams); + + std::vector copyRefDistance(opt.numQuery * opt.k, 0); + std::vector copyRefIndices(opt.numQuery * opt.k, -1); + faiss::SearchParametersHNSW cpuSearchParamstwo; + cpuSearchParamstwo.efSearch = opt.k * 2; + copiedCpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + copyRefDistance.data(), + copyRefIndices.data(), + &cpuSearchParamstwo); + + // test quality of search + auto gpuRes = res.getResources(); + + auto refDistanceDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + refDistance, + gpuRes->getDefaultStreamCurrentDevice()); + auto refIndicesDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + refIndices, + gpuRes->getDefaultStreamCurrentDevice()); + + auto copyRefDistanceDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + copyRefDistance, + gpuRes->getDefaultStreamCurrentDevice()); + auto copyRefIndicesDev = faiss::gpu::toDeviceTemporary( + gpuRes.get(), + copyRefIndices, + gpuRes->getDefaultStreamCurrentDevice()); + + auto raft_handle = gpuRes->getRaftHandleCurrentDevice(); + + auto ref_dis_mds = raft::make_device_matrix_view( + refDistanceDev.data(), opt.numQuery, opt.k); + auto ref_dis_mds_opt = + std::optional>( + ref_dis_mds); + auto ref_ind_mds = + raft::make_device_matrix_view( + refIndicesDev.data(), opt.numQuery, opt.k); + + auto copy_ref_dis_mds = raft::make_device_matrix_view( + copyRefDistanceDev.data(), opt.numQuery, opt.k); + auto copy_ref_dis_mds_opt = + std::optional>( + copy_ref_dis_mds); + auto copy_ref_ind_mds = + raft::make_device_matrix_view( + copyRefIndicesDev.data(), opt.numQuery, opt.k); + + double scalar_init = 0; + auto recall_score = raft::make_host_scalar(scalar_init); + + raft::stats::neighborhood_recall( + raft_handle, + copy_ref_ind_mds, + ref_ind_mds, + recall_score.view(), + copy_ref_dis_mds_opt, + ref_dis_mds_opt); + ASSERT_TRUE(*recall_score.data_handle() > expected_recall); + } +} + +TEST(TestGpuIndexCagra, Float32_CopyTo_L2) { + copyToTest(faiss::METRIC_L2, 0.98, false); +} + +TEST(TestGpuIndexCagra, Float32_CopyTo_L2_BaseLevelOnly) { + copyToTest(faiss::METRIC_L2, 0.98, true); +} + +TEST(TestGpuIndexCagra, Float32_CopyTo_IP) { + copyToTest(faiss::METRIC_INNER_PRODUCT, 0.98, false); +} + +TEST(TestGpuIndexCagra, Float32_CopyTo_IP_BaseLevelOnly) { + copyToTest(faiss::METRIC_INNER_PRODUCT, 0.98, true); +} + +void copyFromTest(faiss::MetricType metric, double expected_recall) { + for (int tries = 0; tries < 5; ++tries) { + Options opt; + if (opt.buildAlgo == faiss::gpu::graph_build_algo::NN_DESCENT && + metric == faiss::METRIC_INNER_PRODUCT) { + continue; + } + + std::vector trainVecs = + faiss::gpu::randVecs(opt.numTrain, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numTrain, opt.dim, trainVecs.data()); + } + + // train cpu index + faiss::IndexHNSWCagra cpuIndex(opt.dim, opt.graphDegree / 2, metric); + cpuIndex.hnsw.efConstruction = opt.k * 2; + cpuIndex.add(opt.numTrain, trainVecs.data()); + + faiss::gpu::StandardGpuResources res; + res.noTempMemory(); + + // convert to gpu index + faiss::gpu::GpuIndexCagra copiedGpuIndex(&res, cpuIndex.d, metric); + copiedGpuIndex.copyFrom(&cpuIndex); + + // train gpu index + faiss::gpu::GpuIndexCagraConfig config; + config.device = opt.device; + config.graph_degree = opt.graphDegree; + config.intermediate_graph_degree = opt.intermediateGraphDegree; + config.build_algo = opt.buildAlgo; + + faiss::gpu::GpuIndexCagra gpuIndex(&res, opt.dim, metric, config); + gpuIndex.train(opt.numTrain, trainVecs.data()); + + // query + auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim); + if (metric == faiss::METRIC_INNER_PRODUCT) { + faiss::fvec_renorm_L2(opt.numQuery, opt.dim, queryVecs.data()); + } + + auto gpuRes = res.getResources(); + auto devAlloc = faiss::gpu::makeDevAlloc( + faiss::gpu::AllocType::FlatData, + gpuRes->getDefaultStreamCurrentDevice()); + faiss::gpu::DeviceTensor copyTestDistance( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + faiss::gpu::DeviceTensor copyTestIndices( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + copiedGpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + copyTestDistance.data(), + copyTestIndices.data()); + + faiss::gpu::DeviceTensor testDistance( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + faiss::gpu::DeviceTensor testIndices( + gpuRes.get(), devAlloc, {opt.numQuery, opt.k}); + gpuIndex.search( + opt.numQuery, + queryVecs.data(), + opt.k, + testDistance.data(), + testIndices.data()); + + // test quality of searches + auto raft_handle = gpuRes->getRaftHandleCurrentDevice(); + + auto test_dis_mds = raft::make_device_matrix_view( + testDistance.data(), opt.numQuery, opt.k); + auto test_dis_mds_opt = + std::optional>( + test_dis_mds); + + auto test_ind_mds = + raft::make_device_matrix_view( + testIndices.data(), opt.numQuery, opt.k); + + auto copy_test_dis_mds = + raft::make_device_matrix_view( + copyTestDistance.data(), opt.numQuery, opt.k); + auto copy_test_dis_mds_opt = + std::optional>( + copy_test_dis_mds); + + auto copy_test_ind_mds = + raft::make_device_matrix_view( + copyTestIndices.data(), opt.numQuery, opt.k); + + double scalar_init = 0; + auto recall_score = raft::make_host_scalar(scalar_init); + + raft::stats::neighborhood_recall( + raft_handle, + copy_test_ind_mds, + test_ind_mds, + recall_score.view(), + copy_test_dis_mds_opt, + test_dis_mds_opt); + ASSERT_TRUE(*recall_score.data_handle() > expected_recall); + } +} + +TEST(TestGpuIndexCagra, Float32_CopyFrom_L2) { + copyFromTest(faiss::METRIC_L2, 0.98); +} + +TEST(TestGpuIndexCagra, Float32_CopyFrom_IP) { + copyFromTest(faiss::METRIC_INNER_PRODUCT, 0.98); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); +} diff --git a/faiss/gpu/test/test_cagra.py b/faiss/gpu/test/test_cagra.py new file mode 100644 index 0000000000..dd7d09f2de --- /dev/null +++ b/faiss/gpu/test/test_cagra.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import faiss +import numpy as np + +from faiss.contrib import datasets, evaluation + + +@unittest.skipIf( + "RAFT" not in faiss.get_compile_options(), + "only if RAFT is compiled in") +class TestComputeGT(unittest.TestCase): + + def do_compute_GT(self, metric): + d = 64 + k = 12 + ds = datasets.SyntheticDataset(d, 0, 10000, 100) + Dref, Iref = faiss.knn(ds.get_queries(), ds.get_database(), k, metric) + + res = faiss.StandardGpuResources() + + index = faiss.GpuIndexCagra(res, d, metric) + index.train(ds.get_database()) + Dnew, Inew = index.search(ds.get_queries(), k) + + evaluation.check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, k) + + def test_compute_GT_L2(self): + self.do_compute_GT(faiss.METRIC_L2) + + def test_compute_GT_IP(self): + self.do_compute_GT(faiss.METRIC_INNER_PRODUCT) + +@unittest.skipIf( + "RAFT" not in faiss.get_compile_options(), + "only if RAFT is compiled in") +class TestInterop(unittest.TestCase): + + def do_interop(self, metric): + d = 64 + k = 12 + ds = datasets.SyntheticDataset(d, 0, 10000, 100) + + res = faiss.StandardGpuResources() + + index = faiss.GpuIndexCagra(res, d, metric) + index.train(ds.get_database()) + Dnew, Inew = index.search(ds.get_queries(), k) + + cpu_index = faiss.index_gpu_to_cpu(index) + Dref, Iref = cpu_index.search(ds.get_queries(), k) + + evaluation.check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, k) + + faiss.write_index(cpu_index, "index_hnsw_cagra.index") + deserialized_index = faiss.read_index("index_hnsw_cagra.index") + gpu_index = faiss.index_cpu_to_gpu(res, 0, deserialized_index) + Dnew2, Inew2 = gpu_index.search(ds.get_queries(), k) + + evaluation.check_ref_knn_with_draws(Dnew2, Inew2, Dnew, Inew, k) + + def test_interop_L2(self): + self.do_interop(faiss.METRIC_L2) + + def test_interop_IP(self): + self.do_interop(faiss.METRIC_INNER_PRODUCT) diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index d8c8225968..3ba5f72f68 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -7,6 +7,7 @@ #include +#include #include #include @@ -215,8 +216,8 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) { if (pt_level > max_level) max_level = pt_level; offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1)); - neighbors.resize(offsets.back(), -1); } + neighbors.resize(offsets.back(), -1); return max_level; } @@ -229,7 +230,14 @@ void HNSW::shrink_neighbor_list( DistanceComputer& qdis, std::priority_queue& input, std::vector& output, - int max_size) { + int max_size, + bool keep_max_size_level0) { + // This prevents number of neighbors at + // level 0 from being shrunk to less than 2 * M. + // This is essential in making sure + // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional + std::vector outsiders; + while (input.size() > 0) { NodeDistFarther v1 = input.top(); input.pop(); @@ -250,8 +258,15 @@ void HNSW::shrink_neighbor_list( if (output.size() >= max_size) { return; } + } else if (keep_max_size_level0) { + outsiders.push_back(v1); } } + size_t idx = 0; + while (keep_max_size_level0 && (output.size() < max_size) && + (idx < outsiders.size())) { + output.push_back(outsiders[idx++]); + } } namespace { @@ -268,7 +283,8 @@ using NodeDistFarther = HNSW::NodeDistFarther; void shrink_neighbor_list( DistanceComputer& qdis, std::priority_queue& resultSet1, - int max_size) { + int max_size, + bool keep_max_size_level0 = false) { if (resultSet1.size() < max_size) { return; } @@ -280,7 +296,8 @@ void shrink_neighbor_list( resultSet1.pop(); } - HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size); + HNSW::shrink_neighbor_list( + qdis, resultSet, returnlist, max_size, keep_max_size_level0); for (NodeDistFarther curen2 : returnlist) { resultSet1.emplace(curen2.d, curen2.id); @@ -294,7 +311,8 @@ void add_link( DistanceComputer& qdis, storage_idx_t src, storage_idx_t dest, - int level) { + int level, + bool keep_max_size_level0 = false) { size_t begin, end; hnsw.neighbor_range(src, level, &begin, &end); if (hnsw.neighbors[end - 1] == -1) { @@ -319,7 +337,7 @@ void add_link( resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh); } - shrink_neighbor_list(qdis, resultSet, end - begin); + shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0); // ...and back size_t i = begin; @@ -429,7 +447,8 @@ void HNSW::add_links_starting_from( float d_nearest, int level, omp_lock_t* locks, - VisitedTable& vt) { + VisitedTable& vt, + bool keep_max_size_level0) { std::priority_queue link_targets; search_neighbors_to_add( @@ -438,13 +457,13 @@ void HNSW::add_links_starting_from( // but we can afford only this many neighbors int M = nb_neighbors(level); - ::faiss::shrink_neighbor_list(ptdis, link_targets, M); + ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0); std::vector neighbors; neighbors.reserve(link_targets.size()); while (!link_targets.empty()) { storage_idx_t other_id = link_targets.top().id; - add_link(*this, ptdis, pt_id, other_id, level); + add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0); neighbors.push_back(other_id); link_targets.pop(); } @@ -452,7 +471,7 @@ void HNSW::add_links_starting_from( omp_unset_lock(&locks[pt_id]); for (storage_idx_t other_id : neighbors) { omp_set_lock(&locks[other_id]); - add_link(*this, ptdis, other_id, pt_id, level); + add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0); omp_unset_lock(&locks[other_id]); } omp_set_lock(&locks[pt_id]); @@ -467,7 +486,8 @@ void HNSW::add_with_locks( int pt_level, int pt_id, std::vector& locks, - VisitedTable& vt) { + VisitedTable& vt, + bool keep_max_size_level0) { // greedy search on upper levels storage_idx_t nearest; @@ -496,7 +516,14 @@ void HNSW::add_with_locks( for (; level >= 0; level--) { add_links_starting_from( - ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt); + ptdis, + pt_id, + nearest, + d_nearest, + level, + locks.data(), + vt, + keep_max_size_level0); } omp_unset_lock(&locks[pt_id]); @@ -910,9 +937,12 @@ void HNSW::search_level_0( const float* nearest_d, int search_type, HNSWStats& search_stats, - VisitedTable& vt) const { + VisitedTable& vt, + const SearchParametersHNSW* params) const { const HNSW& hnsw = *this; + auto efSearch = params ? params->efSearch : hnsw.efSearch; int k = extract_k_from_ResultHandler(res); + if (search_type == 1) { int nres = 0; @@ -925,16 +955,24 @@ void HNSW::search_level_0( if (vt.get(cj)) continue; - int candidates_size = std::max(hnsw.efSearch, k); + int candidates_size = std::max(efSearch, k); MinimaxHeap candidates(candidates_size); candidates.push(cj, nearest_d[j]); nres = search_from_candidates( - hnsw, qdis, res, candidates, vt, search_stats, 0, nres); + hnsw, + qdis, + res, + candidates, + vt, + search_stats, + 0, + nres, + params); } } else if (search_type == 2) { - int candidates_size = std::max(hnsw.efSearch, int(k)); + int candidates_size = std::max(efSearch, int(k)); candidates_size = std::max(candidates_size, int(nprobe)); MinimaxHeap candidates(candidates_size); @@ -947,7 +985,7 @@ void HNSW::search_level_0( } search_from_candidates( - hnsw, qdis, res, candidates, vt, search_stats, 0); + hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params); } } diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index 8261423cdd..f3aacf8a5b 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -184,7 +184,8 @@ struct HNSW { float d_nearest, int level, omp_lock_t* locks, - VisitedTable& vt); + VisitedTable& vt, + bool keep_max_size_level0 = false); /** add point pt_id on all levels <= pt_level and build the link * structure for them. */ @@ -193,7 +194,8 @@ struct HNSW { int pt_level, int pt_id, std::vector& locks, - VisitedTable& vt); + VisitedTable& vt, + bool keep_max_size_level0 = false); /// search interface for 1 point, single thread HNSWStats search( @@ -211,7 +213,8 @@ struct HNSW { const float* nearest_d, int search_type, HNSWStats& search_stats, - VisitedTable& vt) const; + VisitedTable& vt, + const SearchParametersHNSW* params = nullptr) const; void reset(); @@ -224,7 +227,8 @@ struct HNSW { DistanceComputer& qdis, std::priority_queue& input, std::vector& output, - int max_size); + int max_size, + bool keep_max_size_level0 = false); void permute_entries(const idx_t* map); }; diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index 8d80329bf9..1085d3a0d1 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -948,7 +948,7 @@ Index* read_index(IOReader* f, int io_flags) { idx = idxp; } else if ( h == fourcc("IHNf") || h == fourcc("IHNp") || h == fourcc("IHNs") || - h == fourcc("IHN2")) { + h == fourcc("IHN2") || h == fourcc("IHNc")) { IndexHNSW* idxhnsw = nullptr; if (h == fourcc("IHNf")) idxhnsw = new IndexHNSWFlat(); @@ -958,7 +958,15 @@ Index* read_index(IOReader* f, int io_flags) { idxhnsw = new IndexHNSWSQ(); if (h == fourcc("IHN2")) idxhnsw = new IndexHNSW2Level(); + if (h == fourcc("IHNc")) + idxhnsw = new IndexHNSWCagra(); read_index_header(idxhnsw, f); + if (h == fourcc("IHNc")) { + READ1(idxhnsw->keep_max_size_level0); + auto idx_hnsw_cagra = dynamic_cast(idxhnsw); + READ1(idx_hnsw_cagra->base_level_only); + READ1(idx_hnsw_cagra->num_base_level_search_entrypoints); + } read_HNSW(&idxhnsw->hnsw, f); idxhnsw->storage = read_index(f, io_flags); idxhnsw->own_fields = true; diff --git a/faiss/impl/index_write.cpp b/faiss/impl/index_write.cpp index b2808d7170..24303ac376 100644 --- a/faiss/impl/index_write.cpp +++ b/faiss/impl/index_write.cpp @@ -760,10 +760,17 @@ void write_index(const Index* idx, IOWriter* f) { : dynamic_cast(idx) ? fourcc("IHNp") : dynamic_cast(idx) ? fourcc("IHNs") : dynamic_cast(idx) ? fourcc("IHN2") + : dynamic_cast(idx) ? fourcc("IHNc") : 0; FAISS_THROW_IF_NOT(h != 0); WRITE1(h); write_index_header(idxhnsw, f); + if (h == fourcc("IHNc")) { + WRITE1(idxhnsw->keep_max_size_level0); + auto idx_hnsw_cagra = dynamic_cast(idxhnsw); + WRITE1(idx_hnsw_cagra->base_level_only); + WRITE1(idx_hnsw_cagra->num_base_level_search_entrypoints); + } write_HNSW(&idxhnsw->hnsw, f); write_index(idxhnsw->storage, f); } else if (const IndexNSG* idxnsg = dynamic_cast(idx)) { diff --git a/faiss/python/CMakeLists.txt b/faiss/python/CMakeLists.txt index dee8c7762e..0073c20e04 100644 --- a/faiss/python/CMakeLists.txt +++ b/faiss/python/CMakeLists.txt @@ -38,6 +38,11 @@ macro(configure_swigfaiss source) set_source_files_properties(${source} PROPERTIES COMPILE_DEFINITIONS GPU_WRAPPER ) + if (FAISS_ENABLE_RAFT) + set_property(SOURCE ${source} APPEND PROPERTY + COMPILE_DEFINITIONS FAISS_ENABLE_RAFT + ) + endif() endif() endmacro() diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 85e04d322c..74a371f6cd 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -304,6 +304,7 @@ void gpu_sync_all_devices(); #include #include #include +#include #include #include #include @@ -557,6 +558,9 @@ struct faiss::simd16uint16 {}; %include %include %include +#ifdef FAISS_ENABLE_RAFT +%include +#endif %include %include %include @@ -673,6 +677,9 @@ struct faiss::simd16uint16 {}; DOWNCAST ( IndexRowwiseMinMax ) DOWNCAST ( IndexRowwiseMinMaxFP16 ) #ifdef GPU_WRAPPER +#ifdef FAISS_ENABLE_RAFT + DOWNCAST_GPU ( GpuIndexCagra ) +#endif DOWNCAST_GPU ( GpuIndexIVFPQ ) DOWNCAST_GPU ( GpuIndexIVFFlat ) DOWNCAST_GPU ( GpuIndexIVFScalarQuantizer )