Skip to content

Commit

Permalink
refactor: simd swar (#134)
Browse files Browse the repository at this point in the history
Moves the block-wise validators to a "swar" SIMD backend

The core logic of validate => extract => chain is now more evident
  • Loading branch information
AaronO committed Apr 25, 2023
1 parent 6e7ba52 commit 58a6293
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 237 deletions.
234 changes: 18 additions & 216 deletions src/lib.rs
Expand Up @@ -87,65 +87,10 @@ static URI_MAP: [bool; 256] = byte_map![
];

#[inline]
fn is_uri_token(b: u8) -> bool {
pub(crate) fn is_uri_token(b: u8) -> bool {
URI_MAP[b as usize]
}

// A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
// creates a u64 whose bytes are each equal to b
const fn uniform_block(b: u8) -> u64 {
b as u64 * 0x01_01_01_01_01_01_01_01 // [1_u8; 8]
}

// A byte-wise range-check on an enire word/block,
// ensuring all bytes in the word satisfy
// `33 <= x <= 126 && x != '>' && x != '<'`
// it false negatives if the block contains '?'
#[inline]
fn validate_uri_block(block: [u8; 8]) -> usize {
// 33 <= x <= 126
const M: u8 = 0x21;
const N: u8 = 0x7E;
const BM: u64 = uniform_block(M);
const BN: u64 = uniform_block(127-N);
const M128: u64 = uniform_block(128);

let x = u64::from_ne_bytes(block); // Really just a transmute
let lt = x.wrapping_sub(BM) & !x; // <= m
let gt = x.wrapping_add(BN) | x; // >= n

// XOR checks to catch '<' & '>' for correctness
//
// XOR can be thought of as a "distance function"
// (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0`
// (each u8 "xor key" providing a unique total ordering of u8)
// '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`)
// xor(x, '>') <= 2 => {'>', '?', '<'}
// xor(x, '<') <= 2 => {'<', '=', '>'}
//
// We assume P('=') > P('?'),
// given well/commonly-formatted URLs with querystrings contain
// a single '?' but possibly many '='
//
// Thus it's preferable/near-optimal to "xor distance" on '>',
// since we'll slowpath at most one block per URL
//
// Some rust code to sanity check this yourself:
// ```rs
// fn xordist(x: u8, n: u8) -> Vec<(char, u8)> {
// (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect()
// }
// (xordist(b'<', 2), xordist(b'>', 2))
// ```
const B3: u64 = uniform_block(3); // (dist <= 2) + 1 to wrap
const BGT: u64 = uniform_block(b'>');

let xgt = x ^ BGT;
let ltgtq = xgt.wrapping_sub(B3) & !xgt;

offsetnz((ltgtq | lt | gt) & M128)
}

static HEADER_NAME_MAP: [bool; 256] = byte_map![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
Expand All @@ -166,7 +111,7 @@ static HEADER_NAME_MAP: [bool; 256] = byte_map![
];

#[inline]
fn is_header_name_token(b: u8) -> bool {
pub(crate) fn is_header_name_token(b: u8) -> bool {
HEADER_NAME_MAP[b as usize]
}

Expand All @@ -191,45 +136,10 @@ static HEADER_VALUE_MAP: [bool; 256] = byte_map![


#[inline]
fn is_header_value_token(b: u8) -> bool {
pub(crate) fn is_header_value_token(b: u8) -> bool {
HEADER_VALUE_MAP[b as usize]
}

// A byte-wise range-check on an entire word/block,
// ensuring all bytes in the word satisfy `32 <= x <= 126`
#[inline]
fn validate_header_value_block(block: [u8; 8]) -> usize {
// 32 <= x <= 126
const M: u8 = 0x20;
const N: u8 = 0x7E;
const BM: u64 = uniform_block(M);
const BN: u64 = uniform_block(127-N);
const M128: u64 = uniform_block(128);

let x = u64::from_ne_bytes(block); // Really just a transmute
let lt = x.wrapping_sub(BM) & !x; // <= m
let gt = x.wrapping_add(BN) | x; // >= n
offsetnz((lt | gt) & M128)
}

#[inline]
/// Check block to find offset of first non-zero byte
// NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit
fn offsetnz(block: u64) -> usize {
// fast path optimistic case (common for long valid sequences)
if block == 0 {
return 8;
}

// perf: rust will unroll this loop
for (i, b) in block.to_ne_bytes().iter().copied().enumerate() {
if b != 0 {
return i;
}
}
unreachable!()
}

/// An error in parsing.
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum Error {
Expand Down Expand Up @@ -966,28 +876,14 @@ fn parse_token<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> {
#[allow(missing_docs)]
// WARNING: Exported for internal benchmarks, not fit for public consumption
pub fn parse_uri<'a>(bytes: &mut Bytes<'a>) -> Result<&'a str> {
let b = next!(bytes);
if !is_uri_token(b) {
// First char must be a URI char, it can't be a space which would indicate an empty path.
let start = bytes.pos();
simd::match_uri_vectored(bytes);
// URI must have at least one char
if bytes.pos() == start {
return Err(Error::Token);
}

simd::match_uri_vectored(bytes);

let mut b;
loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
let n = validate_uri_block(bytes8);
unsafe { bytes.advance(n); }
if n == 8 { continue; }
}
b = next!(bytes);
if !is_uri_token(b) {
break;
}
}

if b == b' ' {
if next!(bytes) == b' ' {
return Ok(Status::Complete(unsafe {
// all bytes up till `i` must have been `is_token`.
str::from_utf8_unchecked(bytes.slice_skip(1))
Expand Down Expand Up @@ -1099,7 +995,8 @@ fn parse_headers_iter_uninit<'a, 'b>(
headers,
num_headers: 0,
};
let mut count: usize = 0;
// Track starting pointer to calculate the number of bytes parsed.
let start = bytes.as_ref().as_ptr() as usize;
let mut result = Err(Error::TooManyHeaders);

let mut iter = autoshrink.headers.iter_mut();
Expand Down Expand Up @@ -1155,7 +1052,6 @@ fn parse_headers_iter_uninit<'a, 'b>(
b = next!($bytes);
}

count += $bytes.pos();
$bytes.slice();

continue 'headers;
Expand All @@ -1166,50 +1062,24 @@ fn parse_headers_iter_uninit<'a, 'b>(
let b = next!(bytes);
if b == b'\r' {
expect!(bytes.next() == b'\n' => Err(Error::NewLine));
result = Ok(Status::Complete(count + bytes.pos()));
let end = bytes.as_ref().as_ptr() as usize;
result = Ok(Status::Complete(end - start));
break;
}
if b == b'\n' {
result = Ok(Status::Complete(count + bytes.pos()));
let end = bytes.as_ref().as_ptr() as usize;
result = Ok(Status::Complete(end - start));
break;
}
if !is_header_name_token(b) {
handle_invalid_char!(bytes, b, HeaderName);
}

// parse header name until colon
let mut b;
let header_name: &str = 'name: loop {
'name_inner: loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
macro_rules! check {
($bytes:ident, $i:literal) => ({
b = $bytes[$i];
if !is_header_name_token(b) {
unsafe { bytes.advance($i + 1); }
break 'name_inner;
}
});
}

check!(bytes8, 0);
check!(bytes8, 1);
check!(bytes8, 2);
check!(bytes8, 3);
check!(bytes8, 4);
check!(bytes8, 5);
check!(bytes8, 6);
check!(bytes8, 7);
unsafe { bytes.advance(8); }
} else {
b = next!(bytes);
if !is_header_name_token(b) {
break 'name_inner;
}
}
}

count += bytes.pos();
simd::match_header_name_vectored(bytes);
let mut b = next!(bytes);

let name = unsafe {
str::from_utf8_unchecked(bytes.slice_skip(1))
};
Expand All @@ -1223,7 +1093,6 @@ fn parse_headers_iter_uninit<'a, 'b>(
b = next!(bytes);

if b == b':' {
count += bytes.pos();
bytes.slice();
break 'name name;
}
Expand All @@ -1240,7 +1109,6 @@ fn parse_headers_iter_uninit<'a, 'b>(
'whitespace_after_colon: loop {
b = next!(bytes);
if b == b' ' || b == b'\t' {
count += bytes.pos();
bytes.slice();
continue 'whitespace_after_colon;
}
Expand All @@ -1256,7 +1124,6 @@ fn parse_headers_iter_uninit<'a, 'b>(

maybe_continue_after_obsolete_line_folding!(bytes, 'whitespace_after_colon);

count += bytes.pos();
let whitespace_slice = bytes.slice();

// This produces an empty slice that points to the beginning
Expand All @@ -1268,18 +1135,7 @@ fn parse_headers_iter_uninit<'a, 'b>(
// parse value till EOL

simd::match_header_value_vectored(bytes);

'value_line: loop {
if let Some(bytes8) = bytes.peek_n::<[u8; 8]>(8) {
let n = validate_header_value_block(bytes8);
unsafe { bytes.advance(n); }
if n == 8 { continue 'value_line; }
}
b = next!(bytes);
if !is_header_value_token(b) {
break 'value_line;
}
}
let b = next!(bytes);

//found_ctl
let skip = if b == b'\r' {
Expand All @@ -1293,7 +1149,6 @@ fn parse_headers_iter_uninit<'a, 'b>(

maybe_continue_after_obsolete_line_folding!(bytes, 'value_lines);

count += bytes.pos();
// having just checked that a newline exists, it's safe to skip it.
unsafe {
break 'value bytes.slice_skip(skip);
Expand Down Expand Up @@ -1409,7 +1264,6 @@ pub fn parse_chunk_size(buf: &[u8])
#[cfg(test)]
mod tests {
use super::{Request, Response, Status, EMPTY_HEADER, parse_chunk_size};
use super::{offsetnz, validate_header_value_block, validate_uri_block};

const NUM_OF_HEADERS: usize = 4;

Expand Down Expand Up @@ -2376,58 +2230,6 @@ mod tests {
assert_eq!(response.headers[0].value, &b"baguette"[..]);
}

#[test]
fn test_is_header_value_block() {
let is_header_value_block = |b| validate_header_value_block(b) == 8;

// 0..32 => false
for b in 0..32_u8 {
assert_eq!(is_header_value_block([b; 8]), false, "b={}", b);
}
// 32..127 => true
for b in 32..127_u8 {
assert_eq!(is_header_value_block([b; 8]), true, "b={}", b);
}
// 127..=255 => false
for b in 127..=255_u8 {
assert_eq!(is_header_value_block([b; 8]), false, "b={}", b);
}

// A few sanity checks on non-uniform bytes for safe-measure
assert!(!is_header_value_block(*b"foo.com\n"));
assert!(!is_header_value_block(*b"o.com\r\nU"));
}

#[test]
fn test_is_uri_block() {
let is_uri_block = |b| validate_uri_block(b) == 8;

// 0..33 => false
for b in 0..33_u8 {
assert_eq!(is_uri_block([b; 8]), false, "b={}", b);
}
// 33..127 => true if b not in { '<', '?', '>' }
let falsy = |b| b"<?>".contains(&b);
for b in 33..127_u8 {
assert_eq!(is_uri_block([b; 8]), !falsy(b), "b={}", b);
}
// 127..=255 => false
for b in 127..=255_u8 {
assert_eq!(is_uri_block([b; 8]), false, "b={}", b);
}
}

#[test]
fn test_offsetnz() {
let seq = [0_u8; 8];
for i in 0..8 {
let mut seq = seq.clone();
seq[i] = 1;
let x = u64::from_ne_bytes(seq);
assert_eq!(offsetnz(x), i);
}
}

#[test]
fn test_method_within_buffer() {
const REQUEST: &[u8] = b"GET / HTTP/1.1\r\n\r\n";
Expand Down
8 changes: 0 additions & 8 deletions src/simd/fallback.rs

This file was deleted.

0 comments on commit 58a6293

Please sign in to comment.