Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL w/o torchaudio dependency #5537

Open
wants to merge 69 commits into
base: master
Choose a base branch
from
Open

Conversation

wanchichen
Copy link
Contributor

@wanchichen wanchichen commented Nov 7, 2023

What?

This PR allows for HuBERT pre-training w/o using torchaudio, allowing for more customization and use of different ESPnet components. It also introduces some tricks to better support large-scale training.

Features:

  • Flash Attention (by @pyf98)
  • Activation checkpointing in E-Branchformer
  • Convolutional Feature Extractor as frontend
  • Convolutional positional embeddings
  • Efficient batch sampler for large-scale
  • Efficient multi-shard iterator for large-scale
  • WavLM noise augmentation
  • More efficient distributed training
  • Recommended tunable DDP args for multi-node training
  • Cross-Entropy loss for HuBERT SSL
  • Intermediate losses for HuBERT SSL
  • HuBERT SSL using filterbank input features
  • More detailed GPU memory reporting

Supports only HuBERT for Transformer and E-Branchformer so far.

To do:

  • Fbank training configs
  • Conformer, branchformer encoder
  • check in hubert.sh which implementation to use

Copy link
Contributor

mergify bot commented Nov 7, 2023

This pull request is now in conflict :(

@sw005320
Copy link
Contributor

sw005320 commented Jan 4, 2024

@wanchichen, can you restart this PR?

Copy link
Contributor

mergify bot commented Feb 6, 2024

This pull request is now in conflict :(

@sw005320
Copy link
Contributor

@wanchichen, let’s finish this PR.
There are a lot of conflicts now. So, please resolve them.

@simpleoier, please also review this PR.

@@ -512,11 +520,6 @@ def train_one_epoch(
):
assert isinstance(batch, dict), type(batch)

if distributed:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I performed several experiments (ASR, SSL) and found that this was useless. But we may need to experiment with other batch samplers to make sure

@@ -709,7 +731,7 @@ def train_one_epoch(
for iopt, optimizer in enumerate(optimizers):
if optim_idx is not None and iopt != optim_idx:
continue
optimizer.zero_grad()
optimizer.zero_grad(set_to_none=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the recommended setting by PyTorch now, it is both faster and more memory efficient. However, it may also slightly affect the learning curve

espnet2/train/trainer.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@simpleoier simpleoier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I made some comments.
One question I have is about the espnet2/ssl/mask. What is the benefit of a new mask module? If it is only used in HuBERT models during pre-training, a single function in the hubert model would be enough.

egs2/ml_superb/asr1/local/get_ssl_weights.py Outdated Show resolved Hide resolved
espnet/nets/pytorch_backend/transformer/attention.py Outdated Show resolved Hide resolved
espnet/nets/pytorch_backend/transformer/embedding.py Outdated Show resolved Hide resolved
espnet2/asr/encoder/e_branchformer_encoder.py Outdated Show resolved Hide resolved
espnet2/s2t/espnet_model.py Outdated Show resolved Hide resolved
espnet2/ssl/loss/hubert_loss_ce.py Outdated Show resolved Hide resolved
espnet2/tasks/abs_task.py Outdated Show resolved Hide resolved
espnet2/train/preprocessor.py Outdated Show resolved Hide resolved
espnet2/train/trainer.py Outdated Show resolved Hide resolved
@mergify mergify bot added the conflicts label May 8, 2024
@mergify mergify bot removed the conflicts label May 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants