diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index a488265137..6a5a42a9cf 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -524,7 +524,7 @@ def _get_iterator_for_epoch( # TODO: Below is a lazy implementation which discard the final batch regardless # of whether it is a full batch or not. - total_num_itrs = len(self.epoch_batch_sampler) - 1 + total_num_itrs = len(itr) - 1 itr.take(total_num_itrs) logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}")