Skip to content

Commit e6b3e77

Browse files
committed
refactor: single file => modules
1 parent dd4c729 commit e6b3e77

File tree

7 files changed

+547
-415
lines changed

7 files changed

+547
-415
lines changed

src/GenerateReply.ts

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import { FastifyReply, FastifyRequest } from "fastify";
2+
import { CreatePayload, OpenAI_NS } from "./tyeps";
3+
import { client, system_fingerprint } from "./constants";
4+
import { claude_stop_to_openai_stop } from "./utils";
5+
6+
export class GenerateReply {
7+
constructor(
8+
public payload: CreatePayload,
9+
public request: FastifyRequest<{
10+
Body: OpenAI_NS.CompletionRequest;
11+
}>,
12+
public reply: FastifyReply
13+
) {}
14+
15+
async request_stream_api() {
16+
const { request, reply, payload } = this;
17+
const { model, created } = payload;
18+
delete (payload as any).created;
19+
20+
const stream_id = Math.random().toString(36).substring(2, 15);
21+
async function* build_streaming() {
22+
const message_stream = await client.messages.create({
23+
...payload,
24+
stream: true,
25+
});
26+
let input_tokens = 0;
27+
let current_id = stream_id;
28+
let current_model = model;
29+
for await (const data of message_stream) {
30+
if (request.socket.closed) {
31+
message_stream.controller.abort();
32+
break;
33+
}
34+
switch (data.type) {
35+
case "message_start": {
36+
const { content, id, model } = data.message;
37+
current_id = id;
38+
current_model = model;
39+
yield {
40+
id,
41+
object: "chat.completion.chunk",
42+
created,
43+
model,
44+
system_fingerprint,
45+
choices: [
46+
{
47+
index: 0,
48+
delta: {
49+
role: "assistant",
50+
content: "",
51+
},
52+
logprobs: null,
53+
finish_reason: null,
54+
},
55+
],
56+
};
57+
input_tokens = data.message.usage.input_tokens;
58+
break;
59+
}
60+
case "message_stop": {
61+
// 这个最后会触发,但是用不到
62+
break;
63+
}
64+
case "message_delta": {
65+
// 这个会在 stop 之前触发,带有 stop_reason
66+
yield {
67+
id: current_id,
68+
object: "chat.completion.chunk",
69+
created,
70+
model,
71+
system_fingerprint,
72+
choices: [
73+
{
74+
index: 0,
75+
delta: {
76+
content: "",
77+
role: "assistant",
78+
},
79+
logprobs: null,
80+
finish_reason: data.delta.stop_reason
81+
? claude_stop_to_openai_stop(data.delta.stop_reason)
82+
: null,
83+
},
84+
],
85+
usage: {
86+
prompt_tokens: input_tokens,
87+
completion_tokens: data.usage.output_tokens,
88+
total_tokens: input_tokens + data.usage.output_tokens,
89+
},
90+
};
91+
break;
92+
}
93+
case "content_block_delta":
94+
{
95+
// 这个才是真正的 delta。。。
96+
const block = data.delta;
97+
if (block.type !== "text_delta") {
98+
// 直接无视
99+
break;
100+
}
101+
yield {
102+
id: current_id,
103+
object: "chat.completion.chunk",
104+
created,
105+
model: current_model,
106+
system_fingerprint,
107+
choices: [
108+
{
109+
index: 0,
110+
delta: {
111+
role: "assistant",
112+
content: block.text,
113+
},
114+
logprobs: null,
115+
finish_reason: null,
116+
},
117+
],
118+
};
119+
}
120+
break;
121+
}
122+
}
123+
}
124+
125+
reply.raw.writeHead(200, {
126+
"Content-Type": "text/event-stream",
127+
"Cache-Control": "no-cache",
128+
Connection: "keep-alive",
129+
"Access-Control-Allow-Origin": "*",
130+
"Transfer-Encoding": "chunked",
131+
});
132+
// reply.raw.write("\n\n");
133+
try {
134+
for await (const data of build_streaming()) {
135+
// console.log(data);
136+
reply.raw.write("data: " + JSON.stringify(data));
137+
reply.raw.write("\n\n");
138+
}
139+
} catch (error) {
140+
console.error(error);
141+
console.error(JSON.stringify(error));
142+
throw error;
143+
}
144+
145+
reply.raw.write("data: [DONE]");
146+
reply.raw.write("\n\n");
147+
148+
reply.raw.end();
149+
}
150+
151+
async request_api() {
152+
const { request, reply, payload } = this;
153+
const { created, model } = payload;
154+
delete (payload as any).created;
155+
156+
const result = await client.messages.create({
157+
...payload,
158+
stream: false,
159+
});
160+
const openai_response = {
161+
id: result.id,
162+
object: "chat.completion",
163+
created: created,
164+
model: result.model,
165+
system_fingerprint: system_fingerprint,
166+
choices: [
167+
...result.content.map((x: any, index) => ({
168+
index,
169+
message: {
170+
role: "assistant",
171+
// NOTE: 这里的 x 有可能是 tool_use,但是我们不支持 tools 所以都是 text
172+
content: x.text ?? "",
173+
},
174+
logprobs: null,
175+
finish_reason: claude_stop_to_openai_stop(
176+
result.stop_reason ?? "stop"
177+
),
178+
})),
179+
],
180+
usage: {
181+
prompt_tokens: result.usage.input_tokens,
182+
completion_tokens: result.usage.output_tokens,
183+
total_tokens: result.usage.input_tokens + result.usage.output_tokens,
184+
},
185+
};
186+
187+
return openai_response;
188+
}
189+
190+
async run() {
191+
const { request, reply } = this;
192+
try {
193+
if (request.body.stream) {
194+
return await this.request_stream_api();
195+
}
196+
return await this.request_api();
197+
} catch (error) {
198+
await reply.code(500).send({
199+
message:
200+
error instanceof Error ? error.message : "Internal Server Error",
201+
code: "internal_error",
202+
});
203+
return;
204+
}
205+
}
206+
}

src/RequestChecker.ts

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import { FastifyReply, FastifyRequest } from "fastify";
2+
import { CreatePayload, OpenAI_NS } from "./tyeps";
3+
import {
4+
private_key,
5+
claude_model_names_map,
6+
support_models,
7+
ensure_first_mode,
8+
} from "./constants";
9+
10+
export class RequestPreProcess {
11+
constructor(
12+
public payload: OpenAI_NS.CompletionRequest,
13+
public request: FastifyRequest<{
14+
Body: OpenAI_NS.CompletionRequest;
15+
}>,
16+
public reply: FastifyReply
17+
) {}
18+
19+
async check_headers() {
20+
const { request, reply } = this;
21+
if (private_key) {
22+
const { headers } = request;
23+
const { authorization } = headers;
24+
if (!authorization) {
25+
await reply
26+
.code(401)
27+
.header("WWW-Authenticate", "Bearer realm='openai'")
28+
.send({
29+
message: "Authorization header is required",
30+
code: "no_auth",
31+
});
32+
return false;
33+
}
34+
const [scheme, token] = authorization.split(" ");
35+
if (scheme !== "Bearer") {
36+
await reply
37+
.code(401)
38+
.header("WWW-Authenticate", "Bearer realm='openai'")
39+
.send({
40+
message: "Authorization scheme must be Bearer",
41+
code: "invalid_scheme",
42+
});
43+
return false;
44+
}
45+
if (token !== private_key) {
46+
await reply
47+
.code(403)
48+
.send({ message: "Invalid API key", code: "invalid_api_key" });
49+
return false;
50+
}
51+
}
52+
return true;
53+
}
54+
55+
async check(): Promise<Required<CreatePayload> | null> {
56+
const { payload, reply } = this;
57+
const {
58+
messages,
59+
model: _model,
60+
max_tokens = 512,
61+
temperature = 0.75,
62+
top_p = 1,
63+
stop,
64+
stream = false,
65+
// 下面的都不支持...
66+
n = 1,
67+
presence_penalty = 0,
68+
frequency_penalty = 0,
69+
} = payload;
70+
71+
const model = claude_model_names_map[_model] ?? _model;
72+
if (!support_models.includes(model)) {
73+
await reply.code(400).send({
74+
message: `model ${model} is not supported`,
75+
code: "model_not_supported",
76+
});
77+
return null;
78+
}
79+
80+
if (!Array.isArray(messages)) {
81+
throw new Error("messages should be an array");
82+
}
83+
84+
const ensure_result = await this.ensure_first_message_is_user();
85+
if (!ensure_result) {
86+
return null;
87+
}
88+
const { system0, no_sys_messages } = ensure_result;
89+
90+
return {
91+
created: Math.ceil(new Date().getTime() / 1000),
92+
system: system0?.content,
93+
messages: no_sys_messages,
94+
model,
95+
max_tokens,
96+
temperature,
97+
top_p,
98+
stop_sequences: !stop ? undefined : Array.isArray(stop) ? stop : [stop],
99+
};
100+
}
101+
102+
async ensure_first_message_is_user() {
103+
const { reply } = this;
104+
const { messages } = this.payload;
105+
106+
// 因为 claude 的 system 需要单独设置,所以从 messages 中提取出来
107+
// NOTE: 所以也就是说,只支持一个 system
108+
const system0 = messages.find((m) => m.role === "system");
109+
const no_sys_messages = messages.filter((m) => m.role !== "system") as {
110+
role: "user" | "assistant";
111+
content: string;
112+
}[];
113+
if (no_sys_messages[0].role !== "user") {
114+
// ensure first message is 'user' role
115+
switch (ensure_first_mode) {
116+
case "remove": {
117+
// NOTE: if the first message is not user, remove it until the first user message
118+
while (no_sys_messages[0].role !== ("user" as any)) {
119+
no_sys_messages.shift();
120+
}
121+
break;
122+
}
123+
case "continue": {
124+
// NOTE: if the first message is not user, add a user message `continue`
125+
no_sys_messages.unshift({
126+
role: "user",
127+
content: "continue",
128+
});
129+
break;
130+
}
131+
default: {
132+
console.warn(
133+
`ensure_first_mode ${ensure_first_mode} is not supported, use 'remove' instead`
134+
);
135+
}
136+
}
137+
}
138+
if (no_sys_messages.length === 0) {
139+
await reply.code(400).send({
140+
message: "messages should contain at least one user message",
141+
code: "messages_empty",
142+
});
143+
return null;
144+
}
145+
146+
return {
147+
system0,
148+
no_sys_messages,
149+
};
150+
}
151+
}

0 commit comments

Comments
 (0)