Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style_transfer.py (Renew...) #42

Open
adetion opened this issue Sep 10, 2022 · 0 comments
Open

style_transfer.py (Renew...) #42

adetion opened this issue Sep 10, 2022 · 0 comments

Comments

@adetion
Copy link

adetion commented Sep 10, 2022

import argparse
import os
from argparse import Namespace

import numpy as np
import torch
import torchvision
from torch.nn import functional as F
from torchvision import transforms

from model.dualstylegan import DualStyleGAN
from model.encoder.psp import pSp
from util import save_image, load_image

from PIL import Image

class TestOptions():
def init(self):

    self.parser = argparse.ArgumentParser(description="Exemplar-Based Style Transfer")

    # Image path and file name to be transferred
    pic_path = './data/content/'
    pic_name = 'wowowo.jpeg'

    self.parser.add_argument("--content", type=str, default=pic_path + pic_name,
                             help="path of the content image")

    # Style setting / selection
    style_types = ['cartoon', 'caricature', 'anime', 'arcane', 'comic', 'pixar', 'slamdunk']
    # style_type : the default is 'cartoon',that's style_types[0]
    style_type = style_types[1]

    # style_id = [Cartoons]:(0-316)   [caricature]:(0-198)   [anime]:(0-173)
    #            [arcane]:(0-99)   [comic]:(0-100)   [pixar]:(0-121)   [slamdunk]:(0-119)
    # note : the value of style_id is an integer.
    #        For the portrait style comparison table, please refer to ./doc_images directory

    style_id = 52

    self.parser.add_argument("--style", type=str, default=style_type, help="target style type")

    self.parser.add_argument("--style_id", type=int, default=style_id, help="the id of the style image")


    self.parser.add_argument("--truncation", type=float, default=0.75,
                             help="truncation for intrinsic style code (content)")
    self.parser.add_argument("--weight", type=float, nargs=18, default=[0.75] * 7 + [1] * 11,
                             help="weight of the extrinsic style")

    # File header (including style name)
    self.parser.add_argument("--name", type=str, default=style_type+'_transfer',
                              help="filename to save the generated images")

    self.parser.add_argument("--preserve_color", action="store_true",
                             help="preserve the color of the content image")
    self.parser.add_argument("--model_path", type=str, default='./checkpoint/', help="path of the saved models")
    self.parser.add_argument("--model_name", type=str, default='generator.pt',
                             help="name of the saved dualstylegan")
    self.parser.add_argument("--output_path", type=str, default='./output/', help="path of the output images")
    self.parser.add_argument("--data_path", type=str, default='./data/', help="path of dataset")
    self.parser.add_argument("--align_face", action="store_true", help="apply face alignment to the content image")
    self.parser.add_argument("--exstyle_name", type=str, default=None, help="name of the extrinsic style codes")

    # python style_transfer.py --content ./data/content/unsplash-rDEOVtE7vOs.jpg --align_face --preserve_color \
    #  --style arcane --name arcane_transfer --style_id 13 \
    #  --weight 0.6 0.6 0.6 0.6 0.6 0.6 0.6 0.6 0.6 0.6 0.6 1 1 1 1 1 1 1

def parse(self):
    self.opt = self.parser.parse_args()
    if self.opt.exstyle_name is None:
        if os.path.exists(os.path.join(self.opt.model_path, self.opt.style, 'refined_exstyle_code.npy')):
            self.opt.exstyle_name = 'refined_exstyle_code.npy'
        else:
            self.opt.exstyle_name = 'exstyle_code.npy'
    args = vars(self.opt)
    print('Load options')
    for name, value in sorted(args.items()):
        print('%s: %s' % (str(name), str(value)))
    return self.opt

def run_alignment(args):
import dlib
from model.encoder.align_all_parallel import align_face
modelname = os.path.join(args.model_path, 'shape_predictor_68_face_landmarks.dat')
if not os.path.exists(modelname):
import wget, bz2
wget.download('http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', modelname + '.bz2')
zipfile = bz2.BZ2File(modelname + '.bz2')
data = zipfile.read()
open(modelname, 'wb').write(data)
predictor = dlib.shape_predictor(modelname)
aligned_image = align_face(filepath=args.content, predictor=predictor)
return aligned_image

if name == "main":

# device = "cuda" # Change to CPU running  9-10-2022 
# At the same time, modify the 11 lines in the root directory util.py
# from model.stylegan.op import conv2d_gradfix
# To:
# from model.stylegan.op_cpu import conv2d_gradfix
#
# At the same time, modify the 11 lines in model.py in the model/stylegan directory
# from model.stylegan.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
# To:
# from model.stylegan.op_cpu import FusedLeakyReLU,fused_leaky_relu, upfirdn2d, conv2d_gradfix

# device = "cuda"
device = "cpu"

parser = TestOptions()
args = parser.parse()
print('*' * 98)

# Determine the size of the input picture
# print(args.content)
# print(os.path.basename(args.content).split('.')[0] + '.' + os.path.basename(args.content).split('.')[1])

img = Image.open(args.content)
imgSize = img.size
w = img.width
h = img.height
f = img.format
if h != 1024:
    # print('Picture size:' + str(imgSize))
    # print('Picture size:Width ' + str(w), 'Height' + str(h), f)
    out_file = args.data_path + 'content/w_' + os.path.basename(args.content).split('.')[0] + '.' + os.path.basename(args.content).split('.')[1]
    # print(out_file)

    width = img.width
    height = img.height
    new_width = int(1024 * width / height)
    # print(new_width)
    out = img.resize((new_width, 1024), Image.Resampling.LANCZOS)
    out.save(out_file)
# else:
#   print('The picture is normal without modification')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

generator = DualStyleGAN(1024, 512, 8, 2, res_index=6)
generator.eval()

# ckpt = torch.load(os.path.join(args.model_path, args.style, args.model_name), map_location=lambda storage, loc: storage)
ckpt = torch.load(os.path.join(args.model_path, args.style, args.model_name), map_location=device)

generator.load_state_dict(ckpt["g_ema"])
generator = generator.to(device)

model_path = os.path.join(args.model_path, 'encoder.pt')
ckpt = torch.load(model_path, map_location=device)
opts = ckpt['opts']
opts['checkpoint_path'] = model_path
opts = Namespace(**opts)
opts.device = device
encoder = pSp(opts)
encoder.eval()
encoder.to(device)

exstyles = np.load(os.path.join(args.model_path, args.style, args.exstyle_name), allow_pickle='TRUE').item()

print('Load models successfully!')

with torch.no_grad():
    viz = []
    # load content image
    if args.align_face:
        I = transform(run_alignment(args)).unsqueeze(dim=0).to(device)
        I = F.adaptive_avg_pool2d(I, 1024)
    else:
        I = load_image(args.content).to(device)
    viz += [I]

    # reconstructed content image and its intrinsic style code
    img_rec, instyle = encoder(F.adaptive_avg_pool2d(I, 256), randomize_noise=False, return_latents=True,
                               z_plus_latent=True, return_z_plus_latent=True, resize=False)
    img_rec = torch.clamp(img_rec.detach(), -1, 1)
    viz += [img_rec]

    stylename = list(exstyles.keys())[args.style_id]
    latent = torch.tensor(exstyles[stylename]).to(device)
    if args.preserve_color:
        latent[:, 7:18] = instyle[:, 7:18]
    # extrinsic styte code
    exstyle = generator.generator.style(latent.reshape(latent.shape[0] * latent.shape[1], latent.shape[2])).reshape(
        latent.shape)

    # load style image if it exists
    S = None
    if os.path.exists(os.path.join(args.data_path, args.style, 'images/train', stylename)):
        S = load_image(os.path.join(args.data_path, args.style, 'images/train', stylename)).to(device)
        viz += [S]

    # style transfer 
    # input_is_latent: instyle is not in W space
    # z_plus_latent: instyle is in Z+ space
    # use_res: use extrinsic style path, or the style is not transferred
    # interp_weights: weight vector for style combination of two paths
    img_gen, _ = generator([instyle], exstyle, input_is_latent=False, z_plus_latent=True,
                           truncation=args.truncation, truncation_latent=0, use_res=True,
                           interp_weights=args.weight)
    img_gen = torch.clamp(img_gen.detach(), -1, 1)
    viz += [img_gen]

print('Generate images successfully!')


save_name = args.name + '_%d_%s' % (args.style_id, os.path.basename(args.content).split('.')[0])
# Generate preview effect picture
#save_image(torchvision.utils.make_grid(F.adaptive_avg_pool2d(torch.cat(viz, dim=0), 256), 4, 2).cpu(),
#           os.path.join(args.output_path, save_name + '_overview.jpg'))
# Generate final effect picture
save_image(img_gen[0].cpu(), os.path.join(args.output_path, save_name + '.jpg'))

print('Save images successfully!')

if os.path.exists(out_file):
    os.remove(out_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant