Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

formatted python code #1964

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
413 changes: 216 additions & 197 deletions ann/src/main/python/dataflow/faiss_index_bq_dataset.py

Large diffs are not rendered by default.

233 changes: 122 additions & 111 deletions pushservice/src/main/python/models/heavy_ranking/deep_norm.py
@@ -1,136 +1,147 @@
"""
Training job for the heavy ranker of the push notification service.
"""
from datetime import datetime
import json
import os
from datetime import datetime

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import logging

import twml

from ..libs.metric_fn_utils import flip_disliked_labels, get_metric_fn
from ..libs.model_utils import read_config
from ..libs.warm_start_utils import get_feature_list_for_heavy_ranking, warm_start_checkpoint
from ..libs.warm_start_utils import (
get_feature_list_for_heavy_ranking,
warm_start_checkpoint,
)
from .features import get_feature_config
from .model_pools import ALL_MODELS
from .params import load_graph_params
from .run_args import get_training_arg_parser

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import logging


def main() -> None:
args, _ = get_training_arg_parser().parse_known_args()
logging.info(f"Parsed args: {args}")

params = load_graph_params(args)
logging.info(f"Loaded graph params: {params}")

param_file = os.path.join(args.save_dir, "params.json")
logging.info(f"Saving graph params to: {param_file}")
with tf.io.gfile.GFile(param_file, mode="w") as file:
json.dump(params.json(), file, ensure_ascii=False, indent=4)

logging.info(f"Get Feature Config: {args.feature_list}")
feature_list = read_config(args.feature_list).items()
feature_config = get_feature_config(
data_spec_path=args.data_spec,
params=params,
feature_list_provided=feature_list,
)
feature_list_path = args.feature_list

warm_start_from = args.warm_start_from
if args.warm_start_base_dir:
logging.info(f"Get warm started model from: {args.warm_start_base_dir}.")
args, _ = get_training_arg_parser().parse_known_args()
logging.info(f"Parsed args: {args}")

params = load_graph_params(args)
logging.info(f"Loaded graph params: {params}")

param_file = os.path.join(args.save_dir, "params.json")
logging.info(f"Saving graph params to: {param_file}")
with tf.io.gfile.GFile(param_file, mode="w") as file:
json.dump(params.json(), file, ensure_ascii=False, indent=4)

logging.info(f"Get Feature Config: {args.feature_list}")
feature_list = read_config(args.feature_list).items()
feature_config = get_feature_config(
data_spec_path=args.data_spec,
params=params,
feature_list_provided=feature_list,
)
feature_list_path = args.feature_list

warm_start_from = args.warm_start_from
if args.warm_start_base_dir:
logging.info(f"Get warm started model from: {args.warm_start_base_dir}.")

continuous_binary_feat_list_save_path = os.path.join(
args.warm_start_base_dir, "continuous_binary_feat_list.json"
)
warm_start_folder = os.path.join(args.warm_start_base_dir, "best_checkpoint")
job_name = os.path.basename(args.save_dir)
ws_output_ckpt_folder = os.path.join(
args.warm_start_base_dir, f"warm_start_for_{job_name}"
)
if tf.io.gfile.exists(ws_output_ckpt_folder):
tf.io.gfile.rmtree(ws_output_ckpt_folder)

tf.io.gfile.mkdir(ws_output_ckpt_folder)

warm_start_from = warm_start_checkpoint(
warm_start_folder,
continuous_binary_feat_list_save_path,
feature_list_path,
args.data_spec,
ws_output_ckpt_folder,
)
logging.info(f"Created warm_start_from_ckpt {warm_start_from}.")

logging.info("Build Trainer.")
metric_fn = get_metric_fn(
"OONC_Engagement" if len(params.tasks) == 2 else "OONC", False
)

trainer = twml.trainers.DataRecordTrainer(
name="magic_recs",
params=args,
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(
*args
),
save_dir=args.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=flip_disliked_labels(metric_fn),
warm_start_from=warm_start_from,
)

logging.info("Build train and eval input functions.")
train_input_fn = trainer.get_train_input_fn(shuffle=True)
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)

learn = trainer.learn
if args.distributed or args.num_workers is not None:
learn = trainer.train_and_evaluate

if not args.directly_export_best:
logging.info("Starting training")
start = datetime.now()
learn(
early_stop_minimize=False,
early_stop_metric="pr_auc_unweighted_OONC",
early_stop_patience=args.early_stop_patience,
early_stop_tolerance=args.early_stop_tolerance,
eval_input_fn=eval_input_fn,
train_input_fn=train_input_fn,
)
logging.info(f"Total training time: {datetime.now() - start}")
else:
logging.info("Directly exporting the model")

if not args.export_dir:
args.export_dir = os.path.join(args.save_dir, "exported_models")

logging.info(f"Exporting the model to {args.export_dir}.")
start = datetime.now()
twml.contrib.export.export_fn.export_all_models(
trainer=trainer,
export_dir=args.export_dir,
parse_fn=feature_config.get_parse_fn(),
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
)

logging.info(f"Total model export time: {datetime.now() - start}")
logging.info(f"The MLP directory is: {args.save_dir}")

continuous_binary_feat_list_save_path = os.path.join(
args.warm_start_base_dir, "continuous_binary_feat_list.json"
args.save_dir, "continuous_binary_feat_list.json"
)
warm_start_folder = os.path.join(args.warm_start_base_dir, "best_checkpoint")
job_name = os.path.basename(args.save_dir)
ws_output_ckpt_folder = os.path.join(args.warm_start_base_dir, f"warm_start_for_{job_name}")
if tf.io.gfile.exists(ws_output_ckpt_folder):
tf.io.gfile.rmtree(ws_output_ckpt_folder)

tf.io.gfile.mkdir(ws_output_ckpt_folder)

warm_start_from = warm_start_checkpoint(
warm_start_folder,
continuous_binary_feat_list_save_path,
feature_list_path,
args.data_spec,
ws_output_ckpt_folder,
logging.info(
f"Saving the list of continuous and binary features to {continuous_binary_feat_list_save_path}."
)
logging.info(f"Created warm_start_from_ckpt {warm_start_from}.")

logging.info("Build Trainer.")
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)

trainer = twml.trainers.DataRecordTrainer(
name="magic_recs",
params=args,
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
save_dir=args.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=flip_disliked_labels(metric_fn),
warm_start_from=warm_start_from,
)

logging.info("Build train and eval input functions.")
train_input_fn = trainer.get_train_input_fn(shuffle=True)
eval_input_fn = trainer.get_eval_input_fn(repeat=False, shuffle=False)

learn = trainer.learn
if args.distributed or args.num_workers is not None:
learn = trainer.train_and_evaluate

if not args.directly_export_best:
logging.info("Starting training")
start = datetime.now()
learn(
early_stop_minimize=False,
early_stop_metric="pr_auc_unweighted_OONC",
early_stop_patience=args.early_stop_patience,
early_stop_tolerance=args.early_stop_tolerance,
eval_input_fn=eval_input_fn,
train_input_fn=train_input_fn,
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
feature_list_path, args.data_spec
)
twml.util.write_file(
continuous_binary_feat_list_save_path,
continuous_binary_feat_list,
encode="json",
)
logging.info(f"Total training time: {datetime.now() - start}")
else:
logging.info("Directly exporting the model")

if not args.export_dir:
args.export_dir = os.path.join(args.save_dir, "exported_models")

logging.info(f"Exporting the model to {args.export_dir}.")
start = datetime.now()
twml.contrib.export.export_fn.export_all_models(
trainer=trainer,
export_dir=args.export_dir,
parse_fn=feature_config.get_parse_fn(),
serving_input_receiver_fn=feature_config.get_serving_input_receiver_fn(),
export_output_fn=twml.export_output_fns.batch_prediction_continuous_output_fn,
)

logging.info(f"Total model export time: {datetime.now() - start}")
logging.info(f"The MLP directory is: {args.save_dir}")

continuous_binary_feat_list_save_path = os.path.join(
args.save_dir, "continuous_binary_feat_list.json"
)
logging.info(
f"Saving the list of continuous and binary features to {continuous_binary_feat_list_save_path}."
)
continuous_binary_feat_list = get_feature_list_for_heavy_ranking(
feature_list_path, args.data_spec
)
twml.util.write_file(
continuous_binary_feat_list_save_path, continuous_binary_feat_list, encode="json"
)


if __name__ == "__main__":
main()
logging.info("Done.")
main()
logging.info("Done.")
84 changes: 45 additions & 39 deletions pushservice/src/main/python/models/heavy_ranking/eval.py
Expand Up @@ -4,6 +4,8 @@

from datetime import datetime

from tensorflow.compat.v1 import logging

import twml

from ..libs.metric_fn_utils import get_metric_fn
Expand All @@ -13,47 +15,51 @@
from .params import load_graph_params
from .run_args import get_eval_arg_parser

from tensorflow.compat.v1 import logging


def main():
args, _ = get_eval_arg_parser().parse_known_args()
logging.info(f"Parsed args: {args}")

params = load_graph_params(args)
logging.info(f"Loaded graph params: {params}")

logging.info(f"Get Feature Config: {args.feature_list}")
feature_list = read_config(args.feature_list).items()
feature_config = get_feature_config(
data_spec_path=args.data_spec,
params=params,
feature_list_provided=feature_list,
)

logging.info("Build DataRecordTrainer.")
metric_fn = get_metric_fn("OONC_Engagement" if len(params.tasks) == 2 else "OONC", False)

trainer = twml.trainers.DataRecordTrainer(
name="magic_recs",
params=args,
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(*args),
save_dir=args.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=metric_fn,
)

logging.info("Run the evaluation.")
start = datetime.now()
trainer._estimator.evaluate(
input_fn=trainer.get_eval_input_fn(repeat=False, shuffle=False),
steps=None if (args.eval_steps is not None and args.eval_steps < 0) else args.eval_steps,
checkpoint_path=args.eval_checkpoint,
)
logging.info(f"Evaluating time: {datetime.now() - start}.")
args, _ = get_eval_arg_parser().parse_known_args()
logging.info(f"Parsed args: {args}")

params = load_graph_params(args)
logging.info(f"Loaded graph params: {params}")

logging.info(f"Get Feature Config: {args.feature_list}")
feature_list = read_config(args.feature_list).items()
feature_config = get_feature_config(
data_spec_path=args.data_spec,
params=params,
feature_list_provided=feature_list,
)

logging.info("Build DataRecordTrainer.")
metric_fn = get_metric_fn(
"OONC_Engagement" if len(params.tasks) == 2 else "OONC", False
)

trainer = twml.trainers.DataRecordTrainer(
name="magic_recs",
params=args,
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(
*args
),
save_dir=args.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=metric_fn,
)

logging.info("Run the evaluation.")
start = datetime.now()
trainer._estimator.evaluate(
input_fn=trainer.get_eval_input_fn(repeat=False, shuffle=False),
steps=None
if (args.eval_steps is not None and args.eval_steps < 0)
else args.eval_steps,
checkpoint_path=args.eval_checkpoint,
)
logging.info(f"Evaluating time: {datetime.now() - start}.")


if __name__ == "__main__":
main()
logging.info("Job done.")
main()
logging.info("Job done.")