Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Intel®-AMX/oneDNN to accelerate IndexFlatIP search #3266

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Expand Up @@ -52,6 +52,7 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")

# Valid values are "generic", "avx2", "avx512".
option(FAISS_OPT_LEVEL "" "generic")
option(FAISS_ENABLE_DNNL "Enable support for onednn to accelerate indexflat search." OFF)
option(FAISS_ENABLE_GPU "Enable support for GPU indexes." ON)
option(FAISS_ENABLE_RAFT "Enable RAFT for GPU indexes." OFF)
option(FAISS_ENABLE_PYTHON "Build Python extension." ON)
Expand Down Expand Up @@ -80,6 +81,7 @@ if(FAISS_ENABLE_C_API)
add_subdirectory(c_api)
endif()


add_subdirectory(demos)
add_subdirectory(benchs)
add_subdirectory(tutorial/cpp)
Expand Down
5 changes: 5 additions & 0 deletions INSTALL.md
Expand Up @@ -84,6 +84,9 @@ The optional requirements are:
- for GPU indices:
- nvcc,
- the CUDA toolkit,
- for Intel®-AMX/oneDNN acceleration:
- oneDNN,
- 4th+ Gen Intel® Xeon® Scalable processor.
- for the python bindings:
- python 3,
- numpy,
Expand All @@ -105,6 +108,8 @@ Several options can be passed to CMake, among which:
- general options:
- `-DFAISS_ENABLE_GPU=OFF` in order to disable building GPU indices (possible
values are `ON` and `OFF`),
- `-DFAISS_ENABLE_DNNL=OFF` in order to support for Intel®-AMX/oneDNN to accelerate indexflat(inner_product) search (possible
values are `ON` and `OFF`, before invoking CMake and setting this option to `ON`, you can refer to this [link](https://oneapi-src.github.io/oneDNN/dev_guide_build.html) for installing oneDNN),
- `-DFAISS_ENABLE_PYTHON=OFF` in order to disable building python bindings
(possible values are `ON` and `OFF`),
- `-DFAISS_ENABLE_RAFT=ON` in order to enable building the RAFT implementations
Expand Down
8 changes: 8 additions & 0 deletions c_api/utils/distances_c.cpp
Expand Up @@ -100,3 +100,11 @@ void faiss_set_distance_compute_min_k_reservoir(int value) {
int faiss_get_distance_compute_min_k_reservoir() {
return faiss::distance_compute_min_k_reservoir;
}

void faiss_set_distance_compute_dnnl_query_bs(int value) {
faiss::distance_compute_dnnl_query_bs = value;
}

int faiss_get_distance_compute_dnnl_query_bs() {
return faiss::distance_compute_dnnl_query_bs;
}
12 changes: 12 additions & 0 deletions c_api/utils/distances_c.h
Expand Up @@ -103,6 +103,18 @@ void faiss_set_distance_compute_min_k_reservoir(int value);
/// rather than a heap
int faiss_get_distance_compute_min_k_reservoir();

/// Setter of block sizes value for oneDNN/AMX distance computations
void faiss_set_distance_compute_dnnl_query_bs(int value);

/// Getter of block sizes value for oneDNN/AMX distance computations
int faiss_get_distance_compute_dnnl_query_bs();

/// Setter of block sizes value for oneDNN/AMX distance computations
void faiss_set_distance_compute_dnnl_database_bs(int value);

/// Getter of block sizes value for oneDNN/AMX distance computations
int faiss_get_distance_compute_dnnl_database_bs();

#ifdef __cplusplus
}
#endif
Expand Down
18 changes: 18 additions & 0 deletions faiss/CMakeLists.txt
Expand Up @@ -226,6 +226,15 @@ if(NOT WIN32)
list(APPEND FAISS_HEADERS invlists/OnDiskInvertedLists.h)
endif()

if(FAISS_ENABLE_DNNL)
list(APPEND FAISS_HEADERS utils/onednn/onednn_utils.h)
endif()

if(FAISS_ENABLE_DNNL)
add_compile_definitions(ENABLE_DNNL)
endif()


# Export FAISS_HEADERS variable to parent scope.
set(FAISS_HEADERS ${FAISS_HEADERS} PARENT_SCOPE)

Expand Down Expand Up @@ -294,6 +303,15 @@ target_compile_definitions(faiss PRIVATE FINTEGER=int)
target_compile_definitions(faiss_avx2 PRIVATE FINTEGER=int)
target_compile_definitions(faiss_avx512 PRIVATE FINTEGER=int)

if(FAISS_ENABLE_DNNL)
find_library(RT_LIB rt)
find_library(DNNL_LIB dnnl)
message(DNNL_LIB=${DNNL_LIB})
target_link_libraries(faiss PRIVATE ${RT_LIB} ${DNNL_LIB})
target_link_libraries(faiss_avx2 PRIVATE ${RT_LIB} ${DNNL_LIB})
target_link_libraries(faiss_avx512 PRIVATE ${RT_LIB} ${DNNL_LIB})
endif()

find_package(OpenMP REQUIRED)
target_link_libraries(faiss PRIVATE OpenMP::OpenMP_CXX)
target_link_libraries(faiss_avx2 PRIVATE OpenMP::OpenMP_CXX)
Expand Down
91 changes: 76 additions & 15 deletions faiss/utils/distances.cpp
Expand Up @@ -20,6 +20,10 @@
#include <immintrin.h>
#endif

#ifdef ENABLE_DNNL
#include <faiss/utils/onednn/onednn_utils.h>
#endif

#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
Expand Down Expand Up @@ -145,26 +149,60 @@ void exhaustive_inner_product_seq(

FAISS_ASSERT(use_sel == (sel != nullptr));

#ifdef ENABLE_DNNL
// use AMX to accelerate if available
if (is_amxbf16_supported()) {
float* res_arr = (float*)malloc(nx * ny * sizeof(float));
comput_f32bf16f32_inner_product(
nx,
d,
ny,
d,
const_cast<float*>(x),
const_cast<float*>(y),
res_arr);

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
{
SingleResultHandler resi(res);
#pragma omp for
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
const float* y_j = y;
for (size_t i = 0; i < nx; i++) {
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
float ip = res_arr[i * ny + j];
resi.add_result(ip, j);
}
resi.end();
}
}
delete[] res_arr;
} else {
#endif

resi.begin(i);
#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
#pragma omp for
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
const float* y_j = y;

for (size_t j = 0; j < ny; j++, y_j += d) {
if (use_sel && !sel->is_member(j)) {
continue;
resi.begin(i);

for (size_t j = 0; j < ny; j++, y_j += d) {
if (use_sel && !sel->is_member(j)) {
continue;
}
float ip = fvec_inner_product(x_i, y_j, d);
resi.add_result(ip, j);
}
float ip = fvec_inner_product(x_i, y_j, d);
resi.add_result(ip, j);
resi.end();
}
resi.end();
}

#ifdef ENABLE_DNNL
}
#endif
}

template <class BlockResultHandler, bool use_sel = false>
Expand Down Expand Up @@ -216,8 +254,16 @@ void exhaustive_inner_product_blas(
return;

/* block sizes */
const size_t bs_x = distance_compute_blas_query_bs;
const size_t bs_y = distance_compute_blas_database_bs;
size_t prov_bs_x = distance_compute_blas_query_bs;
size_t prov_bs_y = distance_compute_blas_database_bs;
#ifdef ENABLE_DNNL
if (is_amxbf16_supported()) {
prov_bs_x = distance_compute_dnnl_query_bs;
prov_bs_y = distance_compute_dnnl_database_bs;
}
#endif
const size_t bs_x = prov_bs_x;
const size_t bs_y = prov_bs_y;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);

for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
Expand All @@ -231,7 +277,20 @@ void exhaustive_inner_product_blas(
size_t j1 = j0 + bs_y;
if (j1 > ny)
j1 = ny;
/* compute the actual dot products */
/* compute the actual dot products */
#ifdef ENABLE_DNNL
if (is_amxbf16_supported()) {
FINTEGER nyi = j1 - j0, nxi = i1 - i0;
comput_f32bf16f32_inner_product(
nxi,
d,
nyi,
d,
const_cast<float*>(x + i0 * d),
const_cast<float*>(y + j0 * d),
ip_block.get());
} else
#endif
{
float one = 1, zero = 0;
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
Expand Down Expand Up @@ -650,6 +709,8 @@ int distance_compute_blas_threshold = 20;
int distance_compute_blas_query_bs = 4096;
int distance_compute_blas_database_bs = 1024;
int distance_compute_min_k_reservoir = 100;
int distance_compute_dnnl_query_bs = 10240;
int distance_compute_dnnl_database_bs = 10240;

void knn_inner_product(
const float* x,
Expand Down
4 changes: 4 additions & 0 deletions faiss/utils/distances.h
Expand Up @@ -281,6 +281,10 @@ FAISS_API extern int distance_compute_blas_threshold;
FAISS_API extern int distance_compute_blas_query_bs;
FAISS_API extern int distance_compute_blas_database_bs;

// block sizes for oneDNN/AMX distance computations
FAISS_API extern int distance_compute_dnnl_query_bs;
FAISS_API extern int distance_compute_dnnl_database_bs;

// above this number of results we switch to a reservoir to collect results
// rather than a heap
FAISS_API extern int distance_compute_min_k_reservoir;
Expand Down
141 changes: 141 additions & 0 deletions faiss/utils/onednn/onednn_utils.h
@@ -0,0 +1,141 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

/* All distance functions for L2 and IP distances.
* The actual functions are implemented in distances.cpp and distances_simd.cpp
*/

#pragma once
#include <stdlib.h>
#include <mutex>
#include <shared_mutex>
#include "oneapi/dnnl/dnnl.hpp"

namespace faiss {
static dnnl::engine cpu_engine;
static dnnl::stream engine_stream;
static bool is_onednn_init = false;
static std::mutex init_mutex;

static bool is_amxbf16_supported() {
unsigned int eax, ebx, ecx, edx;
__asm__ __volatile__("cpuid"
: "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
: "a"(7), "c"(0));
return edx & (1 << 22);
}

static void init_onednn() {
std::unique_lock<std::mutex> lock(init_mutex);

if (is_onednn_init) {
return;
}

// init dnnl engine
cpu_engine = dnnl::engine(dnnl::engine::kind::cpu, 0);
engine_stream = dnnl::stream(cpu_engine);

is_onednn_init = true;
}

__attribute__((constructor)) static void library_load() {
// this functionn will be automatically called when the library is loaded
// printf("Library loaded.\n");
init_onednn();
}

/**
* @brief Compute float32 matrix inner product with bf16 intermediate results to
* accelerate
* @details The main idea is:
* 1. Define float32 memory layout for input and output
* 2. Create low precision bf16 memory descriptors as inner product input
* 3. Generate inner product primitive descriptor
* 4. Execute float32 => (reorder) => bf16 => (inner product) => float32
* chain operation, isolate different precision data, accelerate inner
* product
* 5. Pipeline execution via streams for asynchronous scheduling
*
* @param xrow Row number of input matrix X
* @param xcol Column number of input matrix X
* @param yrow Row number of weight matrix Y
* @param ycol Column number of weight matrix Y
* @param in_f32_1 Input matrix pointer in float32 type
* @param in_f32_2 Weight matrix pointer in float32 type
* @param out_f32 Output matrix pointer for result in float32 type
* @return None
*/
static void comput_f32bf16f32_inner_product(
uint32_t xrow,
uint32_t xcol,
uint32_t yrow,
uint32_t ycol,
float* in_f32_1,
float* in_f32_2,
float* out_f32) {
dnnl::memory::desc f32_md1 = dnnl::memory::desc(
{xrow, xcol},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::ab);
dnnl::memory::desc f32_md2 = dnnl::memory::desc(
{yrow, ycol},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::ab);
dnnl::memory::desc f32_dst_md2 = dnnl::memory::desc(
{xrow, yrow},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::ab);

dnnl::memory f32_mem1 = dnnl::memory(f32_md1, cpu_engine, in_f32_1);
dnnl::memory f32_mem2 = dnnl::memory(f32_md2, cpu_engine, in_f32_2);
dnnl::memory f32_dst_mem = dnnl::memory(f32_dst_md2, cpu_engine, out_f32);

// inner memory bf16
dnnl::memory::desc bf16_md1 = dnnl::memory::desc(
{xrow, xcol},
dnnl::memory::data_type::bf16,
dnnl::memory::format_tag::any);
dnnl::memory::desc bf16_md2 = dnnl::memory::desc(
{yrow, ycol},
dnnl::memory::data_type::bf16,
dnnl::memory::format_tag::any);

dnnl::inner_product_forward::primitive_desc inner_product_pd =
dnnl::inner_product_forward::primitive_desc(
cpu_engine,
dnnl::prop_kind::forward_training,
bf16_md1,
bf16_md2,
f32_dst_md2);

dnnl::inner_product_forward inner_product_prim =
dnnl::inner_product_forward(inner_product_pd);

dnnl::memory bf16_mem1 =
dnnl::memory(inner_product_pd.src_desc(), cpu_engine);
dnnl::reorder(f32_mem1, bf16_mem1)
.execute(engine_stream, f32_mem1, bf16_mem1);

dnnl::memory bf16_mem2 =
dnnl::memory(inner_product_pd.weights_desc(), cpu_engine);
dnnl::reorder(f32_mem2, bf16_mem2)
.execute(engine_stream, f32_mem2, bf16_mem2);

inner_product_prim.execute(
engine_stream,
{{DNNL_ARG_SRC, bf16_mem1},
{DNNL_ARG_WEIGHTS, bf16_mem2},
{DNNL_ARG_DST, f32_dst_mem}});

// Wait for the computation to finalize.
engine_stream.wait();

// printf("comput_f32bf16f32_inner_product finished#######>\n");
}

} // namespace faiss