Skip to content

Commit dffb789

Browse files
authored
add example files
1 parent bbe6e12 commit dffb789

File tree

4 files changed

+221
-0
lines changed

4 files changed

+221
-0
lines changed

examples/client.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import requests
2+
3+
def chat_with_model(user_input, history):
4+
url = "http://localhost:8000/chat/"
5+
payload = {
6+
"user_input": user_input,
7+
"history": history
8+
}
9+
headers = {
10+
"Content-Type": "application/json"
11+
}
12+
13+
response = requests.post(url, json=payload, headers=headers)
14+
15+
if response.status_code == 200:
16+
return response.json()
17+
else:
18+
print(f"Failed to send request: {response.status_code} {response.text}")
19+
return None
20+
21+
def main():
22+
history = []
23+
print("Enter 'q' to quit, 'c' to clear chat history.")
24+
while True:
25+
user_input = input("User: ").strip()
26+
if user_input.lower() in ['q', 'quit']:
27+
print("Exiting chat.")
28+
break
29+
if user_input.lower() == 'c':
30+
print("Clearing chat history.")
31+
history.clear()
32+
continue
33+
34+
result = chat_with_model(user_input, history)
35+
if result:
36+
# Display the response from the model.
37+
print(f"Assistant: {result['response']}")
38+
# Update the chat history from the response.
39+
history = result['history']
40+
41+
if __name__ == "__main__":
42+
main()

examples/stream_client.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import httpx
2+
3+
4+
def chat_with_model_stream(user_input, history, url="http://localhost:8000/chat/"):
5+
payload = {
6+
"user_input": user_input,
7+
"history": history
8+
}
9+
10+
headers = {
11+
"Content-Type": "application/json"
12+
}
13+
14+
# Use httpx to send a POST request without stream=True, handle streaming in response context
15+
with httpx.Client() as client:
16+
with client.stream("POST", url, json=payload, headers=headers) as response:
17+
if response.status_code == 200:
18+
print("Assistant:", end=" ")
19+
for chunk in response.iter_text():
20+
print(chunk, end="", flush=True)
21+
print()
22+
else:
23+
print(f"Failed to send request: {response.status_code} {response.text}")
24+
return None
25+
26+
27+
def main():
28+
history = []
29+
print("Enter 'q' to quit, 'c' to clear chat history.")
30+
while True:
31+
user_input = input("User: ").strip()
32+
if user_input.lower() in ['q', 'quit']:
33+
print("Exiting chat.")
34+
break
35+
if user_input.lower() == 'c':
36+
print("Clearing chat history.")
37+
history.clear()
38+
continue
39+
40+
chat_with_model_stream(user_input, history)
41+
# Future improvement: Update history based on API response if needed.
42+
43+
44+
if __name__ == "__main__":
45+
main()
46+

examples/stream_vllm_fastapi.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from fastapi import FastAPI, HTTPException
2+
from fastapi.responses import StreamingResponse
3+
from pydantic import BaseModel
4+
import torch
5+
from vllm import LLM, SamplingParams
6+
from transformers import AutoTokenizer
7+
8+
app = FastAPI()
9+
10+
11+
class ChatRequest(BaseModel):
12+
user_input: str
13+
history: list
14+
15+
16+
tokenizer = None
17+
model = None
18+
19+
20+
@app.on_event("startup")
21+
def load_model_and_tokenizer():
22+
global tokenizer, model
23+
path = "AIDC-AI/Marco-o1"
24+
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
25+
model = LLM(model=path, tensor_parallel_size=4)
26+
27+
28+
def generate_response_stream(model, text, max_new_tokens=4096):
29+
new_output = ''
30+
sampling_params = SamplingParams(
31+
max_tokens=1,
32+
temperature=0,
33+
top_p=0.9
34+
)
35+
with torch.inference_mode():
36+
for _ in range(max_new_tokens):
37+
outputs = model.generate(
38+
[f'{text}{new_output}'],
39+
sampling_params=sampling_params,
40+
use_tqdm=False
41+
)
42+
next_token = outputs[0].outputs[0].text
43+
new_output += next_token
44+
yield next_token # Yield each part of the response
45+
46+
if new_output.endswith('</Output>'):
47+
break
48+
49+
50+
@app.post("/chat/")
51+
async def chat(request: ChatRequest):
52+
if not request.user_input:
53+
raise HTTPException(status_code=400, detail="Input cannot be empty.")
54+
55+
if request.user_input.lower() in ['q', 'quit']:
56+
return {"response": "Exiting chat."}
57+
58+
if request.user_input.lower() == 'c':
59+
request.history.clear()
60+
return {"response": "Clearing chat history."}
61+
62+
request.history.append({"role": "user", "content": request.user_input})
63+
text = tokenizer.apply_chat_template(request.history, tokenize=False, add_generation_prompt=True)
64+
65+
response_stream = generate_response_stream(model, text)
66+
67+
# Stream the response using StreamingResponse
68+
return StreamingResponse(response_stream, media_type="text/plain")
69+

examples/vllm_fastapi.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from fastapi import FastAPI, HTTPException
2+
from pydantic import BaseModel
3+
import torch
4+
from vllm import LLM, SamplingParams
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
6+
7+
app = FastAPI()
8+
9+
10+
class ChatRequest(BaseModel):
11+
user_input: str
12+
history: list
13+
14+
15+
tokenizer = None
16+
model = None
17+
18+
19+
@app.on_event("startup")
20+
def load_model_and_tokenizer():
21+
global tokenizer, model
22+
path = "AIDC-AI/Marco-o1"
23+
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
24+
model = LLM(model=path, tensor_parallel_size=4)
25+
26+
27+
def generate_response(model, text, max_new_tokens=4096):
28+
new_output = ''
29+
sampling_params = SamplingParams(
30+
max_tokens=1,
31+
temperature=0,
32+
top_p=0.9
33+
)
34+
with torch.inference_mode():
35+
for _ in range(max_new_tokens):
36+
outputs = model.generate(
37+
[f'{text}{new_output}'],
38+
sampling_params=sampling_params,
39+
use_tqdm=False
40+
)
41+
new_output += outputs[0].outputs[0].text
42+
if new_output.endswith('</Output>'):
43+
break
44+
return new_output
45+
46+
47+
@app.post("/chat/")
48+
async def chat(request: ChatRequest):
49+
if not request.user_input:
50+
raise HTTPException(status_code=400, detail="Input cannot be empty.")
51+
52+
if request.user_input.lower() in ['q', 'quit']:
53+
return {"response": "Exiting chat."}
54+
55+
if request.user_input.lower() == 'c':
56+
request.history.clear()
57+
return {"response": "Clearing chat history."}
58+
59+
request.history.append({"role": "user", "content": request.user_input})
60+
text = tokenizer.apply_chat_template(request.history, tokenize=False, add_generation_prompt=True)
61+
response = generate_response(model, text)
62+
request.history.append({"role": "assistant", "content": response})
63+
64+
return {"response": response, "history": request.history}

0 commit comments

Comments
 (0)