Skip to content

Commit f9d8622

Browse files
authored
Merge pull request #4 from ivanleomk/fix-path
fix: correct issue with file path
2 parents af13a9a + b2d540d commit f9d8622

File tree

4 files changed

+27
-131
lines changed

4 files changed

+27
-131
lines changed

kura/cli/cli.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22
import uvicorn
33
from kura.cli.server import api
44
from rich import print
5+
import os
56

67
app = typer.Typer()
78

89

910
@app.command()
10-
def start_app():
11+
def start_app(
12+
dir: str = typer.Option(
13+
"./checkpoints",
14+
help="Directory to use for checkpoints, relative to the current directory",
15+
),
16+
):
1117
"""Start the FastAPI server"""
12-
18+
os.environ["KURA_CHECKPOINT_DIR"] = dir
1319
uvicorn.run(api, host="0.0.0.0", port=8000)
1420
print(
1521
"\n[bold green]🚀 Access website at[/bold green] [bold blue][http://localhost:8000](http://localhost:8000)[/bold blue]\n"

kura/cli/server.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
generate_new_chats_per_week_data,
1212
)
1313
import json
14+
import os
1415

1516
api = FastAPI()
1617

@@ -57,16 +58,21 @@ async def analyse_conversations(conversation_data: ConversationData):
5758
for conversation in conversation_data.data
5859
]
5960

60-
# Load clusters from checkpoint file if it exists
6161
clusters_file = (
62-
Path(__file__).parent.parent.parent
63-
/ "checkpoints/dimensionality_checkpoints.json"
62+
Path(os.path.abspath(os.environ["KURA_CHECKPOINT_DIR"]))
63+
/ "dimensionality_checkpoints.json"
6464
)
6565
clusters = []
6666

67+
print(clusters_file)
68+
69+
# Load clusters from checkpoint file if it exists
70+
6771
if not clusters_file.exists():
68-
kura = Kura()
69-
kura.conversations = conversations
72+
kura = Kura(
73+
checkpoint_dir=Path(os.path.abspath(os.environ["KURA_CHECKPOINT_DIR"])),
74+
conversations=conversations[:100],
75+
)
7076
await kura.cluster_conversations()
7177

7278
with open(clusters_file) as f:

kura/kura.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,17 @@ def __init__(
2525
meta_cluster_model: BaseMetaClusterModel = MetaClusterModel(),
2626
dimensionality_reduction: BaseDimensionalityReduction = HDBUMAP(),
2727
max_clusters: int = 10,
28-
checkpoint_dir: str = "checkpoints",
28+
checkpoint_dir: str = "./checkpoints",
2929
cluster_checkpoint_name: str = "clusters.json",
3030
meta_cluster_checkpoint_name: str = "meta_clusters.json",
3131
):
32+
# Override checkpoint dirs so that they're the same for the models
33+
summarisation_model.checkpoint_dir = checkpoint_dir
34+
cluster_model.checkpoint_dir = checkpoint_dir
35+
meta_cluster_model.checkpoint_dir = checkpoint_dir
36+
dimensionality_reduction.checkpoint_dir = checkpoint_dir
37+
38+
self.embedding_model = embedding_model
3239
self.embedding_model = embedding_model
3340
self.summarisation_model = summarisation_model
3441
self.conversations = conversations

main.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

0 commit comments

Comments
 (0)