Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run validation every N batches #19562

Closed
mpetteno opened this issue Apr 19, 2024 · 8 comments
Closed

Run validation every N batches #19562

mpetteno opened this issue Apr 19, 2024 · 8 comments
Assignees
Labels
type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.

Comments

@mpetteno
Copy link

Hi everyone.

This is not really an issue but more of a help request, I'm sorry if this maybe is not the right place to ask.
I'm trying to train a model over an entire dataset loaded with Tensorflow tf.data API: since the dataset contains a lot of samples, my idea is to run the training for only one epoch but launch the validation every n batches in order to be able to save the best model according to the validation losss. Is this possible in Keras?

Looking at the code in fit it seems to me that the evaluation is run only at the end of each epoch and even if I set steps_per_epoch then the next epoch the elements will be taken from the beginning again, thus I'm not using the whole dataset. Is that correct or am I missing something?

I tried to build a custom callback to run the validation and it works, but then I have to manuall handle the history and other callbacks that runs only at the end of each epoch.

Thanks for your help.

@fchollet
Copy link
Member

Looking at the code in fit it seems to me that the evaluation is run only at the end of each epoch and even if I set steps_per_epoch then the next epoch the elements will be taken from the beginning again, thus I'm not using the whole dataset. Is that correct or am I missing something?

Can you check that in practice with an example? I expect that you will continue drawing from the same dataset. It needs to be able to generate steps_per_epoch * epochs batches. You may need to call .repeat() on it to achieve it.

@mpetteno
Copy link
Author

mpetteno commented Apr 19, 2024

Yes the point is that I don't want to use .repeat() beacause my dataset is already big enough. I want to be able to run a single epoch (using the whole train dataset) but run the validation every N train steps and save the model if there's an improvement of the validation loss. Now, for the last part I know I can use ModelCheckpoint callback but I'm not sure if I can use the built-in fit method or not to achieve the first requirement.

This is the code for the custom callback that handle this: the main issue is that there are native callbacks that runs at the end of each epoch and not during the batch step (like the History callback).

class ValidationEveryNBatches(keras.callbacks.Callback):

    def __init__(self,
                 validation_data,
                 validation_batch_size,
                 validation_freq=1,
                 validation_steps=None,
                 validation_callbacks=None,
                 verbose='auto'):
        super(ValidationEveryNBatches, self).__init__()
        self._validation_freq = validation_freq
        self._validation_batch_size = validation_batch_size
        self._validation_steps = validation_steps
        self._validation_callbacks = validation_callbacks
        self._verbose = verbose
        self._val_x, self._val_y, self._val_sample_weight = keras.utils.unpack_x_y_sample_weight(validation_data)
        self._validations_count = 0

    def on_batch_end(self, batch, logs=None):
        if (batch + 1) % self._validation_freq == 0:
            self._validations_count += 1
            io_utils.print_msg('\n---------------------------------------------------------------------------------')
            io_utils.print_msg(f'Running validation after processing batch {batch + 1}. '
                               f'Total validations runs: {self._validations_count}')
            val_logs = self.model.evaluate(
                x=self._val_x,
                y=self._val_y,
                sample_weight=self._val_sample_weight,
                batch_size=self._validation_batch_size,
                steps=self._validation_steps,
                callbacks=self._validation_callbacks,
                return_dict=True,
                verbose=self._verbose
            )
            io_utils.print_msg('---------------------------------------------------------------------------------')
            val_logs = {"val_" + name: val for name, val in val_logs.items()}
            if logs:
                logs.update(val_logs)
            self.model.reset_metrics()

@fchollet
Copy link
Member

Fair enough, doing this in your own callback is a good solution. Lets you customize it to do whatever you want.

@mpetteno
Copy link
Author

mpetteno commented Apr 19, 2024

But how can I update the history in order to keep track also of the training loss? Because at the moment the history for that will be saved at the end of each epoch that in my case it's only one

@fchollet
Copy link
Member

Just make your callback create & update its own metrics/loss dict?

@mpetteno
Copy link
Author

Ok, not sure if I got it but I'll try that thanks. I was hoping that there was a way to do this by using the standard loop in the fit method. At the end what I want it's like running N epochs but reading the training dataset from a certain index and not from the beginning at each epoch.

@SuryanarayanaY SuryanarayanaY added type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited. stat:awaiting response from contributor labels Apr 22, 2024
@mpetteno
Copy link
Author

mpetteno commented May 3, 2024

The issue can be closed because it is not really an issue. Setting the seed and deterministic=True while loading the dataset with tf.data API helped in understing how Keras works and to achieve the desired result.
Keras support validation every N batches natively by using steps_per_epoch just be sure to call .repeat() on the dataset if necessary, as stated by @fchollet above.

The custom callback is not necessary.

Thanks for your help in clarifying this.

@mpetteno mpetteno closed this as completed May 3, 2024
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Projects
None yet
Development

No branches or pull requests

3 participants