In this paper, we present a novel image inpainting technique using frequency domain information. Prior works on image inpainting predict the missing pixels by training neural networks using only the spatial domain information. However, these methods still struggle to reconstruct high-frequency details for real complex scenes, leading to a discrepancy in color, boundary artifacts, distorted patterns, and blurry textures. To alleviate these problems, we investigate if it is possible to obtain better performance by training the networks using frequency domain information (Discrete Fourier Transform) along with the spatial domain information. To this end, we propose a frequency-based deconvolution module that enables the network to learn the global context while selectively reconstructing the high-frequency components. We evaluate our proposed method on the publicly available datasets CelebA, Paris Streetview, and DTD texture dataset, and show that our method outperforms current state-of-the-art image inpainting techniques both qualitatively and quantitatively.
(a) Input images with missing regions; (b) DFT of first-stage reconstruction by our deconvolution network; (c) image inpainting results (after the second stage) of our proposed approach; and (d) GT image. The last column shows the prediction of the missing region obtained from our method and original pixel values for the same region in the GT image.- Python 3
- PyTorch 1.0
- NVIDIA GPU + CUDA cuDNN
- some dependencies like cv2, numpy etc.
- Clone this repo:
git clone https://github.com/hiyaroy12/DFT_inpainting.git
cd DFT_inpainting
- Install PyTorch and dependencies from http://pytorch.org
- Install python requirements:
pip install -r requirements.txt
We use CelebA, Paris StreetView and DTD texture datasets. You can download the datasets from the official websites to train the model.
We train our model on the irregular mask dataset similar to Yu et al.
We test our model on the irregular mask dataset provided by Liu et al. You can download the Irregular Mask Dataset from their website.
You can download the pre-trained models from the following links and keep them under ./checkpoints
directory.
CelebA | Paris StreetView | DTD texture dataset
Our model is trained in two stages: 1) training the deconvolution module and 2) training the refinement model.
- Train the model for
regular mask
using:
CUDA_VISIBLE_DEVICES=1 python stage_1/train_color-randombbox.py --epochs 100 --dataset celeba --use_regular 1
- Train the model for
irregular mask
using:
CUDA_VISIBLE_DEVICES=1 python stage_1/train_color_irregular.py --epochs 100 --dataset celeba --use_irregular 1
- Train the model for
regular mask
using:
python stage_2/CEEC/L1_adv_fft.py --n_epochs [] --dataset [] --use_regular 1
Example:
CUDA_VISIBLE_DEVICES=1 python stage_2/CEEC/L1_adv_fft.py --dataset celeba --n_epochs 300 --use_regular 1
- Train the model for
irregular mask
using:
python stage_2/CEEC/L1_adv_fft-irregular.py --n_epochs [] --dataset [] --use_irregular 1
Example:
CUDA_VISIBLE_DEVICES=1 python stage_2/CEEC/L1_adv-irregular.py --n_epochs 300 --dataset celeba --use_irregular 1
To test the model:
-
Please download the stage-1 pre-trained models for CelebA, Paris StreetView, and DTD datasets, put them into
logs/
(Please check the model path correctly in the code). Hereregular_{}_net.pth
andirregular_{}_net.pth
refer to regular and irregular masks. -
Please download the stage-2 pre-trained models for CelebA, Paris StreetView, and DTD datasets, put them into
L1_adv_fft_results/
. Hererandom_bbox_{}_generator.h5f
,random_bbox_{}_discriminator.h5f
refer to regular masks andirregular_{}_generator.h5f
,irregular_{}_discriminator.h5f
refer to irregular masks. -
Then for testing against your validation set for regular masks, run:
CUDA_VISIBLE_DEVICES=1 python CEEC/L1_adv_fft-test.py --dataset [dataset_name] --use_regular 1
Example:
CUDA_VISIBLE_DEVICES=1 python CEEC/L1_adv_fft-test.py --dataset celeba --use_regular 1
- Testing against your validation set for irregular masks, run:
CUDA_VISIBLE_DEVICES=1 python CEEC/L1_adv-irregular-test.py --dataset [dataset_name] --use_irregular 1 --perc_test_mask []
Example:
CUDA_VISIBLE_DEVICES=1 python CEEC/L1_adv-irregular-test.py --dataset celeba --use_irregular 1 --perc_test_mask 0.1
To evaluate the model, first run the model in test mode against your validation set and save the results on disk.
-
Please download the stage-1 pre-trained models for CelebA, Paris StreetView, and DTD datasets, put them into
logs/
(Please check the model path correctly in the code). Hereregular_{}_net.pth
andirregular_{}_net.pth
refer to regular and irregular masks. -
Please download the stage-2 pre-trained models for CelebA, Paris StreetView, and DTD datasets, put them into
L1_adv_fft_results/
. Hererandom_bbox_{}_generator.h5f
,random_bbox_{}_discriminator.h5f
refer to regular masks andirregular_{}_generator.h5f
,irregular_{}_discriminator.h5f
refer to irregular masks.
Then run metrics.py to evaluate the model using PSNR, SSIM and Mean Absolute Error:
CUDA_VISIBLE_DEVICES=9 python CEEC/metric_cal/metrics.py --data-path [path to validation set] --output-path [path to model output]
Example:
CUDA_VISIBLE_DEVICES=9 python CEEC/metric_cal/metrics.py --data-path ./CEEC_fft_infer_results/dtd_images/clean/ --output-path ./CEEC_fft_infer_results/dtd_images/reconstructed/
If you use this code for your research, please cite our paper Image inpainting using frequency domain priors :
@article{roy2021image,
title={Image inpainting using frequency-domain priors},
author={Roy, Hiya and Chaudhury, Subhajit and Yamasaki, Toshihiko and Hashimoto, Tatsuaki},
journal={Journal of Electronic Imaging},
volume={30},
number={2},
pages={023016},
year={2021},
publisher={International Society for Optics and Photonics}
}