Skip to content

Commit

Permalink
formatted distances.cpp and onednn_utils.h
Browse files Browse the repository at this point in the history
  • Loading branch information
guangzegu committed Mar 11, 2024
1 parent 7db8ce6 commit b35a0f2
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 86 deletions.
61 changes: 34 additions & 27 deletions faiss/utils/distances.cpp
Expand Up @@ -146,54 +146,61 @@ void exhaustive_inner_product_seq(
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
int nt = std::min(int(nx), omp_get_max_threads());

FAISS_ASSERT(use_sel == (sel != nullptr));

#ifdef ENABLE_DNNL
// use AMX to accelerate if available
if (is_amxbf16_supported()) {
float *res_arr = (float *)malloc(nx * ny * sizeof(float));
comput_f32bf16f32_inner_product(nx, d, ny, d, const_cast<float*>(x), const_cast<float*>(y), res_arr);
float* res_arr = (float*)malloc(nx * ny * sizeof(float));
comput_f32bf16f32_inner_product(
nx,
d,
ny,
d,
const_cast<float*>(x),
const_cast<float*>(y),
res_arr);

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
{
SingleResultHandler resi(res);
#pragma omp for
for(size_t i = 0; i < nx; i++) {
resi.begin(i);
for(size_t j = 0; j < ny; j++){
float ip = res_arr[i*ny + j];
resi.add_result(ip , j);
for (size_t i = 0; i < nx; i++) {
resi.begin(i);
for (size_t j = 0; j < ny; j++) {
float ip = res_arr[i * ny + j];
resi.add_result(ip, j);
}
resi.end();
}
resi.end();
}
}
delete[] res_arr;
} else {
#endif

#pragma omp parallel num_threads(nt)
{
SingleResultHandler resi(res);
{
SingleResultHandler resi(res);
#pragma omp for
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
const float* y_j = y;
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
const float* y_j = y;

resi.begin(i);
resi.begin(i);

for (size_t j = 0; j < ny; j++, y_j += d) {
if (use_sel && !sel->is_member(j)) {
continue;
for (size_t j = 0; j < ny; j++, y_j += d) {
if (use_sel && !sel->is_member(j)) {
continue;
}
float ip = fvec_inner_product(x_i, y_j, d);
resi.add_result(ip, j);
}
float ip = fvec_inner_product(x_i, y_j, d);
resi.add_result(ip, j);
resi.end();
}
resi.end();
}
}

#ifdef ENABLE_DNNL

#ifdef ENABLE_DNNL
}
#endif
}
Expand Down
147 changes: 88 additions & 59 deletions faiss/utils/onednn/onednn_utils.h
Expand Up @@ -22,91 +22,120 @@ static bool is_onednn_init = false;
static std::mutex init_mutex;

static bool is_amxbf16_supported() {
unsigned int eax, ebx, ecx, edx;
__asm__ __volatile__(
"cpuid"
: "=a" (eax), "=b" (ebx), "=c" (ecx), "=d" (edx)
: "a" (7), "c" (0)
);
unsigned int eax, ebx, ecx, edx;
__asm__ __volatile__("cpuid"
: "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
: "a"(7), "c"(0));
return edx & (1 << 22);
}

static void init_onednn() {
std::unique_lock<std::mutex> lock(init_mutex);

if (is_onednn_init) {
return;
}
std::unique_lock<std::mutex> lock(init_mutex);

// init dnnl engine
cpu_engine = dnnl::engine(dnnl::engine::kind::cpu, 0);
engine_stream = dnnl::stream(cpu_engine);
if (is_onednn_init) {
return;
}

is_onednn_init = true;
// init dnnl engine
cpu_engine = dnnl::engine(dnnl::engine::kind::cpu, 0);
engine_stream = dnnl::stream(cpu_engine);

is_onednn_init = true;
}

__attribute__((constructor))
static void library_load() {
__attribute__((constructor)) static void library_load() {
// this functionn will be automatically called when the library is loaded
// printf("Library loaded.\n");
init_onednn();
}

/**
* @brief Compute float32 matrix inner product with bf16 intermediate results to accelerate
* @details The main idea is:
* @brief Compute float32 matrix inner product with bf16 intermediate results to
* accelerate
* @details The main idea is:
* 1. Define float32 memory layout for input and output
* 2. Create low precision bf16 memory descriptors as inner product input
* 3. Generate inner product primitive descriptor
* 4. Execute float32 => (reorder) => bf16 => (inner product) => float32
* chain operation, isolate different precision data, accelerate inner product
* chain operation, isolate different precision data, accelerate inner
* product
* 5. Pipeline execution via streams for asynchronous scheduling
*
* @param xrow Row number of input matrix X
* @param xrow Row number of input matrix X
* @param xcol Column number of input matrix X
* @param yrow Row number of weight matrix Y
* @param ycol Column number of weight matrix Y
* @param in_f32_1 Input matrix pointer in float32 type
* @param in_f32_1 Input matrix pointer in float32 type
* @param in_f32_2 Weight matrix pointer in float32 type
* @param out_f32 Output matrix pointer for result in float32 type
* @return None
*/
static void comput_f32bf16f32_inner_product(uint32_t xrow, uint32_t xcol, uint32_t yrow, uint32_t ycol,
float* in_f32_1, float* in_f32_2, float* out_f32) {

dnnl::memory::desc f32_md1 = dnnl::memory::desc({xrow, xcol}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab);
dnnl::memory::desc f32_md2 = dnnl::memory::desc({yrow, ycol}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab);
dnnl::memory::desc f32_dst_md2 = dnnl::memory::desc({xrow, yrow}, dnnl::memory::data_type::f32, dnnl::memory::format_tag::ab);

dnnl::memory f32_mem1 = dnnl::memory(f32_md1, cpu_engine, in_f32_1);
dnnl::memory f32_mem2 = dnnl::memory(f32_md2, cpu_engine, in_f32_2);
dnnl::memory f32_dst_mem = dnnl::memory(f32_dst_md2, cpu_engine, out_f32);

// inner memory bf16
dnnl::memory::desc bf16_md1 = dnnl::memory::desc({xrow, xcol}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::any);
dnnl::memory::desc bf16_md2 = dnnl::memory::desc({yrow, ycol}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::any);


dnnl::inner_product_forward::primitive_desc inner_product_pd = dnnl::inner_product_forward::primitive_desc(
cpu_engine, dnnl::prop_kind::forward_training,
bf16_md1, bf16_md2, f32_dst_md2);

dnnl::inner_product_forward inner_product_prim = dnnl::inner_product_forward(inner_product_pd);

dnnl::memory bf16_mem1 = dnnl::memory(inner_product_pd.src_desc(), cpu_engine);
dnnl::reorder(f32_mem1, bf16_mem1).execute(engine_stream, f32_mem1, bf16_mem1);

dnnl::memory bf16_mem2 = dnnl::memory(inner_product_pd.weights_desc(), cpu_engine);
dnnl::reorder(f32_mem2, bf16_mem2).execute(engine_stream, f32_mem2, bf16_mem2);

inner_product_prim.execute(engine_stream, {{DNNL_ARG_SRC, bf16_mem1},
{DNNL_ARG_WEIGHTS, bf16_mem2},
{DNNL_ARG_DST, f32_dst_mem}});

// Wait for the computation to finalize.
engine_stream.wait();

// printf("comput_f32bf16f32_inner_product finished#######>\n");
static void comput_f32bf16f32_inner_product(
uint32_t xrow,
uint32_t xcol,
uint32_t yrow,
uint32_t ycol,
float* in_f32_1,
float* in_f32_2,
float* out_f32) {
dnnl::memory::desc f32_md1 = dnnl::memory::desc(
{xrow, xcol},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::ab);
dnnl::memory::desc f32_md2 = dnnl::memory::desc(
{yrow, ycol},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::ab);
dnnl::memory::desc f32_dst_md2 = dnnl::memory::desc(
{xrow, yrow},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::ab);

dnnl::memory f32_mem1 = dnnl::memory(f32_md1, cpu_engine, in_f32_1);
dnnl::memory f32_mem2 = dnnl::memory(f32_md2, cpu_engine, in_f32_2);
dnnl::memory f32_dst_mem = dnnl::memory(f32_dst_md2, cpu_engine, out_f32);

// inner memory bf16
dnnl::memory::desc bf16_md1 = dnnl::memory::desc(
{xrow, xcol},
dnnl::memory::data_type::bf16,
dnnl::memory::format_tag::any);
dnnl::memory::desc bf16_md2 = dnnl::memory::desc(
{yrow, ycol},
dnnl::memory::data_type::bf16,
dnnl::memory::format_tag::any);

dnnl::inner_product_forward::primitive_desc inner_product_pd =
dnnl::inner_product_forward::primitive_desc(
cpu_engine,
dnnl::prop_kind::forward_training,
bf16_md1,
bf16_md2,
f32_dst_md2);

dnnl::inner_product_forward inner_product_prim =
dnnl::inner_product_forward(inner_product_pd);

dnnl::memory bf16_mem1 =
dnnl::memory(inner_product_pd.src_desc(), cpu_engine);
dnnl::reorder(f32_mem1, bf16_mem1)
.execute(engine_stream, f32_mem1, bf16_mem1);

dnnl::memory bf16_mem2 =
dnnl::memory(inner_product_pd.weights_desc(), cpu_engine);
dnnl::reorder(f32_mem2, bf16_mem2)
.execute(engine_stream, f32_mem2, bf16_mem2);

inner_product_prim.execute(
engine_stream,
{{DNNL_ARG_SRC, bf16_mem1},
{DNNL_ARG_WEIGHTS, bf16_mem2},
{DNNL_ARG_DST, f32_dst_mem}});

// Wait for the computation to finalize.
engine_stream.wait();

// printf("comput_f32bf16f32_inner_product finished#######>\n");
}

}//namespace faiss
} // namespace faiss

0 comments on commit b35a0f2

Please sign in to comment.