Skip to content

Commit

Permalink
Try to reuse mask vectors when unmasking websocket frames
Browse files Browse the repository at this point in the history
  • Loading branch information
lpereira committed Apr 27, 2024
1 parent 1d376b5 commit dc350c7
Showing 1 changed file with 58 additions and 31 deletions.
89 changes: 58 additions & 31 deletions src/lib/lwan-websocket.c
Expand Up @@ -150,63 +150,90 @@ static size_t get_frame_length(struct lwan_request *request, uint16_t header)

static void unmask(char *msg, size_t msg_len, char mask[static 4])
{
const int32_t mask32 = (int32_t)string_as_uint32(mask);
const char *msg_end = msg + msg_len;
/* TODO: handle alignment of `msg` to use (at least) NT loads
* as we're rewriting msg anyway. (NT writes aren't that
* useful as the unmasked value will be used right after.) */

#if defined(__AVX2__)
const size_t len256 = msg_len / 32;
if (len256) {
const __m256i mask256 = _mm256_setr_epi32(
mask32, mask32, mask32, mask32, mask32, mask32, mask32, mask32);
for (size_t i = 0; i < len256; i++) {
__m256i v = _mm256_loadu_si256((__m256i *)msg);
const __m256i mask256 =
_mm256_castps_si256(_mm256_broadcast_ss((const float *)mask));
if (msg_len >= 32) {
do {
__m256i v = _mm256_lddqu_si256((const __m256i *)msg);
_mm256_storeu_si256((__m256i *)msg, _mm256_xor_si256(v, mask256));
msg += 32;
}

msg_len = (size_t)(msg_end - msg);
msg += 32;
msg_len -= 32;
} while (msg_len >= 32);
}
#endif

#if defined(__SSE2__)
const size_t len128 = msg_len / 16;
if (len128) {
const __m128i mask128 = _mm_setr_epi32(mask32, mask32, mask32, mask32);
for (size_t i = 0; i < len128; i++) {
__m128i v = _mm_loadu_si128((__m128i *)msg);
#if defined(__AVX2__)
const __m128i mask128 = _mm256_extracti128_si256(mask256, 0);
#elif defined(__SSE3__)
const __m128i mask128 = _mm_lddqu_si128((const float *)mask);
#else
const __m128i mask128 = _mm_loadu_si128((const __m128i *)mask);
#endif
if (msg_len >= 16) {
do {
#if defined(__SSE3__)
__m128i v = _mm_lddqu_si128((const __m128i *)msg);
#else
__m128i v = _mm_loadu_si128((const __m128i *)msg);
#endif

_mm_storeu_si128((__m128i *)msg, _mm_xor_si128(v, mask128));
msg += 16;
}

msg_len = (size_t)(msg_end - msg);
msg += 16;
msg_len -= 16;
} while (msg_len >= 16);
}
#endif

if (sizeof(void *) == 8) {
const uint64_t mask64 = (uint64_t)mask32 << 32 | (uint64_t)mask32;
const size_t len64 = msg_len / 8;
for (size_t i = 0; i < len64; i++) {
uint64_t v = string_as_uint64(msg);
v ^= mask64;
msg = mempcpy(msg, &v, sizeof(v));
if (msg_len >= 8) {
#if defined(__SSE_4_1__)
/* We're far away enough from the AVX2 path that it's
* probably better to use mask128 instead of mask256
* here. */
const __int64 mask64 = _mm_extract_epi64(mask128, 0);
#else
const uint32_t mask32 = string_as_uint32(mask);
const uint64_t mask64 = (uint64_t)mask32 << 32 | (uint64_t)mask32;
#endif
do {
uint64_t v = string_as_uint64(msg);
v ^= (uint64_t)mask64;
msg = mempcpy(msg, &v, sizeof(v));
msg_len -= 8;
} while (msg_len >= 8);
}
}

const size_t len32 = (size_t)((msg_end - msg) / 4);
for (size_t i = 0; i < len32; i++) {
uint32_t v = string_as_uint32(msg);
v ^= (uint32_t)mask32;
msg = mempcpy(msg, &v, sizeof(v));
if (msg_len >= 4) {
const uint32_t mask32 = string_as_uint32(mask);
do {
uint32_t v = string_as_uint32(msg);
v ^= (uint32_t)mask32;
msg = mempcpy(msg, &v, sizeof(v));
msg_len -= 4;
} while (msg_len >= 4);
}

switch (msg_end - msg) {
switch (msg_len) {
case 3:
msg[2] ^= mask[2]; /* fallthrough */
case 2:
msg[1] ^= mask[1]; /* fallthrough */
case 1:
msg[0] ^= mask[0];
break;
default:
__builtin_unreachable();
}
#undef MASK32_SET
}

static void send_websocket_pong(struct lwan_request *request, uint16_t header)
Expand Down

0 comments on commit dc350c7

Please sign in to comment.