Skip to content

Commit

Permalink
feat: allow running with legacy model/graph data (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
holtgrewe committed Sep 14, 2023
1 parent 1222a8c commit 9d3cc7c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
50 changes: 36 additions & 14 deletions cada_prio/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,21 @@ def load_hgnc_info(path_model):
return all_to_hgnc, hgnc_info_by_id


def load_graph_model(path_model):
def load_graph_model(path_model: str, legacy_path: typing.Optional[str] = None):
path_graph = os.path.join(path_model, "graph.gpickle")
path_embedding = os.path.join(path_model, "model")
if legacy_path:
logger.info("(using legacy model paths from %s)", legacy_path)
path_graph = os.path.join(
legacy_path, "data", "processed", "knowledge_graph", "unweighted", "train100.gpickle"
)
path_embedding = os.path.join(legacy_path, "models", "unweighted", "node2vec.model")
logger.info("Loading graph...")
with open(os.path.join(path_model, "graph.gpickle"), "rb") as inputf:
with open(path_graph, "rb") as inputf:
graph = pickle.load(inputf)
logger.info("... done loading graph")
logger.info("Loading model...")
model = Word2Vec.load(os.path.join(path_model, "model"))
model = Word2Vec.load(path_embedding)
logger.info("... done loading model")
return graph, model

Expand All @@ -50,7 +58,7 @@ class NoValidHpoTerms(ValueError):


def run_prediction(
orig_hpo_terms, orig_genes, all_to_hgnc, graph, model
orig_hpo_terms, orig_genes, all_to_hgnc, graph, model, translate_legacy_entrez_ids: bool = False
) -> typing.Tuple[typing.List[str], typing.Dict[str, float]]:
# Lookup HPO term embeddings.
hpo_terms = {}
Expand All @@ -75,17 +83,31 @@ def run_prediction(
logger.info("Generating scores...")
gene_scores = {}
for node_id in graph.nodes():
if node_id.startswith("HGNC:"): # is gene
if translate_legacy_entrez_ids and node_id.startswith("Entrez:"):
graph_gene_id = node_id
stripped_node_id = node_id[len("Entrez:") :]
if stripped_node_id not in all_to_hgnc:
logger.warning(
"skipping legacy gene id %s as cannot translate to HGNC", graph_gene_id
)
continue
hgnc_id = all_to_hgnc[stripped_node_id].hgnc_id
elif node_id.startswith("HGNC:"): # is gene
graph_gene_id = node_id
hgnc_id = node_id
if genes and hgnc_id not in genes:
continue # skip, not asked for

this_gene_scores = []
hgnc_id_emb = model.wv[hgnc_id]
for hpo_term, hpo_term_emb in hpo_terms.items():
score = np.dot(hpo_term_emb, hgnc_id_emb)
this_gene_scores.append(score)
gene_scores[hgnc_id] = sum(this_gene_scores) / len(hpo_terms)
else:
continue # skip, no gene

# if we reach here then we have a gene
if genes and hgnc_id not in genes:
continue # skip, not asked for

this_gene_scores = []
graph_gene_emb = model.wv[graph_gene_id]
for hpo_term, hpo_term_emb in hpo_terms.items():
score = np.dot(hpo_term_emb, graph_gene_emb)
this_gene_scores.append(score)
gene_scores[hgnc_id] = sum(this_gene_scores) / len(hpo_terms)

sorted_scores = dict(sorted(gene_scores.items(), key=lambda x: x[1], reverse=True))
return list(hpo_terms.keys()), sorted_scores
Expand Down
5 changes: 4 additions & 1 deletion cada_prio/rest_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
DEBUG = env.get("CADA_DEBUG", "false").lower() in ("true", "1")
#: Path to data / model
PATH_DATA = env.get("CADA_PATH_DATA", "/data/cada")
#: Optional path to legacy model from CADA
PATH_LEGACY = env.get("CADA_PATH_LEGACY", None)

#: The CADA models, to be loaded on startup.
GLOBAL_STATIC = {}
Expand All @@ -30,7 +32,7 @@ async def lifespan(app: FastAPI):
all_to_hgnc, hgnc_info_by_id = predict.load_hgnc_info(PATH_DATA)
GLOBAL_STATIC["all_to_hgnc"] = all_to_hgnc
GLOBAL_STATIC["hgnc_info_by_id"] = hgnc_info_by_id
graph, model = predict.load_graph_model(PATH_DATA)
graph, model = predict.load_graph_model(PATH_DATA, PATH_LEGACY)
GLOBAL_STATIC["graph"] = graph
GLOBAL_STATIC["model"] = model

Expand Down Expand Up @@ -70,6 +72,7 @@ async def handle_predict(
GLOBAL_STATIC["all_to_hgnc"],
GLOBAL_STATIC["graph"],
GLOBAL_STATIC["model"],
PATH_LEGACY is not None,
)
hgnc_info_by_id = GLOBAL_STATIC["hgnc_info_by_id"]
return [
Expand Down

0 comments on commit 9d3cc7c

Please sign in to comment.