Skip to content

Commit

Permalink
update process method
Browse files Browse the repository at this point in the history
  • Loading branch information
RexWzh committed Dec 20, 2023
1 parent 45d3507 commit 54e392a
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README-EN.md
Expand Up @@ -35,7 +35,7 @@ from random import randint

msgs = [f"find the result of {randint(3, 100)} + {randint(4, 100)}" for _ in range(4)]
# Annotate some data and get interrupted
process_messages(msgs[:2], "test.jsonl", time_interval=5, max_tries=3)
process_messages(msgs[:2], "test.jsonl", interval=5, max_tries=3)
# Continue annotation
process_messages(msgs, "test.jsonl")
```
Expand Down
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -36,7 +36,7 @@ from random import randint

msgs = [f"find the result of {randint(3, 100)} + {randint(4, 100)}" for _ in range(4)]
# 标注一部分后被中断
process_messages(msgs[:2], "test.jsonl", time_interval=5, max_tries=3)
process_messages(msgs[:2], "test.jsonl", interval=5, max_tries=3)
# 继续标注
process_messages(msgs, "test.jsonl")
```
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Expand Up @@ -3,21 +3,20 @@
"""The setup script."""

from setuptools import setup, find_packages
VERSION = '0.1.1'
VERSION = '0.2.0'

with open('README.md') as readme_file:
readme = readme_file.read()

requirements = [ ]

test_requirements = ['pytest>=3', ]
test_requirements = ['pytest>=3', 'tqdm>=4.60', 'chattool>=3.0.0']

setup(
author="Rex Wang",
author_email='1073853456@qq.com',
python_requires='>=3.6',
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
Expand Down
4 changes: 2 additions & 2 deletions tests/test_process.py
Expand Up @@ -4,5 +4,5 @@

def test_process_msgs():
msgs = [f"find the result of {randint(3, 100)} + {randint(4, 100)}" for _ in range(4)]
process_messages(msgs[:2], "test.jsonl")
process_messages(msgs, "test.jsonl")
process_messages(msgs[:2], "test.jsonl", max_tries=3)
process_messages(msgs, "test.jsonl", max_tries=3, interval=3)
2 changes: 1 addition & 1 deletion webchatter/__init__.py
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '1073853456@qq.com'
__version__ = '0.1.1'
__version__ = '0.2.0'

import os, dotenv, requests
from typing import Union
Expand Down
66 changes: 38 additions & 28 deletions webchatter/checkpoint.py
@@ -1,49 +1,59 @@
import os, json
import os, json, warnings, time
from webchatter import WebChat
import tqdm, tqdm.notebook
import time
from chattool import load_chats, Chat
from typing import List, Callable
# from chattool import load_chats

def process_messages( msgs
def try_sth(func:Callable, max_tries:int, interval:float, *args, **kwargs):
"""Try something.
Args:
func (Callable): The function to try.
max_tries (int): The maximum number of tries.
interval (float): The interval between tries.
"""
while max_tries:
try:
return func(*args, **kwargs)
except Exception as e:
print(e)
max_tries -= 1
time.sleep(interval)
return None

def process_messages( msgs:List[str]
, checkpoint:str
, time_interval:int=5
, interval:int=5
, max_tries:int=-1
, isjupyter:bool=False
, interval_rate:float=1
):
"""Process the messages.
Args:
msgs (list): The messages.
checkpoint (str): Store the checkpoint.
Returns:
list: The processed messages.
"""
offset = 0
if os.path.exists(checkpoint):
with open(checkpoint, 'r', encoding='utf-8') as f:
processed = f.read().strip().split('\n')
if len(processed) >= 1 and processed[0] != '':
offset = len(processed)
chats = load_chats(checkpoint)
if len(chats) > len(msgs):
warnings.warn(f"checkpoint file {checkpoint} has more chats than the data to be processed")
return chats[:len(msgs)]
chats.extend([None] * (len(msgs) - len(chats)))
tq = tqdm.tqdm if not isjupyter else tqdm.notebook.tqdm
chat = WebChat()
with open(checkpoint, 'a', encoding='utf-8') as f:
for ind in tq(range(offset, len(msgs))):
wait_time = time_interval
while max_tries:
try:
msg = msgs[ind]
ans = chat.ask(msg, keep=False)
data = {"index":ind + offset, "chat_log":{"user":msg, "assistant":ans}}
f.write(json.dumps(data) + '\n')
break
except Exception as e:
print(ind, e)
max_tries -= 1
time.sleep(wait_time)
wait_time = wait_time * interval_rate
return True
# process chats
webchat, chat = WebChat(), Chat()
for ind in tq(range(len(chats))):
if chats[ind] is not None: continue
ans = try_sth(webchat.ask, max_tries, interval, msgs[ind])
chat = Chat(msgs[ind])
chat.assistant(ans)
chat.save(checkpoint, mode='a', index=ind)
chats[ind] = chat
return chats

def process_chats(chats, checkpoint:str):
"""Process the chats.
Expand Down
1 change: 0 additions & 1 deletion webchatter/webchatter.py
Expand Up @@ -283,7 +283,6 @@ def save( self, file:str
# make path if not exists
pathname = os.path.dirname(file).strip()
if pathname != '': os.makedirs(pathname, exist_ok=True)

if chat_log_only:
data = {
"index": index,
Expand Down

0 comments on commit 54e392a

Please sign in to comment.