Skip to content

Commit

Permalink
cleanup: drop SWAR's 64-bit assumptions (#140)
Browse files Browse the repository at this point in the history
The SWAR code now operates on a register of the host CPU at a time as intended.

Note this might actually not be faster on 32-bit, I would have to bench it but in some cases 4 memory reads / lookup-table reads might be faster than blockwide-operations
  • Loading branch information
AaronO committed May 4, 2023
1 parent f34faf2 commit 1c5faf8
Showing 1 changed file with 41 additions and 38 deletions.
79 changes: 41 additions & 38 deletions src/simd/swar.rs
@@ -1,17 +1,20 @@
/// SWAR: SIMD Within A Register
/// SIMD validator backend that validates register-sized chunks of data at a time.
// TODO: current impl assumes 64-bit registers, optimize for 32-bit
use crate::{is_header_name_token, is_header_value_token, is_uri_token, Bytes};

// Adapt block-size to match native register size, i.e: 32bit => 4, 64bit => 8
const BLOCK_SIZE: usize = core::mem::size_of::<usize>();
type ByteBlock = [u8; BLOCK_SIZE];

#[inline]
pub fn match_uri_vectored(bytes: &mut Bytes) {
loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
let n = match_uri_char_8_swar(bytes8);
unsafe {
bytes.advance(n);
}
if n == 8 {
if n == BLOCK_SIZE {
continue;
}
}
Expand All @@ -28,12 +31,12 @@ pub fn match_uri_vectored(bytes: &mut Bytes) {
#[inline]
pub fn match_header_value_vectored(bytes: &mut Bytes) {
loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
let n = match_header_value_char_8_swar(bytes8);
unsafe {
bytes.advance(n);
}
if n == 8 {
if n == BLOCK_SIZE {
continue;
}
}
Expand All @@ -49,19 +52,19 @@ pub fn match_header_value_vectored(bytes: &mut Bytes) {

#[inline]
pub fn match_header_name_vectored(bytes: &mut Bytes) {
while let Some(block) = bytes.peek_n::<[u8; 8]>(8) {
while let Some(block) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
let n = match_block(is_header_name_token, block);
unsafe {
bytes.advance(n);
}
if n != 8 {
if n != BLOCK_SIZE {
return;
}
}
unsafe { bytes.advance(match_tail(is_header_name_token, bytes.as_ref())) };
}

// Matches "tail", i.e: when we have <8 bytes in the buffer, should be uncommon
// Matches "tail", i.e: when we have <BLOCK_SIZE bytes in the buffer, should be uncommon
#[cold]
#[inline]
fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize {
Expand All @@ -75,35 +78,35 @@ fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize {

// Naive fallback block matcher
#[inline(always)]
fn match_block(f: impl Fn(u8) -> bool, block: [u8; 8]) -> usize {
fn match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize {
for (i, &b) in block.iter().enumerate() {
if !f(b) {
return i;
}
}
8
BLOCK_SIZE
}

/// // A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
// A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
// creates a u64 whose bytes are each equal to b
const fn uniform_block(b: u8) -> u64 {
b as u64 * 0x01_01_01_01_01_01_01_01 // [1_u8; 8]
const fn uniform_block(b: u8) -> usize {
(b as u64 * 0x01_01_01_01_01_01_01_01 /* [1_u8; 8] */) as usize
}

// A byte-wise range-check on an enire word/block,
// ensuring all bytes in the word satisfy
// `33 <= x <= 126 && x != '>' && x != '<'`
// IMPORTANT: it false negatives if the block contains '?'
#[inline]
fn match_uri_char_8_swar(block: [u8; 8]) -> usize {
fn match_uri_char_8_swar(block: ByteBlock) -> usize {
// 33 <= x <= 126
const M: u8 = 0x21;
const N: u8 = 0x7E;
const BM: u64 = uniform_block(M);
const BN: u64 = uniform_block(127 - N);
const M128: u64 = uniform_block(128);
const BM: usize = uniform_block(M);
const BN: usize = uniform_block(127 - N);
const M128: usize = uniform_block(128);

let x = u64::from_ne_bytes(block); // Really just a transmute
let x = usize::from_ne_bytes(block); // Really just a transmute
let lt = x.wrapping_sub(BM) & !x; // <= m
let gt = x.wrapping_add(BN) | x; // >= n

Expand All @@ -130,8 +133,8 @@ fn match_uri_char_8_swar(block: [u8; 8]) -> usize {
// }
// (xordist(b'<', 2), xordist(b'>', 2))
// ```
const B3: u64 = uniform_block(3); // (dist <= 2) + 1 to wrap
const BGT: u64 = uniform_block(b'>');
const B3: usize = uniform_block(3); // (dist <= 2) + 1 to wrap
const BGT: usize = uniform_block(b'>');

let xgt = x ^ BGT;
let ltgtq = xgt.wrapping_sub(B3) & !xgt;
Expand All @@ -143,15 +146,15 @@ fn match_uri_char_8_swar(block: [u8; 8]) -> usize {
// ensuring all bytes in the word satisfy `32 <= x <= 126`
// IMPORTANT: false negatives if obs-text is present (0x80..=0xFF)
#[inline]
fn match_header_value_char_8_swar(block: [u8; 8]) -> usize {
fn match_header_value_char_8_swar(block: ByteBlock) -> usize {
// 32 <= x <= 126
const M: u8 = 0x20;
const N: u8 = 0x7E;
const BM: u64 = uniform_block(M);
const BN: u64 = uniform_block(127 - N);
const M128: u64 = uniform_block(128);
const BM: usize = uniform_block(M);
const BN: usize = uniform_block(127 - N);
const M128: usize = uniform_block(128);

let x = u64::from_ne_bytes(block); // Really just a transmute
let x = usize::from_ne_bytes(block); // Really just a transmute
let lt = x.wrapping_sub(BM) & !x; // <= m
let gt = x.wrapping_add(BN) | x; // >= n
offsetnz((lt | gt) & M128)
Expand All @@ -160,10 +163,10 @@ fn match_header_value_char_8_swar(block: [u8; 8]) -> usize {
/// Check block to find offset of first non-zero byte
// NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit
#[inline]
fn offsetnz(block: u64) -> usize {
fn offsetnz(block: usize) -> usize {
// fast path optimistic case (common for long valid sequences)
if block == 0 {
return 8;
return BLOCK_SIZE;
}

// perf: rust will unroll this loop
Expand All @@ -177,19 +180,19 @@ fn offsetnz(block: u64) -> usize {

#[test]
fn test_is_header_value_block() {
let is_header_value_block = |b| match_header_value_char_8_swar(b) == 8;
let is_header_value_block = |b| match_header_value_char_8_swar(b) == BLOCK_SIZE;

// 0..32 => false
for b in 0..32_u8 {
assert_eq!(is_header_value_block([b; 8]), false, "b={}", b);
assert_eq!(is_header_value_block([b; BLOCK_SIZE]), false, "b={}", b);
}
// 32..127 => true
for b in 32..127_u8 {
assert_eq!(is_header_value_block([b; 8]), true, "b={}", b);
assert_eq!(is_header_value_block([b; BLOCK_SIZE]), true, "b={}", b);
}
// 127..=255 => false
for b in 127..=255_u8 {
assert_eq!(is_header_value_block([b; 8]), false, "b={}", b);
assert_eq!(is_header_value_block([b; BLOCK_SIZE]), false, "b={}", b);
}

// A few sanity checks on non-uniform bytes for safe-measure
Expand All @@ -199,30 +202,30 @@ fn test_is_header_value_block() {

#[test]
fn test_is_uri_block() {
let is_uri_block = |b| match_uri_char_8_swar(b) == 8;
let is_uri_block = |b| match_uri_char_8_swar(b) == BLOCK_SIZE;

// 0..33 => false
for b in 0..33_u8 {
assert_eq!(is_uri_block([b; 8]), false, "b={}", b);
assert_eq!(is_uri_block([b; BLOCK_SIZE]), false, "b={}", b);
}
// 33..127 => true if b not in { '<', '?', '>' }
let falsy = |b| b"<?>".contains(&b);
for b in 33..127_u8 {
assert_eq!(is_uri_block([b; 8]), !falsy(b), "b={}", b);
assert_eq!(is_uri_block([b; BLOCK_SIZE]), !falsy(b), "b={}", b);
}
// 127..=255 => false
for b in 127..=255_u8 {
assert_eq!(is_uri_block([b; 8]), false, "b={}", b);
assert_eq!(is_uri_block([b; BLOCK_SIZE]), false, "b={}", b);
}
}

#[test]
fn test_offsetnz() {
let seq = [0_u8; 8];
for i in 0..8 {
let seq = [0_u8; BLOCK_SIZE];
for i in 0..BLOCK_SIZE {
let mut seq = seq.clone();
seq[i] = 1;
let x = u64::from_ne_bytes(seq);
let x = usize::from_ne_bytes(seq);
assert_eq!(offsetnz(x), i);
}
}

0 comments on commit 1c5faf8

Please sign in to comment.