Skip to content

Commit

Permalink
refactor: native zlib support (#10243)
Browse files Browse the repository at this point in the history
* refactor: remove zlib-sync

* fix: bad length check

* refactor: support both options

BREAKING CHANGE: renamed compression related options

* chore: fix doc comment

* chore: update debug messages

* chore: better wording

Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com>

* chore: suggested changes

* chore: better naming

* refactor: lazy node:zlib import and lib detection

* chore: zlib capitalization

* fix: use proper var

* refactor: better inflate check

Co-authored-by: Aura <kyradiscord@gmail.com>

* chore: debug label

Co-authored-by: Superchupu <53496941+SuperchupuDev@users.noreply.github.com>

---------

Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com>
Co-authored-by: Aura <kyradiscord@gmail.com>
Co-authored-by: Superchupu <53496941+SuperchupuDev@users.noreply.github.com>
Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>
  • Loading branch information
5 people committed May 11, 2024
1 parent 7816ec2 commit 20258f9
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 61 deletions.
5 changes: 4 additions & 1 deletion packages/ws/README.md
Expand Up @@ -50,7 +50,10 @@ const manager = new WebSocketManager({
intents: 0, // for no intents
rest,
// uncomment if you have zlib-sync installed and want to use compression
// compression: CompressionMethod.ZlibStream,
// compression: CompressionMethod.ZlibSync,

// alternatively, we support compression using node's native `node:zlib` module:
// compression: CompressionMethod.ZlibNative,
});

manager.on(WebSocketShardEvents.Dispatch, (event) => {
Expand Down
9 changes: 8 additions & 1 deletion packages/ws/src/utils/constants.ts
Expand Up @@ -18,13 +18,19 @@ export enum Encoding {
* Valid compression methods
*/
export enum CompressionMethod {
ZlibStream = 'zlib-stream',
ZlibNative,
ZlibSync,
}

export const DefaultDeviceProperty = `@discordjs/ws [VI]{{inject}}[/VI]` as `@discordjs/ws ${string}`;

const getDefaultSessionStore = lazy(() => new Collection<number, SessionInfo | null>());

export const CompressionParameterMap = {
[CompressionMethod.ZlibNative]: 'zlib-stream',
[CompressionMethod.ZlibSync]: 'zlib-stream',
} as const satisfies Record<CompressionMethod, string>;

/**
* Default options used by the manager
*/
Expand All @@ -46,6 +52,7 @@ export const DefaultWebSocketManagerOptions = {
version: APIVersion,
encoding: Encoding.JSON,
compression: null,
useIdentifyCompression: false,
retrieveSessionInfo(shardId) {
const store = getDefaultSessionStore();
return store.get(shardId) ?? null;
Expand Down
10 changes: 8 additions & 2 deletions packages/ws/src/ws/WebSocketManager.ts
Expand Up @@ -96,9 +96,9 @@ export interface OptionalWebSocketManagerOptions {
*/
buildStrategy(manager: WebSocketManager): IShardingStrategy;
/**
* The compression method to use
* The transport compression method to use - mutually exclusive with `useIdentifyCompression`
*
* @defaultValue `null` (no compression)
* @defaultValue `null` (no transport compression)
*/
compression: CompressionMethod | null;
/**
Expand Down Expand Up @@ -176,6 +176,12 @@ export interface OptionalWebSocketManagerOptions {
* Function used to store session information for a given shard
*/
updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>;
/**
* Whether to use the `compress` option when identifying
*
* @defaultValue `false`
*/
useIdentifyCompression: boolean;
/**
* The gateway version to use
*
Expand Down
195 changes: 138 additions & 57 deletions packages/ws/src/ws/WebSocketShard.ts
@@ -1,11 +1,10 @@
/* eslint-disable id-length */
import { Buffer } from 'node:buffer';
import { once } from 'node:events';
import { clearInterval, clearTimeout, setInterval, setTimeout } from 'node:timers';
import { setTimeout as sleep } from 'node:timers/promises';
import { URLSearchParams } from 'node:url';
import { TextDecoder } from 'node:util';
import { inflate } from 'node:zlib';
import type * as nativeZlib from 'node:zlib';
import { Collection } from '@discordjs/collection';
import { lazy, shouldUseGlobalFetchAndWebSocket } from '@discordjs/util';
import { AsyncQueue } from '@sapphire/async-queue';
Expand All @@ -21,13 +20,20 @@ import {
type GatewaySendPayload,
} from 'discord-api-types/v10';
import { WebSocket, type Data } from 'ws';
import type { Inflate } from 'zlib-sync';
import type { IContextFetchingStrategy } from '../strategies/context/IContextFetchingStrategy.js';
import { ImportantGatewayOpcodes, getInitialSendRateLimitState } from '../utils/constants.js';
import type * as ZlibSync from 'zlib-sync';
import type { IContextFetchingStrategy } from '../strategies/context/IContextFetchingStrategy';
import {
CompressionMethod,
CompressionParameterMap,
ImportantGatewayOpcodes,
getInitialSendRateLimitState,
} from '../utils/constants.js';
import type { SessionInfo } from './WebSocketManager.js';

// eslint-disable-next-line promise/prefer-await-to-then
/* eslint-disable promise/prefer-await-to-then */
const getZlibSync = lazy(async () => import('zlib-sync').then((mod) => mod.default).catch(() => null));
const getNativeZlib = lazy(async () => import('node:zlib').then((mod) => mod).catch(() => null));
/* eslint-enable promise/prefer-await-to-then */

export enum WebSocketShardEvents {
Closed = 'closed',
Expand Down Expand Up @@ -86,9 +92,9 @@ const WebSocketConstructor: typeof WebSocket = shouldUseGlobalFetchAndWebSocket(
export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
private connection: WebSocket | null = null;

private useIdentifyCompress = false;
private nativeInflate: nativeZlib.Inflate | null = null;

private inflate: Inflate | null = null;
private zLibSyncInflate: ZlibSync.Inflate | null = null;

private readonly textDecoder = new TextDecoder();

Expand Down Expand Up @@ -120,6 +126,18 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {

#status: WebSocketShardStatus = WebSocketShardStatus.Idle;

private identifyCompressionEnabled = false;

/**
* @privateRemarks
*
* This is needed because `this.strategy.options.compression` is not an actual reflection of the compression method
* used, but rather the compression method that the user wants to use. This is because the libraries could just be missing.
*/
private get transportCompressionEnabled() {
return this.strategy.options.compression !== null && (this.nativeInflate ?? this.zLibSyncInflate) !== null;
}

public get status(): WebSocketShardStatus {
return this.#status;
}
Expand Down Expand Up @@ -161,21 +179,63 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
throw new Error("Tried to connect a shard that wasn't idle");
}

const { version, encoding, compression } = this.strategy.options;
const { version, encoding, compression, useIdentifyCompression } = this.strategy.options;
this.identifyCompressionEnabled = useIdentifyCompression;

// eslint-disable-next-line id-length
const params = new URLSearchParams({ v: version, encoding });
if (compression) {
const zlib = await getZlibSync();
if (zlib) {
params.append('compress', compression);
this.inflate = new zlib.Inflate({
chunkSize: 65_535,
to: 'string',
});
} else if (!this.useIdentifyCompress) {
this.useIdentifyCompress = true;
console.warn(
'WebSocketShard: Compression is enabled but zlib-sync is not installed, falling back to identify compress',
);
if (compression !== null) {
if (useIdentifyCompression) {
console.warn('WebSocketShard: transport compression is enabled, disabling identify compression');
this.identifyCompressionEnabled = false;
}

params.append('compress', CompressionParameterMap[compression]);

switch (compression) {
case CompressionMethod.ZlibNative: {
const zlib = await getNativeZlib();
if (zlib) {
const inflate = zlib.createInflate({
chunkSize: 65_535,
flush: zlib.constants.Z_SYNC_FLUSH,
});

inflate.on('error', (error) => {
this.emit(WebSocketShardEvents.Error, { error });
});

this.nativeInflate = inflate;
} else {
console.warn('WebSocketShard: Compression is set to native but node:zlib is not available.');
params.delete('compress');
}

break;
}

case CompressionMethod.ZlibSync: {
const zlib = await getZlibSync();
if (zlib) {
this.zLibSyncInflate = new zlib.Inflate({
chunkSize: 65_535,
to: 'string',
});
} else {
console.warn('WebSocketShard: Compression is set to zlib-sync, but it is not installed.');
params.delete('compress');
}

break;
}
}
}

if (this.identifyCompressionEnabled) {
const zlib = await getNativeZlib();
if (!zlib) {
console.warn('WebSocketShard: Identify compression is enabled, but node:zlib is not available.');
this.identifyCompressionEnabled = false;
}
}

Expand Down Expand Up @@ -451,28 +511,29 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
`shard id: ${this.id.toString()}`,
`shard count: ${this.strategy.options.shardCount}`,
`intents: ${this.strategy.options.intents}`,
`compression: ${this.inflate ? 'zlib-stream' : this.useIdentifyCompress ? 'identify' : 'none'}`,
`compression: ${this.transportCompressionEnabled ? CompressionParameterMap[this.strategy.options.compression!] : this.identifyCompressionEnabled ? 'identify' : 'none'}`,
]);

const d: GatewayIdentifyData = {
const data: GatewayIdentifyData = {
token: this.strategy.options.token,
properties: this.strategy.options.identifyProperties,
intents: this.strategy.options.intents,
compress: this.useIdentifyCompress,
compress: this.identifyCompressionEnabled,
shard: [this.id, this.strategy.options.shardCount],
};

if (this.strategy.options.largeThreshold) {
d.large_threshold = this.strategy.options.largeThreshold;
data.large_threshold = this.strategy.options.largeThreshold;
}

if (this.strategy.options.initialPresence) {
d.presence = this.strategy.options.initialPresence;
data.presence = this.strategy.options.initialPresence;
}

await this.send({
op: GatewayOpcodes.Identify,
d,
// eslint-disable-next-line id-length
d: data,
});

await this.waitForEvent(WebSocketShardEvents.Ready, this.strategy.options.readyTimeout);
Expand All @@ -490,6 +551,7 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
this.replayedEvents = 0;
return this.send({
op: GatewayOpcodes.Resume,
// eslint-disable-next-line id-length
d: {
token: this.strategy.options.token,
seq: session.sequence,
Expand All @@ -507,13 +569,22 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {

await this.send({
op: GatewayOpcodes.Heartbeat,
// eslint-disable-next-line id-length
d: session?.sequence ?? null,
});

this.lastHeartbeatAt = Date.now();
this.isAck = false;
}

private parseInflateResult(result: any): GatewayReceivePayload | null {
if (!result) {
return null;
}

return JSON.parse(typeof result === 'string' ? result : this.textDecoder.decode(result)) as GatewayReceivePayload;
}

private async unpackMessage(data: Data, isBinary: boolean): Promise<GatewayReceivePayload | null> {
// Deal with no compression
if (!isBinary) {
Expand All @@ -528,10 +599,12 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
const decompressable = new Uint8Array(data as ArrayBuffer);

// Deal with identify compress
if (this.useIdentifyCompress) {
return new Promise((resolve, reject) => {
if (this.identifyCompressionEnabled) {
// eslint-disable-next-line no-async-promise-executor
return new Promise(async (resolve, reject) => {
const zlib = (await getNativeZlib())!;
// eslint-disable-next-line promise/prefer-await-to-callbacks
inflate(decompressable, { chunkSize: 65_535 }, (err, result) => {
zlib.inflate(decompressable, { chunkSize: 65_535 }, (err, result) => {
if (err) {
reject(err);
return;
Expand All @@ -542,42 +615,50 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
});
}

// Deal with gw wide zlib-stream compression
if (this.inflate) {
const l = decompressable.length;
// Deal with transport compression
if (this.transportCompressionEnabled) {
const flush =
l >= 4 &&
decompressable[l - 4] === 0x00 &&
decompressable[l - 3] === 0x00 &&
decompressable[l - 2] === 0xff &&
decompressable[l - 1] === 0xff;
decompressable.length >= 4 &&
decompressable.at(-4) === 0x00 &&
decompressable.at(-3) === 0x00 &&
decompressable.at(-2) === 0xff &&
decompressable.at(-1) === 0xff;

const zlib = (await getZlibSync())!;
this.inflate.push(Buffer.from(decompressable), flush ? zlib.Z_SYNC_FLUSH : zlib.Z_NO_FLUSH);
if (this.nativeInflate) {
this.nativeInflate.write(decompressable, 'binary');

if (this.inflate.err) {
this.emit(WebSocketShardEvents.Error, {
error: new Error(`${this.inflate.err}${this.inflate.msg ? `: ${this.inflate.msg}` : ''}`),
});
}
if (!flush) {
return null;
}

if (!flush) {
return null;
}
const [result] = await once(this.nativeInflate, 'data');
return this.parseInflateResult(result);
} else if (this.zLibSyncInflate) {
const zLibSync = (await getZlibSync())!;
this.zLibSyncInflate.push(Buffer.from(decompressable), flush ? zLibSync.Z_SYNC_FLUSH : zLibSync.Z_NO_FLUSH);

if (this.zLibSyncInflate.err) {
this.emit(WebSocketShardEvents.Error, {
error: new Error(
`${this.zLibSyncInflate.err}${this.zLibSyncInflate.msg ? `: ${this.zLibSyncInflate.msg}` : ''}`,
),
});
}

const { result } = this.inflate;
if (!result) {
return null;
}
if (!flush) {
return null;
}

return JSON.parse(typeof result === 'string' ? result : this.textDecoder.decode(result)) as GatewayReceivePayload;
const { result } = this.zLibSyncInflate;
return this.parseInflateResult(result);
}
}

this.debug([
'Received a message we were unable to decompress',
`isBinary: ${isBinary.toString()}`,
`useIdentifyCompress: ${this.useIdentifyCompress.toString()}`,
`inflate: ${Boolean(this.inflate).toString()}`,
`identifyCompressionEnabled: ${this.identifyCompressionEnabled.toString()}`,
`inflate: ${this.transportCompressionEnabled ? CompressionMethod[this.strategy.options.compression!] : 'none'}`,
]);

return null;
Expand Down Expand Up @@ -838,7 +919,7 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
messages.length > 1
? `\n${messages
.slice(1)
.map((m) => ` ${m}`)
.map((message) => ` ${message}`)
.join('\n')}`
: ''
}`;
Expand Down

0 comments on commit 20258f9

Please sign in to comment.