Hands-on: Keras Custom Training with Custom Callbacks

You want to use low-level training and evaluation loops and don’t want to miss the convenience of the Keras callbacks? No problem! In this tutorial, we will implement a Keras callback in a low-level training and evaluation loop. For this article, I assume that you are familiar with the basic terminology and principles of Machine Learning and have done some toy examples with Keras.

HINT: When I first got in touch with Tensorflow and Keras, I was overwhelmed with the different terminology of Machine Learning and was happy that I found the Machine Learning Glossary by Google.

With a callback, it is possible to perform different actions at several stages while training your deep neural network. By default, Keras provides easy and convenient built-in callbacks like ReduceLROnPlateau (Reduce learning Rate on Plateau) or EarlyStopping for your training and evaluation loop. Their usage is covered in the .fit() method of a model and can be passed as a list of callbacks:

At each stage of the training (e.g. at the start or end of an epoch) all relevant methods will be called automatically. You can find more detailed information about the callback methods in the Keras documentation. To write your Callbacks you should give the article Building Custom Callbacks with Keras and TensorFlow 2 by B. Chen a try.

I won’t go into detail about how to implement a custom training loop. So if you want to have more information about that I recommend the Keras Classification Tutorial. For this tutorial, we will slightly modify the mentioned classification tutorial for the MNIST dataset. The MNIST dataset contains handwritten digits and is commonly used as toy example for Machine Learning. Furthermore, we will implement the widely used callback ReduceLROnPlateau and add also an exponential reduction of the learning rate.

This tutorial can be found on Github. Use git clone to run it locally in Jupyter notebook.

git clone https://github.com/Pelk89/TF_Custom_Training_Callbacks.git

Reduce Learning Rate On Plateau

Reducing the learning rate is often used when a metric has stopped improving. Moreover once the learning stagnates Machine Learning Models often benefit from reducing the learning rate exponentially. But what if we want to reduce the learning rate linearly when the metric has stopped improving? No problem at all!

To understand what’s going on under the hood we need to go deeper into Keras library. Let’s dig a little bit to the heart of the Keras Callbacks in tf.keras.callbacks.Callback. Thanks to Francois Chollet and the other authors we will find an incredibly clean and comprehensible code for the class ReduceLROnPlateau. With this basis, we easily can reuse the existing code and use it for our purposes. How great is that?

When we analyze the class we will see for example the explanation for the arguments that can be passed to the __init__ method. Some of the arguments are listed below. A more detailed overview can be found at Keras ReduceLROnPlateau documentation.


  • monitor: quantity (e.g. validation loss) to be monitored.
  • factor: factor by which the learning rate will be reduced.
    new_lr = lr * factor.
  • patience: number of epochs with no improvement after which learning rate will be reduced.
  • cooldown: number of epochs to wait before resuming normal operation after learning rate has been reduced.

But more importantly, we can identify the underlying algorithm and when it’s called while training our neural network. The class uses two methods:

  • on_train_begin(): Called once when training begins and resetting wait and cooldown timer
  • on_epoch_end(): Called on every end of an epoch and change the learning rate exponentially depended on the defined cooldown timer, patience and factor.

Modify Reduce Learning Rate On Plateau

Now that we have a detailed overview of the Reduce Learning Rate On Plateau algorithm, we can modify it for our needs. Moreover, we can implement the algorithm at the right position in our custom training and evaluation loop. In this tutorial, we will monitor the validation loss of our training model. First of all, we need to know that we can not use the arguments self.model as the model is an instance of keras.models.Model and a reference of the model being trained. So everything with self.model needs to be replaced and passed as an argument into the __init__ or on_epoch_end() method. I marked the changes in the code as following:

## Custom modification: "Reason for Modification"

For reducing the learning rate linearly we need to add arguments to __init__ method and also modify the method on_epoch_end().

Modfiy __init__ and _reset

Because we measure the validation loss of our neural network we can remove the monitor argument. Next, we want to set a boolean reduce_lin to control whether we want to reduce the learning rate linearly. Likewise, we need to pass the optimizer to the method.

Next we need to remove the self.monitor in the _reset() method:

Modify on_epoch_end()

Normally the method is called on every end of an epoch during training. In our case we additionally need to pass the loss of our validation dataset and the epoch to epoch_end() on every epoch end of our training. To prevent the learning rate to be reduced to a negative value, we also add an error handling to the method.

Putting it all together

Let’s put everything together and implement it in our custom training loop! First, we need to import and initiate the class like this:

We set reduce_lin on true since we want to reduce the learning rate linearly. Furthermore, we know from our analysis of the callback, that the callback will be called once on training start and whenever the epochs end. Simple add .train_begin() before the training loop is starting to reset the cooldown and wait timer. Next, add .epoch_end() at the end of your training loop.

Training your neural network

You made it! Finally, we can start the training of our neural network! Start the training of the neural network by simply cloning the repository and run the cells in the Jupyter Notebook.

After 10 epochs when the validation loss is not improving the learning rate will be reduced linearly!


By default, Keras provides convenient callbacks built-in callbacks for your training and evaluation loop for the .fit() method of a model. But when you write your custom training loop to get a low-level control for training and evaluation it’s not simply possible to use built-in callbacks.

In this tutorial, we implemented the famous reduce learning rate on plateau callback by using its natively implemented callback. Moreover, we modified the callback to reduce the learning rate linearly. You are now able to use it your own custom training and evaluation loop!

Thanks for reading this article. I continuously want to improve my skill in machine learning. So If you have anything to add or have any ideas for this topic feel free to leave a comment.

Machine Learning enthusiasts. Always on the path for continuous and lifelong learning.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store