Skip to content

Commit

Permalink
feat: embedding parameters can be provided via CLI and contains seeds (
Browse files Browse the repository at this point in the history
  • Loading branch information
holtgrewe committed Sep 14, 2023
1 parent 3aace31 commit bbd5d86
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
5 changes: 5 additions & 0 deletions cada_prio/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def cli(ctx: click.Context, verbose: bool):
required=True,
)
@click.option("--path-hpo-obo", type=str, help="path HPO OBO file", required=True)
@click.option(
"--path-embedding-params", type=str, help="optional path to JSON file with embedding parameters"
)
@click.option("--cpus", type=int, help="number of CPUs to use", default=1)
@click.pass_context
def cli_train_model(
Expand All @@ -39,6 +42,7 @@ def cli_train_model(
path_gene_hpo_links: str,
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
path_embedding_params: typing.Optional[str],
cpus: int,
):
"""train model"""
Expand All @@ -49,6 +53,7 @@ def cli_train_model(
path_gene_hpo_links,
path_hpo_genes_to_phenotype,
path_hpo_obo,
path_embedding_params,
cpus,
)

Expand Down
37 changes: 33 additions & 4 deletions cada_prio/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,20 @@ class EmbeddingParams:
min_count: int = 1
#: Set the batch_words in the fitting
batch_words: int = 4
#: RNG seed for embedding
seed_embedding: int = 1
#: RNG seed for fitting
seed_fit: int = 1


def build_and_fit_model(*, clinvar_gen2phen, hpo_gen2phen, hpo_ontology, cpus: int = 1):
def build_and_fit_model(
*,
clinvar_gen2phen,
hpo_gen2phen,
hpo_ontology,
embedding_params: EmbeddingParams,
cpus: int = 1
):
# create graph edges combining HPO hierarchy and training edges from ClinVar
logger.info("Constructing training graph ...")
logger.info("- building edges ...")
Expand All @@ -233,16 +244,16 @@ def build_and_fit_model(*, clinvar_gen2phen, hpo_gen2phen, hpo_ontology, cpus: i
yield_gene2phen_edges(clinvar_gen2phen),
)
)
with open("__trainign_edges.json", "wt") as outputf:
print(json.dumps(cattrs.unstructure(training_edges), indent=2), file=outputf)
logger.info("- graph construction")
training_graph = nx.Graph()
training_graph.add_edges_from(training_edges)
logger.info("... done constructing training graph with %s edges", len(training_edges))

logger.info("Computing the embedding / model fit...")
logger.info("- embedding graph")
embedding_params = EmbeddingParams()
logger.info(
"- using parameters:\n%s", json.dumps(cattrs.unstructure(embedding_params), indent=2)
)
embedding = node2vec.Node2Vec(
training_graph,
dimensions=embedding_params.dimensions,
Expand All @@ -251,12 +262,14 @@ def build_and_fit_model(*, clinvar_gen2phen, hpo_gen2phen, hpo_ontology, cpus: i
p=embedding_params.p,
q=embedding_params.q,
workers=cpus,
seed=embedding_params.seed_embedding,
)
logger.info("- fitting model")
model = embedding.fit(
window=embedding_params.window,
min_count=embedding_params.min_count,
batch_words=embedding_params.batch_words,
seed=embedding_params.seed_fit,
)
logger.info("... done computing the embedding")
return training_graph, model, embedding_params
Expand Down Expand Up @@ -299,15 +312,30 @@ def write_graph_and_model(
logger.info("... done saving embedding to")


def load_embedding_params(path_embedding_params: typing.Optional[str]):
if path_embedding_params:
logger.info("Loading embedding parameters from %s...", path_embedding_params)
with open(path_embedding_params, "rt") as inputf:
embedding_params_dict = json.load(inputf)
logger.info("... done loading embedding parameters")
else:
logger.info("Using default embedding parameters")
embedding_params_dict = {}

return cattrs.structure(embedding_params_dict, EmbeddingParams)


def run(
path_out: str,
path_hgnc_json: str,
path_gene_hpo_links: str,
path_hpo_genes_to_phenotype: str,
path_hpo_obo: str,
path_embedding_params: typing.Optional[str] = None,
cpus: int = 1,
):
# load all data
embedding_params = load_embedding_params(path_embedding_params)
ncbi_to_hgnc, hgnc_info = load_hgnc_info(path_hgnc_json)
clinvar_gen2phen = load_clinvar_gen2phen(path_gene_hpo_links)
hpo_gen2phen = load_hpo_gen2phen(path_hpo_genes_to_phenotype, ncbi_to_hgnc)
Expand All @@ -319,6 +347,7 @@ def run(
clinvar_gen2phen=clinvar_gen2phen,
hpo_gen2phen=hpo_gen2phen,
hpo_ontology=hpo_ontology,
embedding_params=embedding_params,
cpus=cpus,
)
# write out graph and model
Expand Down

0 comments on commit bbd5d86

Please sign in to comment.