Skip to content

Commit

Permalink
feat: adding param-opt command with single parameter evaluation (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
holtgrewe committed Sep 14, 2023
1 parent bbd5d86 commit 83141c6
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
tests/data/** filter=lfs diff=lfs merge=lfs -text
data/param_opt/* filter=lfs diff=lfs merge=lfs -text
60 changes: 59 additions & 1 deletion cada_prio/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Console script for CADA"""

import logging
import typing

import click
import logzero

from cada_prio import _version, inspection, predict, train_model
from cada_prio import _version, inspection, param_opt, predict, train_model


@click.group()
Expand All @@ -15,6 +17,10 @@ def cli(ctx: click.Context, verbose: bool):
"""Main entry point for CLI via click."""
ctx.ensure_object(dict)
ctx.obj["verbose"] = verbose
if verbose:
logzero.loglevel(logging.DEBUG)
else:
logzero.loglevel(logging.INFO)


@cli.command("train-model")
Expand Down Expand Up @@ -111,3 +117,55 @@ def cli_dump_graph(
"""dump graph edges for debugging"""
ctx.ensure_object(dict)
inspection.dump_graph(path_graph, path_hgnc_info)


@cli.command("param-opt")
@click.option("--path-hgnc-json", type=str, help="path to HGNC JSON", required=True)
@click.option(
"--path-hpo-genes-to-phenotype",
type=str,
help="path to genes_to_phenotype.txt file",
required=True,
)
@click.option("--path-hpo-obo", type=str, help="path HPO OBO file", required=True)
@click.option(
"--path-clinvar-phenotype-links",
type=str,
help="path to ClinVar phenotype links JSONL",
required=True,
)
@click.option(
"--fraction-links",
type=float,
help="fraction of links to add to the graph",
required=True,
)
@click.option(
"--seed",
type=int,
help="seed for random number generator",
default=1,
)
@click.option("--cpus", type=int, help="number of CPUs to use", default=1)
@click.pass_context
def cli_param_opt(
ctx: click.Context,
path_hgnc_json: str,
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
path_clinvar_phenotype_links: str,
fraction_links: float,
seed: int,
cpus: int,
):
"""dump graph edges for debugging"""
ctx.ensure_object(dict)
param_opt.perform_parameter_optimization(
path_hgnc_json=path_hgnc_json,
path_hpo_genes_to_phenotype=path_hpo_genes_to_phenotype,
path_hpo_obo=path_hpo_obo,
path_clinvar_phenotype_links=path_clinvar_phenotype_links,
fraction_links=fraction_links,
seed=seed,
cpus=cpus,
)
230 changes: 230 additions & 0 deletions cada_prio/param_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
"""Code for parameter optimization."""

import gzip
import hashlib
import json
import os
import random
import typing

import cattrs
from logzero import logger
import tqdm

from cada_prio import predict, train_model


def load_clinvar_phenotype_links_jsonl(path) -> typing.Iterable[train_model.GenePhenotypeRecord]:
if path.endswith(".gz"):
inputf = gzip.open(path, "rt")
else:
inputf = open(path, "rt")
with inputf:
for line in inputf:
yield cattrs.structure(json.loads(line), train_model.GenePhenotypeRecord)


def load_links(path_clinvar_phenotype_links: str, fraction_links: float, seed: int):
rng = random.Random(seed)

logger.info("Loading phenotype links...")
logger.info("- load JSONL")
orig_phenotype_links = list(load_clinvar_phenotype_links_jsonl(path_clinvar_phenotype_links))
logger.info("- make unique by submitter")
phenotype_links_dict = {
f"{link.submitter}/{','.join(link.hgnc_ids)}/{','.join(link.hpo_terms)}": link
for link in orig_phenotype_links
if len(link.hgnc_ids) == 1
}
phenotype_links = list(phenotype_links_dict.values())
logger.info("- randomizing")
rng.shuffle(phenotype_links)
logger.info("... done loading %s links", len(phenotype_links))

logger.info("Computing counts...")
n_links_total = len(phenotype_links)
n_links_used = int(fraction_links * n_links_total)
n_links_training = int(n_links_used * 0.6)
n_links_validation = (n_links_used - n_links_training) // 2
n_links_test = n_links_validation
links_training = phenotype_links[:n_links_training]
links_validation = phenotype_links[n_links_training : n_links_training + n_links_validation]
links_test = phenotype_links[
n_links_training + n_links_validation : n_links_training + n_links_validation + n_links_test
]
logger.info("- total: % 6d", len(phenotype_links))
logger.info("- used: % 6d", n_links_used)
logger.info("- training: % 6d", len(links_training))
logger.info("- validation: % 6d", len(links_validation))
logger.info("- test: % 6d", len(links_test))
logger.info("... done computing counts")

return links_training, links_validation, links_test


def prepare_training(
links_training: typing.List[train_model.GenePhenotypeRecord],
links_validation: typing.List[train_model.GenePhenotypeRecord],
links_test: typing.List[train_model.GenePhenotypeRecord],
fraction_links: float,
seed: int,
):
logger.info("Preparing training...")
embedding_params = train_model.EmbeddingParams(seed_embedding=seed + 23, seed_fit=seed + 42)
params_name = hashlib.md5(
json.dumps(cattrs.unstructure(embedding_params)).encode("utf-8")
).hexdigest()
path_out = f"param_opt.d/{int(fraction_links * 100)}.{seed}.{params_name}"
os.makedirs(path_out, exist_ok=True)
logger.info("- output path: %s", path_out)
logger.info(
"- embedding params:\n%s", json.dumps(cattrs.unstructure(embedding_params), indent=2)
)
path_embedding_params = f"{path_out}/embedding_params.json"
with open(path_embedding_params, "wt") as outputf:
print(json.dumps(cattrs.unstructure(embedding_params), indent=2), file=outputf)
path_links_training = f"{path_out}/links_training.jsonl"
logger.info("- trailing links: %s", path_links_training)
with open(path_links_training, "wt") as outputf:
for link in links_training:
print(json.dumps(cattrs.unstructure(link)), file=outputf)
path_links_validation = f"{path_out}/links_validation.jsonl"
logger.info("- validation links: %s", path_links_validation)
with open(path_links_validation, "wt") as outputf:
for link in links_validation:
print(json.dumps(cattrs.unstructure(link)), file=outputf)
path_links_test = f"{path_out}/links_test.jsonl"
logger.info("- test links: %s", path_links_test)
with open(path_links_test, "wt") as outputf:
for link in links_test:
print(json.dumps(cattrs.unstructure(link)), file=outputf)
logger.info("... done preparing training")

return (
path_out,
path_links_training,
path_links_validation,
path_links_test,
path_embedding_params,
)


def run_training(
path_out: str,
path_hgnc_json: str,
path_links_training: str,
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
path_embedding_params: str,
cpus: int,
):
logger.info("Running training...")
train_model.run(
f"{path_out}/model",
path_hgnc_json,
path_links_training,
path_hpo_genes_to_phenotype,
path_hpo_obo,
path_embedding_params,
cpus=cpus,
)
logger.info("... done running training")


def run_validation(
path_out: str,
path_links_validation: str,
):
logger.info("Running validation...")
path_model = f"{path_out}/model"
logger.info("- model: %s", path_model)
all_to_hgnc, _ = predict.load_hgnc_info(os.path.join(path_model, "hgnc_info.jsonl"))
graph, model = predict.load_graph_model(path_model)
_, hpo_id_from_alt, _ = train_model.load_hpo_ontology(os.path.join(path_model, "hp.obo"))

logger.info("... validation steps ...")
links_total = 0
links_top = {
1: 0,
5: 0,
10: 0,
50: 0,
100: 0,
}
pb = tqdm.tqdm(list(load_clinvar_phenotype_links_jsonl(path_links_validation)), unit=" records")
for link_validation in pb:
if len(link_validation.hgnc_ids) != 1:
logger.warn(
"skipping submission %s with %d genes",
link_validation.scv,
len(link_validation.hgnc_ids),
)
continue
else:
hgnc_id = link_validation.hgnc_ids[0]

try:
hpo_terms = list(
sorted(set(hpo_id_from_alt.get(x, x) for x in link_validation.hpo_terms))
)
_, sorted_scores = predict.run_prediction(hpo_terms, None, all_to_hgnc, graph, model)
except predict.NoValidHpoTerms:
logger.warn("no valid HPO terms in %s (skipped)", link_validation.hpo_terms)
continue

links_total += 1
ranked_genes = list(sorted_scores.keys())
if hgnc_id in ranked_genes:
rank = ranked_genes.index(hgnc_id) + 1
else:
rank = len(ranked_genes)
for no in links_top.keys():
if rank <= no:
links_top[no] += 1

result = {no: 100.0 * count / links_total for no, count in links_top.items()}
logger.info("<result>")
print(json.dumps(result, indent=2))
logger.info("</result>")

logger.info("... done running validation")


def perform_parameter_optimization(
path_hgnc_json: str,
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
path_clinvar_phenotype_links: str,
fraction_links: float,
seed: int,
cpus: int,
):
"""Simulate cases based on the dataset file."""

links_training, links_validation, links_test = load_links(
path_clinvar_phenotype_links, fraction_links, seed
)

(
path_out,
path_links_training,
path_links_validation,
path_links_test,
path_embedding_params,
) = prepare_training(links_training, links_validation, links_test, fraction_links, seed)
_ = path_links_test

run_training(
path_out,
path_hgnc_json,
path_links_training,
path_hpo_genes_to_phenotype,
path_hpo_obo,
path_embedding_params,
cpus=cpus,
)

run_validation(
path_out,
path_links_validation,
)
2 changes: 1 addition & 1 deletion cada_prio/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def run_prediction(
genes.append(orig_gene)

# Generate a score for each gene in the knowledge graph
logger.info("Generating scores...")
logger.debug("Generating scores...")
gene_scores = {}
for node_id in graph.nodes():
if translate_legacy_entrez_ids and node_id.startswith("Entrez:"):
Expand Down
6 changes: 5 additions & 1 deletion cada_prio/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import os
import pickle
import shutil
import typing
import warnings

Expand Down Expand Up @@ -232,7 +233,7 @@ def build_and_fit_model(
hpo_gen2phen,
hpo_ontology,
embedding_params: EmbeddingParams,
cpus: int = 1
cpus: int = 1,
):
# create graph edges combining HPO hierarchy and training edges from ClinVar
logger.info("Constructing training graph ...")
Expand Down Expand Up @@ -270,6 +271,7 @@ def build_and_fit_model(
min_count=embedding_params.min_count,
batch_words=embedding_params.batch_words,
seed=embedding_params.seed_fit,
workers=cpus,
)
logger.info("... done computing the embedding")
return training_graph, model, embedding_params
Expand Down Expand Up @@ -352,3 +354,5 @@ def run(
)
# write out graph and model
write_graph_and_model(path_out, hgnc_info, training_graph, embedding_params, model)

shutil.copyfile(path_hpo_obo, f"{path_out}/hp.obo")
2 changes: 2 additions & 0 deletions stubs/logzero.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ class LogFormatter(logging.Formatter):
def format(self, record): ...

logger: logging.Logger

def loglevel(level: int): ...

0 comments on commit 83141c6

Please sign in to comment.