Skip to content

Commit

Permalink
Merge pull request #3 from RexWzh/rex/wrap-chatter
Browse files Browse the repository at this point in the history
allow process files
  • Loading branch information
RexWzh committed Dec 20, 2023
2 parents 0ac1539 + 8c26af5 commit 1cda773
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 9 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -104,3 +104,5 @@ ENV/
# IDE settings
.vscode/
.idea/
*.json
*.jsonl
8 changes: 8 additions & 0 deletions tests/test_process.py
@@ -0,0 +1,8 @@

from webchatter import WebChat, Node, process_messages
from random import randint

def test_process_msgs():
msgs = [f"find the result of {randint(3, 100)} + {randint(4, 100)}" for _ in range(4)]
process_messages(msgs[:2], "test.jsonl")
process_messages(msgs, "test.jsonl")
6 changes: 4 additions & 2 deletions tests/test_webchatter.py
Expand Up @@ -20,8 +20,10 @@ def test_askchat():
# test chatlog
chat_log = chat.chat_log
assert len(chat_log) == 4
assert chat_log[0] == "It is nice today, ahhh"
assert chat_log[2] == "Tell me a joke about the weather."
assert chat_log[0] == {"role":"user", "content":"It is nice today, ahhh"}
assert chat_log[1]['role'] == "assistant"
assert chat_log[2] == {"role":"user", "content":"Tell me a joke about the weather."}
assert chat_log[3]['role'] == "assistant"
newchat = WebChat(chat_id=chat.chat_id, node_id=chat.node_id)
assert newchat.chat_log == chat_log
# test store
Expand Down
1 change: 1 addition & 0 deletions webchatter/__init__.py
Expand Up @@ -8,6 +8,7 @@
from typing import Union
from . import request
from .webchatter import WebChat, Node
from .checkpoint import process_messages
from pprint import pprint

def load_envs(env:Union[None, str, dict]=None):
Expand Down
44 changes: 44 additions & 0 deletions webchatter/checkpoint.py
@@ -0,0 +1,44 @@
import os, json
from webchatter import WebChat
import tqdm, tqdm.notebook
# from chattool import load_chats

def process_messages(msgs, checkpoint:str, mode:str="delete", isjupyter:bool=False):
"""Process the messages.
Args:
msgs (list): The messages.
checkpoint (str): Store the checkpoint.
mode (str, optional): One of the three mode: delete, repeat, newchat. Defaults to "delete".
Returns:
list: The processed messages.
"""
offset = 0
if os.path.exists(checkpoint):
with open(checkpoint, 'r', encoding='utf-8') as f:
processed = f.read().strip().split('\n')
if len(processed) >= 1 and processed[0] != '':
offset = len(processed)
tq = tqdm.tqdm if not isjupyter else tqdm.notebook.tqdm
with open(checkpoint, 'a', encoding='utf-8') as f:
for ind in tq(range(offset, len(msgs))):
msg = msgs[ind]
chat = WebChat()
ans = chat.ask(msg, keep=False)
data = {"index":ind + offset, "chat_log":{"user":msg, "assistant":ans}}
f.write(json.dumps(data) + '\n')
return True

def process_chats(chats, checkpoint:str):
"""Process the chats.
Args:
chats (list): The chats.
checkpoint (str): Store the checkpoint.
Returns:
list: The processed chats.
"""
# TODO
return chats
20 changes: 13 additions & 7 deletions webchatter/webchatter.py
Expand Up @@ -120,7 +120,14 @@ def chat_log(self):
chat_log.append(mapping[node_id].message)
node_id = mapping[node_id].parent
# remove the root node and tree node
return chat_log[-3::-1]
chat_log = chat_log[-3::-1]
chat_log_with_role = []
for ind, log in enumerate(chat_log):
if ind % 2 == 0:
chat_log_with_role.append({"role": "user", "content":log})
else:
chat_log_with_role.append({"role": "assistant", "content":log})
return chat_log_with_role

def account_status(self):
"""Get the account status."""
Expand Down Expand Up @@ -171,7 +178,9 @@ def ask( self, message:str
if self.chat_id is None:
# create four nodes
tree_id, que_id = str(uuid.uuid4()), str(uuid.uuid4())
root_resp, ans_resp = chat_completion(url, token, message, que_id, tree_id)
root_resp, ans_resp = chat_completion(url, token, message, que_id, tree_id
, history_and_training_disabled=not keep)
if not keep: return Node(ans_resp).message
# update parent and children for these nodes
root_resp['children'], root_resp['parent'] = [que_id], tree_id
ans_resp['children'], ans_resp['parent'] = [], que_id
Expand Down Expand Up @@ -300,11 +309,8 @@ def load(self, path:str, check_mapping:bool=False):
def print_log(self):
"""Print the chat log."""
sep = '\n' + '-'*15 + '\n'
for i, msg in enumerate(self.chat_log):
if i % 2 == 0:
print(f"{sep}user{sep}{msg}\n")
else:
print(f"{sep}assistant{sep}{msg}\n")
for item in self.chat_log:
print(f"{sep}{item['role']}{sep}{item['content']}\n")
return True

def __repr__(self):
Expand Down

0 comments on commit 1cda773

Please sign in to comment.