Skip to content

Commit

Permalink
feat: adding dump-graph to cli (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
holtgrewe committed Sep 14, 2023
1 parent 9d3cc7c commit 3aace31
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 9 deletions.
16 changes: 15 additions & 1 deletion cada_prio/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import click

from cada_prio import _version, predict, train_model
from cada_prio import _version, inspection, predict, train_model


@click.group()
Expand Down Expand Up @@ -92,3 +92,17 @@ def cli_predict(

if predict.run(path_model, hpo_term_list, gene_list) != 0:
ctx.exit(1)


@cli.command("dump-graph")
@click.argument("path_graph", type=str)
@click.argument("path_hgnc_info", type=str)
@click.pass_context
def cli_dump_graph(
ctx: click.Context,
path_graph: str,
path_hgnc_info: str,
):
"""dump graph edges for debugging"""
ctx.ensure_object(dict)
inspection.dump_graph(path_graph, path_hgnc_info)
20 changes: 20 additions & 0 deletions cada_prio/inspection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Helpers for inspection models"""

import pickle

import networkx as nx

from cada_prio.predict import load_hgnc_info


def dump_graph(path_graph: str, path_hgnc_info: str):
_, hgnc_info_by_id = load_hgnc_info(path_hgnc_info)
with open(path_graph, "rb") as inputf:
graph: nx.Graph = pickle.load(inputf)
for edge in sorted(graph.edges):
lhs, rhs = edge
if lhs.startswith("HGNC:"):
lhs = "Entrez:%s" % hgnc_info_by_id[lhs].ncbi_gene_id
if rhs.startswith("HGNC:"):
rhs = "Entrez:%s" % hgnc_info_by_id[rhs].ncbi_gene_id
print(f"{lhs}\t{rhs}")
5 changes: 2 additions & 3 deletions cada_prio/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from cada_prio import train_model


def load_hgnc_info(path_model):
def load_hgnc_info(path_hgnc_jsonl):
logger.info("Loading HGNC info...")
logger.info("- parsing")
hgnc_infos = []
path_hgnc_jsonl = os.path.join(path_model, "hgnc_info.jsonl")
with open(path_hgnc_jsonl, "rt") as f:
for line in f:
hgnc_infos.append(cattrs.structure(json.loads(line), train_model.GeneIds))
Expand Down Expand Up @@ -119,7 +118,7 @@ def run(
orig_genes: typing.Optional[typing.List[str]] = None,
) -> int:
# Load and prepare data
all_to_hgnc, hgnc_info_by_id = load_hgnc_info(path_model)
all_to_hgnc, hgnc_info_by_id = load_hgnc_info(os.path.join(path_model, "hgnc_info.jsonl"))
graph, model = load_graph_model(path_model)
try:
hpo_terms, sorted_scores = run_prediction(
Expand Down
4 changes: 3 additions & 1 deletion cada_prio/rest_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the models
all_to_hgnc, hgnc_info_by_id = predict.load_hgnc_info(PATH_DATA)
all_to_hgnc, hgnc_info_by_id = predict.load_hgnc_info(
os.path.join(PATH_DATA, "hgnc_info.jsonl")
)
GLOBAL_STATIC["all_to_hgnc"] = all_to_hgnc
GLOBAL_STATIC["hgnc_info_by_id"] = hgnc_info_by_id
graph, model = predict.load_graph_model(PATH_DATA, PATH_LEGACY)
Expand Down
18 changes: 14 additions & 4 deletions cada_prio/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def build_and_fit_model(*, clinvar_gen2phen, hpo_gen2phen, hpo_ontology, cpus: i
yield_gene2phen_edges(clinvar_gen2phen),
)
)
with open("__trainign_edges.json", "wt") as outputf:
print(json.dumps(cattrs.unstructure(training_edges), indent=2), file=outputf)
logger.info("- graph construction")
training_graph = nx.Graph()
training_graph.add_edges_from(training_edges)
Expand All @@ -257,11 +259,15 @@ def build_and_fit_model(*, clinvar_gen2phen, hpo_gen2phen, hpo_ontology, cpus: i
batch_words=embedding_params.batch_words,
)
logger.info("... done computing the embedding")
return training_graph, model
return training_graph, model, embedding_params


def write_graph_and_model(
path_out, hgnc_info: typing.List[GeneIds], training_graph: nx.Graph, model: Word2Vec
path_out,
hgnc_info: typing.List[GeneIds],
training_graph: nx.Graph,
embedding_params: EmbeddingParams,
model: Word2Vec,
):
os.makedirs(path_out, exist_ok=True)

Expand All @@ -286,6 +292,10 @@ def write_graph_and_model(
path_model = os.path.join(path_out, "model")
logger.info("- %s", path_model)
model.save(path_model)
path_params = os.path.join(path_out, "embedding_params.json")
logger.info("- %s", path_params)
with open(path_params, "wt") as outputf:
print(json.dumps(cattrs.unstructure(embedding_params), indent=2), file=outputf)
logger.info("... done saving embedding to")


Expand All @@ -305,11 +315,11 @@ def run(
_, _ = hpo_id_from_alt, hpo_id_to_name

# build and fit model
training_graph, model = build_and_fit_model(
training_graph, model, embedding_params = build_and_fit_model(
clinvar_gen2phen=clinvar_gen2phen,
hpo_gen2phen=hpo_gen2phen,
hpo_ontology=hpo_ontology,
cpus=cpus,
)
# write out graph and model
write_graph_and_model(path_out, hgnc_info, training_graph, model)
write_graph_and_model(path_out, hgnc_info, training_graph, embedding_params, model)

0 comments on commit 3aace31

Please sign in to comment.