Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes Fairseq compatible with WandB when running in SageMaker so that experiments can be tracked #5316

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
11 changes: 9 additions & 2 deletions fairseq/logging/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,16 +502,23 @@ def log(self, stats, tag=None, step=None):
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)

def is_running_in_sagemaker(self):
return "SM_TRAINING_ENV" in os.environ

def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)

def update_config(self, config):
"""Log latest configuration."""
if wandb is not None:
if wandb is not None and not self.is_running_in_sagemaker():
wandb.config.update(config)
self.wrapped_bar.update_config(config)
self.wrapped_bar.update_config(config)
else:
print(
"Running in AWS SageMaker , Config updated from environment variables"
)

def _log_to_wandb(self, stats, tag=None, step=None):
if wandb is None:
Expand Down