Skip to content

necla-ml/Diff-JPEG

Repository files navigation

Differentiable JPEG: The Devil is in the Details

arXiv License Framework

Christoph Reich , Biplob Debnath , Deep Patel & Srimat Chakradhar

1

This repository includes the official and maintained implementation of the differentiable JPEG approach proposed in the paper Differentiable JPEG: The Devil is in the Details.

Abstract

JPEG remains one of the most widespread lossy image coding methods. However, the non-differentiable nature of JPEG restricts the application in deep learning pipelines. Several differentiable approximations of JPEG have recently been proposed to address this issue. This paper conducts a comprehensive review of existing differentiable JPEG approaches and identifies critical details that have been missed by previous methods. To this end, we propose a novel differentiable JPEG approach, overcoming previous limitations. Our approach is differentiable w.r.t. the input image, the JPEG quality, the quantization tables, and the color conversion parameters. We evaluate the forward and backward performance of our differentiable JPEG approach against existing methods. Additionally, extensive ablations are performed to evaluate crucial design choices. Our proposed differentiable JPEG resembles the (non-differentiable) reference implementation best, significantly surpassing the recent-best differentiable approach by 3.47dB (PSNR) on average. For strong compression rates, we can even improve PSNR by 9.51dB. Strong adversarial attack results are yielded by our differentiable JPEG, demonstrating the effective gradient approximation.

If you use our differentiable JPEG or find this research useful in your work, please cite our paper:

@inproceedings{Reich2024,
    author={Reich, Christoph and Debnath, Biplob and Patel, Deep and Chakradhar, Srimat},
    title={{Differentiable JPEG: The Devil is in the Details}},
    booktitle={{WACV}},
    year={2024}
}

Installation

Our differentiable JPEG implementation can be installed as a Python package by running:

pip install git+https://github.com/necla-ml/Diff-JPEG

All dependencies are listed in requirements.txt.

Usage

We offer both a functional and class (nn.Module) implementation of our differentiable JPEG approach. Note beyond the examples provided here we also have an example.py file.

The following example showcases the use of the functional implementation.

import torch
import torchvision
from torch import Tensor

from diff_jpeg import diff_jpeg_coding

# Load test image and reshape to [B, 3, H, W]
image: Tensor = torchvision.io.read_image("test_images/test_image.png").float()[None]
# Init JPEG quality
jpeg_quality: Tensor = torch.tensor([2.0])
# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding(image_rgb=image, jpeg_quality=jpeg_quality)

In the following code example, the class (nn.Module) implementation is used.

import torch
import torch.nn as nn
import torchvision
from torch import Tensor

from diff_jpeg import DiffJPEGCoding

# Init module
diff_jpeg_coding_module: nn.Module = DiffJPEGCoding()
# Load test image and reshape to [B, 3, H, W]
image: Tensor = torchvision.io.read_image("test_images/test_image.png").float()[None]
# Init JPEG quality
jpeg_quality: Tensor = torch.tensor([19.04])
# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding_module(image_rgb=image, jpeg_quality=jpeg_quality)

STE Variant

To utilize the proposed straight-through estimator (STE) variant just set the ste: bool = True parameter.

# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding(image_rgb=image, jpeg_quality=jpeg_quality, ste=True)
# Init module
diff_jpeg_coding_module: nn.Module = DiffJPEGCoding(ste=True)

Custom Quantization Tables

Both the diff_jpeg_coding function and the forward function of DiffJPEGCoding offer the option to use custom quantization tables. Just use the quantization_table_y: Optional[Tensor] and quantization_table_c: Optional[Tensor] parameter. Both parameters are required to be a torch.Tensor of the shape [8, 8]. If no quantization table is given (or set to None), the respective standard JPEG quantization tables are utilized.

Here we provide two examples of using a custom quantization table.

import torch
import torchvision
from torch import Tensor

from diff_jpeg import diff_jpeg_coding

# Load test image and reshape to [B, 3, H, W]
image: Tensor = torchvision.io.read_image("test_images/test_image.png").float()[None]
# Init JPEG quality
jpeg_quality: Tensor = torch.tensor([2.0])
# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding(
    image_rgb=image,
    jpeg_quality=jpeg_quality,
    quantization_table_y=torch.randint(low=1, high=256, size=(8, 8)),
    quantization_table_c=torch.randint(low=1, high=256, size=(8, 8)),
)
import torch
import torch.nn as nn
import torchvision
from torch import Tensor

from diff_jpeg import DiffJPEGCoding

# Init module
diff_jpeg_coding_module: nn.Module = DiffJPEGCoding()
# Load test image and reshape to [B, 3, H, W]
image: Tensor = torchvision.io.read_image("test_images/test_image.png").float()[None]
# Init JPEG quality
jpeg_quality: Tensor = torch.tensor([19.04])
# Perform differentiable JPEG coding
image_coded: Tensor = diff_jpeg_coding_module(
    image_rgb=image,
    jpeg_quality=jpeg_quality,
    quantization_table_y=torch.randint(low=1, high=256, size=(8, 8)),
    quantization_table_c=torch.randint(low=1, high=256, size=(8, 8)),
)

Issues

If you encounter any issues with this implementation please open a GitHub issue!