Skip to content

Commit

Permalink
feat: prioritization prediction with model (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
holtgrewe committed Sep 7, 2023
1 parent c5c2925 commit 48d504c
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1 @@
tests/data/train_smoke/* filter=lfs diff=lfs merge=lfs -text
tests/data/** filter=lfs diff=lfs merge=lfs -text
46 changes: 45 additions & 1 deletion cada_prio/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Console script for CADA"""

import typing

import click

from cada_prio import _version, train_model
from cada_prio import _version, predict, train_model


@click.group()
Expand Down Expand Up @@ -37,7 +39,49 @@ def cli_train_model(
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
):
"""train model"""
ctx.ensure_object(dict)
train_model.run(
path_out, path_hgnc_json, path_gene_hpo_links, path_hpo_genes_to_phenotype, path_hpo_obo
)


@cli.command("predict")
@click.argument("path_model", type=str)
@click.option(
"--hpo-terms",
type=str,
help="comma-separate HPO terms or @file with space-separated ones",
required=True,
)
@click.option(
"--genes",
type=str,
help="comma-separated genes to restrict prediction to or @file with space-separated ones",
)
@click.pass_context
def cli_predict(
ctx: click.Context,
path_model: str,
hpo_terms: str,
genes: typing.Optional[str],
):
"""perform prediction/prioritication"""
ctx.ensure_object(dict)

if hpo_terms.startswith("@"):
with open(hpo_terms[1:]) as f:
hpo_term_list = [x.strip() for x in f.read().split()]
else:
hpo_term_list = [x.strip() for x in hpo_terms.split(",")]

gene_list = None
if genes:
if genes.startswith("@"):
with open(genes[1:]) as f:
gene_list = [x.strip() for x in f.read().split()]
else:
gene_list = [x.strip() for x in genes.split(",")]

if predict.run(path_model, hpo_term_list, gene_list) != 0:
ctx.exit(1)
105 changes: 105 additions & 0 deletions cada_prio/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Model-based prediction"""

import json
import os
import typing

import cattrs
from gensim.models import Word2Vec
from logzero import logger
import networkx as nx
import numpy as np

from cada_prio import train_model


def load_hgnc_info(path_hgnc_json: str) -> typing.List[train_model.GeneIds]:
result = []
with open(path_hgnc_json, "rt") as f:
for line in f:
result.append(cattrs.structure(json.loads(line), train_model.GeneIds))
return result


def run(
path_model: str,
orig_hpo_terms: typing.List[str],
genes: typing.Optional[typing.List[str]] = None,
) -> int:
# Load and prepare data
logger.info("Loading HGNC info...")
logger.info("- parsing")
hgnc_infos = load_hgnc_info(os.path.join(path_model, "hgnc_info.jsonl"))
logger.info("- create mapping")
all_to_hgnc = {}
for record in hgnc_infos:
all_to_hgnc[record.symbol] = record
all_to_hgnc[record.ncbi_gene_id] = record
all_to_hgnc[record.hgnc_id] = record
if record.ensembl_gene_id:
all_to_hgnc[record.ensembl_gene_id] = record
hgnc_info_by_id = {record.hgnc_id: record for record in hgnc_infos}
hgnc_ids = []
for gene in genes or []:
if gene in all_to_hgnc:
hgnc_ids.append(all_to_hgnc[gene].hgnc_id)
else:
logger.warning("could not resolve HGNC ID for gene %s", gene)
logger.info("... done loading HGNC info")

logger.info("Loading graph...")
graph = nx.read_gpickle(os.path.join(path_model, "graph.gpickle"))
logger.info("... done loading graph")
logger.info("Loading model...")
model = Word2Vec.load(os.path.join(path_model, "model"))
logger.info("... done loading model")

# Lookup HPO term embeddings.
hpo_terms = {}
for hpo_term in orig_hpo_terms:
if hpo_term not in model.wv:
logger.warn("skipping HPO term %s as it is not in the model", hpo_term)
else:
hpo_terms[hpo_term] = model.wv[hpo_term]
if not hpo_terms:
logger.error("no valid HPO terms in model")
return 1

# Generate a score for each gene in the knowledge graph
logger.info("Generating scores...")
gene_scores = {}
for node_id in graph.nodes():
if node_id.startswith("HGNC:"): # is gene
hgnc_id = node_id
if genes and hgnc_id not in genes:
continue # skip, not asked for

this_gene_scores = []
hgnc_id_emb = model.wv[hgnc_id]
for hpo_term, hpo_term_emb in hpo_terms.items():
score = np.dot(hpo_term_emb, hgnc_id_emb)
this_gene_scores.append(score)
gene_scores[hgnc_id] = sum(this_gene_scores) / len(hpo_terms)

# Write out results to stdout, largest score first
sorted_scores = sorted(gene_scores.items(), key=lambda x: x[1], reverse=True)
print("# query (len=%d): %s" % (len(hpo_terms), ",".join(hpo_terms)))
print("\t".join(["rank", "score", "gene_symbol", "ncbi_gene_id", "hgnc_id"]))
for rank, (hgnc_id, score) in enumerate(sorted_scores, start=1):
hgnc_info = hgnc_info_by_id[hgnc_id]
print(
"\t".join(
map(
str,
[
rank,
score,
hgnc_info.symbol,
hgnc_info.ncbi_gene_id,
hgnc_info.hgnc_id,
],
)
)
)

return 0
17 changes: 14 additions & 3 deletions cada_prio/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import attrs
import cattrs
from gensim.models import Word2Vec
from logzero import logger
import networkx as nx
import node2vec
Expand Down Expand Up @@ -250,9 +251,18 @@ def build_and_fit_model(clinvar_gen2phen, hpo_ontology):
return training_graph, model


def write_graph_and_model(path_out, training_graph, model):
def write_graph_and_model(
path_out, hgnc_info: typing.List[GeneIds], training_graph: nx.Graph, model: Word2Vec
):
os.makedirs(path_out, exist_ok=True)

path_hgnc_info = os.path.join(path_out, "hgnc_info.jsonl")
logger.info("Saving HGNC info to %s...", path_hgnc_info)
with open(path_hgnc_info, "w") as f:
for record in hgnc_info:
json.dump(cattrs.unstructure(record), f)
logger.info("... done saving HGNC info")

path_graph = os.path.join(path_out, "graph.gpickle")
logger.info("Saving graph to %s...", path_graph)
nx.write_gpickle(training_graph, path_graph)
Expand All @@ -261,6 +271,7 @@ def write_graph_and_model(path_out, training_graph, model):
logger.info("Saving embedding to %s...", path_out)
path_embeddings = os.path.join(path_out, "embedding")
logger.info("- %s", path_embeddings)
print(type(model.wv), model.wv)
model.wv.save_word2vec_format(path_embeddings)
path_model = os.path.join(path_out, "model")
logger.info("- %s", path_model)
Expand All @@ -280,9 +291,9 @@ def run(
clinvar_gen2phen = load_clinvar_gen2phen(path_gene_hpo_links)
hpo_gen2phen = load_hpo_gen2phen(path_hpo_genes_to_phenotype, ncbi_to_hgnc)
hpo_ontology, hpo_id_from_alt, hpo_id_to_name = load_hpo_ontology(path_hpo_obo)
_, _, _, _ = hgnc_info, hpo_gen2phen, hpo_id_from_alt, hpo_id_to_name
_, _, _ = hpo_gen2phen, hpo_id_from_alt, hpo_id_to_name

# build and fit model
training_graph, model = build_and_fit_model(clinvar_gen2phen, hpo_ontology)
# write out graph and model
write_graph_and_model(path_out, training_graph, model)
write_graph_and_model(path_out, hgnc_info, training_graph, model)
2 changes: 2 additions & 0 deletions stubs/gensim/models.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ class KeyedVectors:
class Word2Vec:
wv: KeyedVectors

@classmethod
def load(cls, fname: str): ...
def save(self, fname: str): ...
1 change: 1 addition & 0 deletions stubs/node2vec.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ class Node2Vec:
def fit(
self, window: int, min_count: int, batch_words: int, **kwargs
) -> gensim.models.Word2Vec: ...
def save(self, path: str) -> None: ...
3 changes: 3 additions & 0 deletions tests/data/model_smoke/embedding
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/model_smoke/graph.gpickle
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/model_smoke/hgnc_info.jsonl
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/data/model_smoke/model
Git LFS file not shown
5 changes: 5 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from cada_prio import predict


def test_predict_smoke_test(tmpdir):
predict.run("tests/data/model_smoke", "HP:0008551")

0 comments on commit 48d504c

Please sign in to comment.