forked from TencentARC/GFPGAN
/
app.py
203 lines (179 loc) · 7.73 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import io
import json
import argparse
import cv2
import glob
import numpy as np
import os
from basicsr.utils import imwrite
import sys
from os import listdir
from os.path import isfile, join
from gfpgan import GFPGANer
from werkzeug.utils import secure_filename
import shutil
import torch
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request, render_template, redirect, url_for
import sys
UPLOAD_FOLDER = 'inputs/whole_imgs'
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
app = Flask(__name__,static_folder='results')
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
@app.route("/", methods = ['GET', 'POST'])
def index():
return redirect(url_for('upload_file'))
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route("/upload", methods = ['GET', 'POST'])
def upload_file():
source = 'inputs/whole_imgs/'
destination = 'inputs/saved/'
out = 'results/restored_imgs'
for f in os.listdir(source):
os.remove(os.path.join(source, f))
for f in os.listdir(destination):
os.remove(os.path.join(destination, f))
for f in os.listdir(out):
os.remove(os.path.join(out, f))
if request.method == 'POST':
# check if the post request has the file part
if 'file' not in request.files:
print('No file part')
return redirect(request.url)
file = request.files['file']
# If the user does not select a file, the browser submits an
# empty file without a filename.
if file.filename == '':
print('No selected file')
return redirect(request.url)
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
return redirect(url_for('main', name=filename))
return '''
<!doctype html>
<title>Upload new File</title>
<h1>Upload new File</h1>
<body>
<div>
After you upload the image, click submit to run the model
</div>
<br>
<div>
<form method=post enctype=multipart/form-data>
<input type=file name=file>
<input type=submit value=Upload>
</form>
</div>
</body>
'''
@app.route('/main', methods=['POST','GET'])
def main():
"""Inference demo for GFPGAN.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image')
parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original')
parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2')
parser.add_argument('--model_path', type=str, default='GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler')
parser.add_argument(
'--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing')
parser.add_argument('--test_path', type=str, default='upload/', help='Input folder')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images')
parser.add_argument('--save_root', type=str, default='results', help='Path to save root')
parser.add_argument(
'--ext',
type=str,
default='auto',
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
sys.argv = ['--model_path GFPGANCleanv1-NoCE-C2.pth --upscale 2 --test_path inputs/whole_imgs --save_root results --bg_upsampler realesrgan']
args = parser.parse_args()
if args.test_path.endswith('/'):
args.test_path = args.test_path[:-1]
os.makedirs(args.save_root, exist_ok=True)
# background upsampler
if args.bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
from realesrgan import RealESRGANer
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
tile=args.bg_tile,
tile_pad=10,
pre_pad=0,
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
args.test_path = 'inputs/whole_imgs'
# set up GFPGAN restorer
restorer = GFPGANer(
model_path=args.model_path,
upscale=args.upscale,
arch=args.arch,
channel_multiplier=args.channel,
bg_upsampler=bg_upsampler)
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
print(img_list, '**')
count = 4
for img_path in img_list:
count -= 1
if count == 0:
break
# read image
print('yes')
img_name = os.path.basename(img_path)
print(f'Processing {img_name} ...')
basename, ext = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
# restore faces and background if necessary
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
# save faces
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
print('1')
# save cropped face
save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
imwrite(cropped_face, save_crop_path)
# save restored face
if args.suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save comparison image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
# save restored img
if restored_img is not None:
if args.ext == 'auto':
extension = ext[1:]
else:
extension = args.ext
if args.suffix is not None:
save_restore_path = os.path.join(args.save_root, 'restored_imgs',
f'{basename}_{args.suffix}.{extension}')
else:
save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}')
imwrite(restored_img, save_restore_path)
onlyfiles = [f for f in listdir('results/restored_imgs') if isfile(join('results/restored_imgs', f))]
try:
onlyfiles.remove('.DS_Store')
return render_template("index2.html", variable = onlyfiles[0])
except:
return render_template("index2.html", variable = onlyfiles[0])
if __name__ == '__main__':
app.run(host="0.0.0.0")
main()