Skip to content

sayakpaul/maxim-tf

Repository files navigation

MAXIM in TensorFlow

HugginFace badge Open In Colab TensorFlow 2.10 Models on TF-Hub HugginFace badge

Implementation of MAXIM [1] in TensorFlow. This project received the #TFCommunitySpotlight Award.

MAXIM introduces a backbone that can tackle image denoising, dehazing, deblurring, deraining, and enhancement.

Taken from the MAXIM paper

The weights of different MAXIM variants are in JAX and they're available in [2].

You can find all the TensorFlow MAXIM models here on TensorFlow Hub as well as on Hugging Face Hub.

You can try out the models on Hugging Face Spaces:

If you prefer Colab Notebooks, then you can check them out here.

Model conversion to TensorFlow from JAX

Blocks and layers related to MAXIM are implemented in the maxim directory.

convert_to_tf.py script is leveraged to initialize a particular MAXIM model variant and a pre-trained checkpoint and then run the conversion to TensorFlow. Refer to the usage section of the script to know more.

This script serializes the model weights in .h5 as as well pushes the SavedModel to Hugging Face Hub. For the latter, you need to authenticate yourself if not already done (huggingface-cli login).

This TensorFlow implementation is in close alignment with [2]. The author of this repository has reused some code blocks from [2] (with credits) to do.

Results and model variants

A comprehensive table is available here. The author of this repository validated the results with the converted models qualitatively.

Inference with the provided sample images

You can run the run_eval.py script for this purpose.

Image Denoising (click to expand)
python3 maxim/run_eval.py --task Denoising --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-3_denoising_sidd/1/uncompressed \
  --input_dir images/Denoising --output_dir images/Results --has_target=False --dynamic_resize=True
Image Deblurring (click to expand)
python3 maxim/run_eval.py --task Deblurring --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-3_deblurring_gopro/1/uncompressed \
  --input_dir images/Deblurring --output_dir images/Results --has_target=False --dynamic_resize=True
Image Deraining (click to expand)

Rain streak:

python3 maxim/run_eval.py --task Deraining --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_deraining_rain13k/1/uncompressed \
  --input_dir images/Deraining --output_dir images/Results --has_target=False --dynamic_resize=True

Rain drop:

python3 maxim/run_eval.py --task Deraining --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_deraining_raindrop/1/uncompressed \
  --input_dir images/Deraining --output_dir images/Results --has_target=False --dynamic_resize=True
Image Dehazing (click to expand)

Indoor:

python3 maxim/run_eval.py --task Dehazing --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_dehazing_sots-indoor/1/uncompressed \
  --input_dir images/Dehazing --output_dir images/Results --has_target=False --dynamic_resize=True

Outdoor:

python3 maxim/run_eval.py --task Dehazing --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_dehazing_sots-outdoor/1/uncompressed \
  --input_dir images/Dehazing --output_dir images/Results --has_target=False --dynamic_resize=True
Image Enhancement (click to expand)

Low-light enhancement:

python3 maxim/run_eval.py --task Enhancement --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_enhancement_lol/1/uncompressed \
  --input_dir images/Enhancement --output_dir images/Results --has_target=False --dynamic_resize=True

Retouching:

python3 maxim/run_eval.py --task Enhancement --ckpt_path gs://tfhub-modules/sayakpaul/maxim_s-2_enhancement_fivek/1/uncompressed \
  --input_dir images/Enhancement --output_dir images/Results --has_target=False --dynamic_resize=True

Notes:

  • The run_eval.py script is heavily inspired by the original one.
  • You can set dynamic_resize to False to obtain faster latency compromising the prediction quality.

XLA support

The models are XLA-supported. It can drammatically reduce the latency. Refer to the benchmark_xla.py script for more.

Known limitations

These are some of the known limitations of the current implementation. These are all open for contributions.

Supporting arbitrary image resolutions

MAXIM supports arbitrary image resolutions. However, the available TensorFlow models were exported with (256, 256, 3) resolution. So, a crude form of resizing is done on the input images to perform inference with the available models. This impacts the results quite a bit. This issue is discussed in more details here. Some work has been started to fix this behaviour (without ETA). I am thankful to Amy Roberts from Hugging Face for guiding me in the right direction.

But these models can be extended to support arbitrary resolution. Refer to this notebook for more details. Specifically, for a given task and an image, a new version of the model is instantiated and the weights of the available model are copied into the new model instance. This is a time-consuming process and isn't very efficient.

Output mismatches

The outputs of the TF and JAX models vary slightly. This is because of the differences in the implementation of different layers (resizing layer mainly). Even though the differences in the outputs of individual blocks of TF and JAX models are small, they add up, in the end, to be larger than one might expect.

With all that said, the qualitative performance doesn't seem to be disturbed at all.

Call for contributions

  • Add a minimal training notebook.
  • Fix any of the known limitations stated above

Acknowledgements

  • ML Developer Programs' team at Google for providing Google Cloud credits.
  • Gustavo Martins from Google for initial discussions and reviews of the codebase.
  • Amy Roberts from Hugging Face for guiding me in the right direction for handling arbitrary input shapes.

References

[1] MAXIM paper: https://arxiv.org/abs/2201.02973

[2] MAXIM official GitHub: https://github.com/google-research/maxim