From c7c478b92fe135838a2b9ec8341495c732a92401 Mon Sep 17 00:00:00 2001 From: Junteng Jia Date: Mon, 9 Oct 2023 14:13:06 -0700 Subject: [PATCH] fix iterator when loading from checkpoint (#5344) Co-authored-by: Junteng Jia --- fairseq/data/iterators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index a488265137..6a5a42a9cf 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -524,7 +524,7 @@ def _get_iterator_for_epoch( # TODO: Below is a lazy implementation which discard the final batch regardless # of whether it is a full batch or not. - total_num_itrs = len(self.epoch_batch_sampler) - 1 + total_num_itrs = len(itr) - 1 itr.take(total_num_itrs) logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}")