Skip to content

Commit

Permalink
feat: re-useable implementation of "tune train-eval" (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
holtgrewe committed Sep 15, 2023
1 parent 84c54e8 commit c80c4bf
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/model-out
/local_data

*~
.*.sw?
Expand Down
56 changes: 45 additions & 11 deletions cada_prio/cli.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""Console script for CADA"""

import logging
import os
import sys
import typing

import click
import logzero

from cada_prio import _version, inspection, param_opt, predict, train_model

# Lower the update interval of tqdm to 5 seconds if stdout is not a TTY.
if not sys.stdout.isatty():
os.environ["TQDM_MININTERVAL"] = "5"


@click.group()
@click.version_option(_version.__version__)
Expand All @@ -23,7 +29,7 @@ def cli(ctx: click.Context, verbose: bool):
logzero.loglevel(logging.INFO)


@cli.command("train-model")
@cli.command("train")
@click.argument("path_out", type=str)
@click.option("--path-hgnc-json", type=str, help="path to HGNC JSON", required=True)
@click.option(
Expand All @@ -41,7 +47,7 @@ def cli(ctx: click.Context, verbose: bool):
)
@click.option("--cpus", type=int, help="number of CPUs to use", default=1)
@click.pass_context
def cli_train_model(
def cli_train(
ctx: click.Context,
path_out: str,
path_hgnc_json: str,
Expand Down Expand Up @@ -105,7 +111,12 @@ def cli_predict(
ctx.exit(1)


@cli.command("dump-graph")
@cli.group("utils")
def cli_utils():
"""utilities"""


@cli_utils.command("dump-graph")
@click.argument("path_graph", type=str)
@click.argument("path_hgnc_info", type=str)
@click.pass_context
Expand All @@ -119,7 +130,13 @@ def cli_dump_graph(
inspection.dump_graph(path_graph, path_hgnc_info)


@cli.command("param-opt")
@cli.group("tune")
def cli_tune():
"""hyperparameter tuning"""


@cli_tune.command("train-eval")
@click.argument("path_out", type=str)
@click.option("--path-hgnc-json", type=str, help="path to HGNC JSON", required=True)
@click.option(
"--path-hpo-genes-to-phenotype",
Expand All @@ -137,35 +154,52 @@ def cli_dump_graph(
@click.option(
"--fraction-links",
type=float,
help="fraction of links to add to the graph",
required=True,
help="fraction of links to add to the graph (conflicts with --path-validation-links)",
)
@click.option(
"--path-validation-links",
type=str,
help="path to validation links JSONL (conflicts with --fraction-links)",
)
@click.option(
"--path-embedding-params", type=str, help="path to JSON file with embedding params; optional"
)
@click.option(
"--seed",
type=int,
help="seed for random number generator",
default=1,
)
@click.option("--cpus", type=int, help="number of CPUs to use", default=1)
@click.pass_context
def cli_param_opt(
ctx: click.Context,
path_out: str,
path_hgnc_json: str,
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
path_clinvar_phenotype_links: str,
fraction_links: float,
seed: int,
fraction_links: typing.Optional[float],
path_validation_links: typing.Optional[str],
path_embedding_params: typing.Optional[str],
seed: typing.Optional[int],
cpus: int,
):
"""dump graph edges for debugging"""
"""train and evaluate model for one set of parameters"""
if bool(fraction_links) == bool(path_validation_links):
raise click.UsageError(
"exactly one of --fraction-links and --path-validation-links must be given"
)

ctx.ensure_object(dict)
param_opt.perform_parameter_optimization(
param_opt.train_and_validate(
path_out=path_out,
path_hgnc_json=path_hgnc_json,
path_hpo_genes_to_phenotype=path_hpo_genes_to_phenotype,
path_hpo_obo=path_hpo_obo,
path_clinvar_phenotype_links=path_clinvar_phenotype_links,
fraction_links=fraction_links,
path_validation_links=path_validation_links,
path_embedding_params=path_embedding_params,
seed=seed,
cpus=cpus,
)
160 changes: 119 additions & 41 deletions cada_prio/param_opt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Code for parameter optimization."""

from datetime import timedelta
import gzip
import hashlib
import json
import os
import random
import time
import typing

import cattrs
Expand All @@ -13,6 +15,9 @@

from cada_prio import predict, train_model

#: Default seeed value.
DEFAULT_SEED = 1


def load_clinvar_phenotype_links_jsonl(path) -> typing.Iterable[train_model.GenePhenotypeRecord]:
if path.endswith(".gz"):
Expand All @@ -24,57 +29,105 @@ def load_clinvar_phenotype_links_jsonl(path) -> typing.Iterable[train_model.Gene
yield cattrs.structure(json.loads(line), train_model.GenePhenotypeRecord)


def load_links(path_clinvar_phenotype_links: str, fraction_links: float, seed: int):
rng = random.Random(seed)
def make_links_unique_by_submitter(
links: typing.Iterable[train_model.GenePhenotypeRecord],
) -> typing.List[train_model.GenePhenotypeRecord]:
links_dict = {
f"{link.submitter}/{','.join(link.hgnc_ids)}/{','.join(link.hpo_terms)}": link
for link in links
if len(link.hgnc_ids) == 1
}
return list(links_dict.values())


def load_links(
path_clinvar_phenotype_links: str,
fraction_links: typing.Optional[float],
path_validation_links: typing.Optional[str],
seed: typing.Optional[int],
):
rng = random.Random(seed or DEFAULT_SEED)

logger.info("Loading phenotype links...")
logger.info("- load JSONL")
orig_phenotype_links = list(load_clinvar_phenotype_links_jsonl(path_clinvar_phenotype_links))
logger.info("- make unique by submitter")
phenotype_links_dict = {
f"{link.submitter}/{','.join(link.hgnc_ids)}/{','.join(link.hpo_terms)}": link
for link in orig_phenotype_links
if len(link.hgnc_ids) == 1
}
phenotype_links = list(phenotype_links_dict.values())
phenotype_links = make_links_unique_by_submitter(orig_phenotype_links)
logger.info("- randomizing")
rng.shuffle(phenotype_links)
logger.info("... done loading %s links", len(phenotype_links))

logger.info("Computing counts...")
n_links_total = len(phenotype_links)
n_links_used = int(fraction_links * n_links_total)
n_links_training = int(n_links_used * 0.6)
n_links_validation = (n_links_used - n_links_training) // 2
n_links_test = n_links_validation
links_training = phenotype_links[:n_links_training]
links_validation = phenotype_links[n_links_training : n_links_training + n_links_validation]
links_test = phenotype_links[
n_links_training + n_links_validation : n_links_training + n_links_validation + n_links_test
]
logger.info("- total: % 6d", len(phenotype_links))
logger.info("- used: % 6d", n_links_used)
logger.info("- training: % 6d", len(links_training))
logger.info("- validation: % 6d", len(links_validation))
logger.info("- test: % 6d", len(links_test))
logger.info("... done computing counts")
if fraction_links is None:
assert (
path_validation_links is not None
), "must give fraction_links or path_validation_links"
links_training = phenotype_links
orig_links_validation = list(load_clinvar_phenotype_links_jsonl(path_validation_links))
links_validation = make_links_unique_by_submitter(orig_links_validation)
links_test: typing.List[train_model.GenePhenotypeRecord] = []
logger.info("Counts in explicit validation set...")
logger.info("- training: % 6d", len(links_training))
logger.info(
"- validation: % 6d (non-unique: %d)", len(links_validation), len(orig_links_validation)
)
logger.info("- test: % 6d", len(links_test))
logger.info("... that's all")
else:
logger.info("Computing counts...")
n_links_total = len(phenotype_links)
n_links_used = int(fraction_links * n_links_total)
n_links_training = int(n_links_used * 0.6)
n_links_validation = (n_links_used - n_links_training) // 2
n_links_test = n_links_validation
links_training = phenotype_links[:n_links_training]
links_validation = phenotype_links[n_links_training : n_links_training + n_links_validation]
links_test = phenotype_links[
n_links_training
+ n_links_validation : n_links_training
+ n_links_validation
+ n_links_test
]
logger.info("- total: % 6d", len(phenotype_links))
logger.info("- used: % 6d", n_links_used)
logger.info("- training: % 6d", len(links_training))
logger.info("- validation: % 6d", len(links_validation))
logger.info("- test: % 6d", len(links_test))
logger.info("... done computing counts")

return links_training, links_validation, links_test


def prepare_training(
path_out: str,
links_training: typing.List[train_model.GenePhenotypeRecord],
links_validation: typing.List[train_model.GenePhenotypeRecord],
links_test: typing.List[train_model.GenePhenotypeRecord],
fraction_links: float,
seed: int,
fraction_links: typing.Optional[float],
path_validation_links: typing.Optional[str],
path_embedding_params: typing.Optional[str],
seed: typing.Optional[int],
):
if seed is None:
seed = DEFAULT_SEED

logger.info("Preparing training...")
embedding_params = train_model.EmbeddingParams(seed_embedding=seed + 23, seed_fit=seed + 42)
if path_embedding_params:
logger.info("- loading embedding params from %s", path_embedding_params)
with open(path_embedding_params, "rt") as inputf:
embedding_params = cattrs.structure(json.load(inputf), train_model.EmbeddingParams)
else:
logger.info("- using default embedding params params")
embedding_params = train_model.EmbeddingParams(seed_embedding=seed + 23, seed_fit=seed + 42)
params_name = hashlib.md5(
json.dumps(cattrs.unstructure(embedding_params)).encode("utf-8")
).hexdigest()
path_out = f"param_opt.d/{int(fraction_links * 100)}.{seed}.{params_name}"

if fraction_links is None:
path_out = f"{path_out}/explicit-validation.{seed}.{params_name}"
_ = path_validation_links
else:
path_out = f"{path_out}/{int(fraction_links * 100)}.{seed}.{params_name}"

os.makedirs(path_out, exist_ok=True)
logger.info("- output path: %s", path_out)
logger.info(
Expand Down Expand Up @@ -115,9 +168,10 @@ def run_training(
path_links_training: str,
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
path_embedding_params: str,
path_embedding_params: typing.Optional[str],
cpus: int,
):
start = time.time()
logger.info("Running training...")
train_model.run(
f"{path_out}/model",
Expand All @@ -128,13 +182,19 @@ def run_training(
path_embedding_params,
cpus=cpus,
)
logger.info("... done running training")
elapsed = time.time() - start
logger.info("... done running training in %s", timedelta(seconds=elapsed))


def run_validation(
path_out: str,
path_links_validation: str,
):
) -> typing.Dict[int, float]:
"""Run validation on the links from the path.
:return: dictionary with top N -> percentage of links in top N
"""
start = time.time()
logger.info("Running validation...")
path_model = f"{path_out}/model"
logger.info("- model: %s", path_model)
Expand Down Expand Up @@ -187,22 +247,31 @@ def run_validation(
print(json.dumps(result, indent=2))
logger.info("</result>")

logger.info("... done running validation")
elapsed = time.time() - start
logger.info("... done running validation in %s", timedelta(seconds=elapsed))
return result


def perform_parameter_optimization(
def train_and_validate(
*,
path_out: str,
path_hgnc_json: str,
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
path_clinvar_phenotype_links: str,
fraction_links: float,
seed: int,
fraction_links: typing.Optional[float],
path_validation_links: typing.Optional[str],
path_embedding_params: typing.Optional[str],
seed: typing.Optional[int],
cpus: int,
):
"""Simulate cases based on the dataset file."""
) -> typing.Dict[int, float]:
"""Train model and run validation on the links from the path.
:return: dictionary with top N -> percentage of links in top N
"""

links_training, links_validation, links_test = load_links(
path_clinvar_phenotype_links, fraction_links, seed
path_clinvar_phenotype_links, fraction_links, path_validation_links, seed
)

(
Expand All @@ -211,7 +280,16 @@ def perform_parameter_optimization(
path_links_validation,
path_links_test,
path_embedding_params,
) = prepare_training(links_training, links_validation, links_test, fraction_links, seed)
) = prepare_training(
path_out,
links_training,
links_validation,
links_test,
fraction_links,
path_validation_links,
path_embedding_params,
seed,
)
_ = path_links_test

run_training(
Expand All @@ -224,7 +302,7 @@ def perform_parameter_optimization(
cpus=cpus,
)

run_validation(
return run_validation(
path_out,
path_links_validation,
)

0 comments on commit c80c4bf

Please sign in to comment.