/
topic_augmentor.py
71 lines (54 loc) · 2.23 KB
/
topic_augmentor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import traceback
from copy import copy
import argparse
from app.settings import settings
import json
import os
from app.llm.generators.topic import generate_topic, generate_specific_topic
from app.course.embeddings import dedup_list
import asyncio
from tqdm import tqdm
from app.util import debug_print_trace
def load_processed_titles(file_name: str):
with open(os.path.join(settings.DATA_DIR, file_name)) as f:
titles = json.load(f)
return titles
async def generate_topics(title):
topics = await generate_topic(title)
return topics
async def generate_specific_topics(title, domain=None):
topics = await generate_specific_topic(title, domain)
return topics
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Augment book titles with synthetic topics, and dedup.")
parser.add_argument("in_file", help="Input filename (flat json list)")
parser.add_argument("out_file", help="Output filename (flat json list)")
parser.add_argument("--domain", help="Specific domain for the topics", default=None, type=str)
parser.add_argument("--max", type=int, default=None, help="Maximum number of courses to generate")
args = parser.parse_args()
titles = load_processed_titles(args.in_file)
# Filter titles to max if needed
if args.max:
titles = titles[:args.max]
topics_from_titles = copy(titles)
for title in tqdm(titles):
try:
topics = asyncio.run(generate_topics(title))
topics_from_titles.extend(topics)
except Exception as e:
debug_print_trace()
print(f"Error generating topic: {e}")
topics_from_titles = dedup_list(topics_from_titles)
print(len(topics_from_titles))
all_topics = copy(topics_from_titles)
for topic in tqdm(topics_from_titles):
try:
topics = asyncio.run(generate_specific_topics(topic, domain=args.domain))
all_topics.extend(topics)
except Exception as e:
debug_print_trace()
print(f"Error generating specific topic: {e}")
all_topics = dedup_list(all_topics)
with open(os.path.join(settings.DATA_DIR, args.out_file), "w+") as f:
json.dump(all_topics, f, indent=2)
print(len(all_topics))