Skip to content

Commit

Permalink
Some changes to simdlib (#2885)
Browse files Browse the repository at this point in the history
Summary:
- Use elementwise operation and reduction once instead of across-vector comparing operation twice
- Use already implemented supporting functions
- Unify semantics of `operator==` as same as `simd16uint16`
    - `operator==` of `simd8uint32` and `simd8float32` had been implemented on #2568, but these has not same semantics as `simd16uint16` (which had been implemented in a long time ago). For getting the vector equality as `bool` , now we should use `is_same_as` member function.
- Change `is_same_as` to accept any vector type as argument for `simdlib_neon`
    - `is_same_as` has supported any vector type on `simdlib_avx2` and `simdlib_emulated` already
- Remove unused function `simd16uint16::is_same` on `simdlib_avx2`
    - Is it typo of `is_same_as` ? Anyway it seems to be used unlikely

Pull Request resolved: #2885

Reviewed By: mdouze

Differential Revision: D46330666

Pulled By: alexanderguzhva

fbshipit-source-id: 0ea14f8e9a8bda78f24a655219dffe3e07fc110f
  • Loading branch information
wx257osn2 authored and facebook-github-bot committed Jun 1, 2023
1 parent bbc95b1 commit 9c88422
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 83 deletions.
6 changes: 0 additions & 6 deletions faiss/utils/simdlib_avx2.h
Expand Up @@ -202,12 +202,6 @@ struct simd16uint16 : simd256bit {
return simd16uint16(_mm256_cmpeq_epi16(lhs.i, rhs.i));
}

bool is_same(simd16uint16 other) const {
const __m256i pcmp = _mm256_cmpeq_epi16(i, other.i);
unsigned bitmask = _mm256_movemask_epi8(pcmp);
return (bitmask == 0xffffffffU);
}

simd16uint16 operator~() const {
return simd16uint16(_mm256_xor_si256(i, _mm256_set1_epi32(-1)));
}
Expand Down
149 changes: 72 additions & 77 deletions faiss/utils/simdlib_neon.h
Expand Up @@ -559,15 +559,13 @@ struct simd16uint16 {
}

// Checks whether the other holds exactly the same bytes.
bool is_same_as(simd16uint16 other) const {
const bool equal0 =
(vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) ==
0xffff);
const bool equal1 =
(vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) ==
0xffff);

return equal0 && equal1;
template <typename T>
bool is_same_as(T other) const {
const auto o = detail::simdlib::reinterpret_u16(other.data);
const auto equals = detail::simdlib::binary_func(data, o)
.template call<&vceqq_u16>();
const auto equal = vandq_u16(equals.val[0], equals.val[1]);
return vminvq_u16(equal) == 0xffffu;
}

simd16uint16 operator~() const {
Expand Down Expand Up @@ -689,13 +687,12 @@ inline void cmplt_min_max_fast(
simd16uint16& minIndices,
simd16uint16& maxValues,
simd16uint16& maxIndices) {
const uint16x8x2_t comparison = uint16x8x2_t{
vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
const uint16x8x2_t comparison =
detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vcltq_u16>();

minValues.data = uint16x8x2_t{
vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
minValues = min(candidateValues, currentValues);
minIndices.data = uint16x8x2_t{
vbslq_u16(
comparison.val[0],
Expand All @@ -706,9 +703,7 @@ inline void cmplt_min_max_fast(
candidateIndices.data.val[1],
currentIndices.data.val[1])};

maxValues.data = uint16x8x2_t{
vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
maxValues = max(candidateValues, currentValues);
maxIndices.data = uint16x8x2_t{
vbslq_u16(
comparison.val[0],
Expand Down Expand Up @@ -869,13 +864,13 @@ struct simd32uint8 {
}

// Checks whether the other holds exactly the same bytes.
bool is_same_as(simd32uint8 other) const {
const bool equal0 =
(vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff);
const bool equal1 =
(vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff);

return equal0 && equal1;
template <typename T>
bool is_same_as(T other) const {
const auto o = detail::simdlib::reinterpret_u8(other.data);
const auto equals = detail::simdlib::binary_func(data, o)
.template call<&vceqq_u8>();
const auto equal = vandq_u8(equals.val[0], equals.val[1]);
return vminvq_u8(equal) == 0xffu;
}
};

Expand Down Expand Up @@ -960,27 +955,28 @@ struct simd8uint32 {
return *this;
}

bool operator==(simd8uint32 other) const {
const auto equals = detail::simdlib::binary_func(data, other.data)
.call<&vceqq_u32>();
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
return vminvq_u32(equal) == 0xffffffff;
simd8uint32 operator==(simd8uint32 other) const {
return simd8uint32{detail::simdlib::binary_func(data, other.data)
.call<&vceqq_u32>()};
}

bool operator!=(simd8uint32 other) const {
return !(*this == other);
simd8uint32 operator~() const {
return simd8uint32{
detail::simdlib::unary_func(data).call<&vmvnq_u32>()};
}

// Checks whether the other holds exactly the same bytes.
bool is_same_as(simd8uint32 other) const {
const bool equal0 =
(vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
0xffffffff);
const bool equal1 =
(vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
0xffffffff);
simd8uint32 operator!=(simd8uint32 other) const {
return ~(*this == other);
}

return equal0 && equal1;
// Checks whether the other holds exactly the same bytes.
template <typename T>
bool is_same_as(T other) const {
const auto o = detail::simdlib::reinterpret_u32(other.data);
const auto equals = detail::simdlib::binary_func(data, o)
.template call<&vceqq_u32>();
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
return vminvq_u32(equal) == 0xffffffffu;
}

void clear() {
Expand Down Expand Up @@ -1053,13 +1049,14 @@ inline void cmplt_min_max_fast(
simd8uint32& minIndices,
simd8uint32& maxValues,
simd8uint32& maxIndices) {
const uint32x4x2_t comparison = uint32x4x2_t{
vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])};

minValues.data = uint32x4x2_t{
vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
const uint32x4x2_t comparison =
detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vcltq_u32>();

minValues.data = detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vminq_u32>();
minIndices.data = uint32x4x2_t{
vbslq_u32(
comparison.val[0],
Expand All @@ -1070,9 +1067,9 @@ inline void cmplt_min_max_fast(
candidateIndices.data.val[1],
currentIndices.data.val[1])};

maxValues.data = uint32x4x2_t{
vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
maxValues.data = detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vmaxq_u32>();
maxIndices.data = uint32x4x2_t{
vbslq_u32(
comparison.val[0],
Expand Down Expand Up @@ -1167,28 +1164,25 @@ struct simd8float32 {
return *this;
}

bool operator==(simd8float32 other) const {
const auto equals =
simd8uint32 operator==(simd8float32 other) const {
return simd8uint32{
detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
.call<&vceqq_f32>();
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
return vminvq_u32(equal) == 0xffffffff;
.call<&vceqq_f32>()};
}

bool operator!=(simd8float32 other) const {
return !(*this == other);
simd8uint32 operator!=(simd8float32 other) const {
return ~(*this == other);
}

// Checks whether the other holds exactly the same bytes.
bool is_same_as(simd8float32 other) const {
const bool equal0 =
(vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) ==
0xffffffff);
const bool equal1 =
(vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) ==
0xffffffff);

return equal0 && equal1;
template <typename T>
bool is_same_as(T other) const {
const auto o = detail::simdlib::reinterpret_f32(other.data);
const auto equals =
detail::simdlib::binary_func<::uint32x4x2_t>(data, o)
.template call<&vceqq_f32>();
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
return vminvq_u32(equal) == 0xffffffffu;
}

std::string tostring() const {
Expand Down Expand Up @@ -1302,13 +1296,14 @@ inline void cmplt_min_max_fast(
simd8uint32& minIndices,
simd8float32& maxValues,
simd8uint32& maxIndices) {
const uint32x4x2_t comparison = uint32x4x2_t{
vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])};

minValues.data = float32x4x2_t{
vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
const uint32x4x2_t comparison =
detail::simdlib::binary_func<::uint32x4x2_t>(
candidateValues.data, currentValues.data)
.call<&vcltq_f32>();

minValues.data = detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vminq_f32>();
minIndices.data = uint32x4x2_t{
vbslq_u32(
comparison.val[0],
Expand All @@ -1319,9 +1314,9 @@ inline void cmplt_min_max_fast(
candidateIndices.data.val[1],
currentIndices.data.val[1])};

maxValues.data = float32x4x2_t{
vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
maxValues.data = detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vmaxq_f32>();
maxIndices.data = uint32x4x2_t{
vbslq_u32(
comparison.val[0],
Expand Down

0 comments on commit 9c88422

Please sign in to comment.