Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MEAD results #3

Open
jetory12 opened this issue Apr 9, 2022 · 3 comments
Open

MEAD results #3

jetory12 opened this issue Apr 9, 2022 · 3 comments

Comments

@jetory12
Copy link

jetory12 commented Apr 9, 2022

Hi,
you report MEAD results in the paper in the text and compare to MoCoGANHD but I can not find their numbers in the paper. Can you share the full results for MEAD? Did you compare DIGAN on MEAD as well?

Also I want to use your new evaluation protocol and compare to your model, DIGAN and MoCoGAN-HD and want to use your evaluation scripts. Do you mind to share the metrics scripts / changed dataset file you used for the DIGAN and MoCoGAN-HD repos?

@universome
Copy link
Owner

Hi! Yeah, we do not report the MEAD results in the main table since we could only run StyleGAN-V and MoCoGAN-HD on it (DIGAN seems to be expensive to run on it, but we didn't try though). Our FID/FVD scores are specified in raw text in the paper. We do not report FID/FVD for MoCoGAN-HD generations, only for StyleGAN-V:
Screen Shot 2022-04-14 at 09 43 51

The metrics scripts are src/scripts/calc_metrics_for_dataset.py and src/scripts/calc_metrics.py depending on whether you want to compute between two datasets or between a dataset and a generator. For the datasets — I will share the link tomorrow (it is 170G, so it takes time to find a suitable cloud storage for it). The usage is specified in README.md in the Evaluation section.

I am sorry for replying late. If you have any other questions or additional information — feel free to ask!

@jetory12
Copy link
Author

Hi, thanks a lot for the reply.

I still have a question to the second part. You said that you used the DIGAN and MoCoGAN-HD official repos to evaluate with your new FVD protocol. I tried combining your protocol with their repos but the numbers are off a lot. Since they use different dataloaders etc. than you, I wanted to ask if you could share your modified repos of DIGAN and MoCoGAN-HD where you included your FVD protocol into their repos? Or basically the exact way you used to evaluate their re-trained models. I want to make sure that I evaluate their models correctly. This would help me a lot to reproduce all numbers, since I want to use your new protocol across all models as you did. No need to upload any datasets, just the modified repos of DIGAN and MoCoGAN-HD would be great somewhere uploaded, if that is possible?

@universome
Copy link
Owner

Hi! For their repos, we didn't integrate our FVD evaluation into them, but rather sampled from the models to construct a dataset of fake videos and then used our src/scripts/calc_metrics_for_dataset.py script. So, the only things we changed were sampling procedures (to generate long videos, videos starting at some t, etc.). For DIGAN, we also changed its data loading strategy to select videos uniformly at random (simply with our dataset class). We do not have any repos for this, just a big mess our infrastructure bindings and bash scripts. In case you need our sampling scripts, then here are our versions (sorry for the code quality) of

`evaluate.py` from MoCoGAN-HD
import os

# import cv2
import numpy as np
import torch
from torchvision.io import write_video
from tqdm import tqdm
from PIL import Image
import torchvision.transforms.functional as TVF
from torchvision import utils

from options.test_options import TestOptions
from models.models import create_model


def test():
    opt = TestOptions().parse(save=False)

    ### initialize models
    modelG = create_model(opt)

    if opt.static_z_path is None:
        z = torch.randn(opt.num_test_videos, opt.latent_dimension).cuda()
    else:
        z = torch.load(opt.static_z_path)[:opt.num_test_videos].cuda()

    def create_video(z, modelG, opt, use_noise: bool, prefix: str, save: bool=False) -> torch.Tensor:
        x_fake, _, _ = modelG(styles=[z],
                              n_frames=opt.n_frames_G,
                              use_noise=use_noise,
                              frame_rate_increase=opt.frame_rate_increase,
                              time_offset=opt.time_offset)
        x_fake = x_fake.view(1, -1, 3, opt.style_gan_size, opt.style_gan_size).data
        x_fake = x_fake.clamp(-1, 1)

        video = x_fake[0].cpu()
        video = ((video + 1.) / 2. * 255).type(torch.uint8).permute(0, 2, 3, 1) # [t, h, w, 3]

        if save:
            if opt.as_frames:
                save_dir = os.path.join(opt.output_path, prefix)
                os.makedirs(save_dir, exist_ok=True)
                for i, frame in enumerate(video):
                    Image.fromarray(frame.numpy()).save(os.path.join(save_dir, f'{i + opt.time_offset:05d}.jpg'), q=95)
            else:
                write_video(os.path.join(opt.output_path, prefix + '.mp4'), video, fps=opt.fps)

        return video

    if opt.as_grid:
        results_dir = os.path.dirname(opt.output_path)
        if results_dir != '':
            os.makedirs(results_dir, exist_ok=True)
    else:
        os.makedirs(opt.output_path, exist_ok=True)

    with torch.no_grad():
        videos = []
        for j in tqdm(range(opt.num_test_videos)):
            # prefix = opt.name + '_' + str(opt.load_pretrain_epoch) + '_' + str(j) + '_noise'
            # create_video(curr_z, modelG, opt=opt, use_noise=True, prefix=prefix, save=not opt.as_grid)

            prefix = f'{opt.name}_epoch{opt.load_pretrain_epoch:05d}_{j:06d}'
            video = create_video(z[[j]], modelG, opt=opt, use_noise=False, prefix=prefix, save=not opt.as_grid) # [t, h, w, 3]

            if opt.as_grid:
                videos.append(video)

        if opt.as_grid:
            videos = torch.stack(videos) # [num_videos, t, h, w, 3]
            videos = videos.permute(1, 0, 4, 2, 3) # [t, num_videos, 3, h, w]
            grids = [utils.make_grid(vs, nrow=np.ceil(np.sqrt(opt.num_test_videos)).astype(int).item()) for vs in videos] # [t, 3, gh, gw]
            grid_video = np.array([grid.permute(1, 2, 0).numpy() for grid in grids]) # [t, gh, gw, 3]
            write_video(os.path.join(opt.output_path), grid_video, fps=opt.fps)

        print(opt.name + ' Finished!')


# def save_video_frames_as_mp4(frames: torch.Tensor, fps: int, save_path: os.PathLike, verbose: bool=False):
#     # Load data
#     frame_w, frame_h = frames[0].shape[1:]
#     fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
#     video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h))
#     frames = tqdm(frames, desc='Saving videos') if verbose else frames
#     for frame in frames:
#         assert frame.shape[0] == 3, "RGBA/grayscale images are not supported"
#         frame = np.array(TVF.to_pil_image(frame))
#         video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

#     # Uncomment this line to release the memory.
#     # It didn't work for me on centos and complained about installing additional libraries (which requires root access)
#     # cv2.destroyAllWindows()
#     video.release()


if __name__ == "__main__":
    test()
`src/scripts/generate_videos.py` from DIGAN
"""Generates a dataset of images using pretrained network pickle."""
import math
import sys; sys.path.extend(['.', 'src'])
import re
import json
import os
import random
from typing import List

import click
import dnnlib
import numpy as np
import torch
from torch import Tensor
from tqdm import tqdm
import torchvision.transforms.functional as TVF

import legacy
from training.networks import Generator
from scripts import save_image_grid
from einops import rearrange


torch.set_grad_enabled(False)


@click.command()
@click.pass_context
@click.option('--network_pkl', help='Network pickle filename', required=False)
@click.option('--experiment_dir', help='A directory with the experiment output', required=False)
@click.option('--video_len', type=int, help='Number of frames per video', default=16, show_default=True)
@click.option('--num_videos', type=int, help='Number of images to generate', default=100, show_default=True)
@click.option('--batch_size', type=int, help='Batch size for video generation', default=16, show_default=True)
@click.option('--seed', type=int, help='Random seed', default=42, metavar='DIR')
@click.option('--output_path', help='Where to save the output images', type=str, required=True, metavar='DIR')
@click.option('--fps', help='FPS to save video with', type=int, required=False, metavar='INT')
@click.option('--as_frames', help='Should we save videos as frames?', type=bool, default=False, metavar='BOOL')
@click.option('--num_z', help='Number of different z to use when generating the videos', type=int, default=None, metavar='INT')
@click.option('--time_offset', help='Time offset for generation', type=int, default=0, metavar='INT')
def generate_videos(
    ctx: click.Context,
    network_pkl: str,
    experiment_dir: str,
    video_len: int,
    num_videos: int,
    batch_size: int,
    seed: int,
    output_path: str,
    fps: int,
    as_frames: bool,
    num_z: int,
    time_offset: int,
):
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')

    with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device).eval() # type: ignore
        G.forward = Generator.forward.__get__(G, Generator)
        print("Done. ")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if as_frames:
        os.makedirs(output_path, exist_ok=True)
        curr_video_idx = 0
        all_z = torch.randn(num_videos, G.z_dim, device=device) # [num_videos, z_dim]
        all_z_motion = torch.randn_like(all_z) # [num_videos, z_motion_dim]

        for batch_idx in tqdm(range((num_videos + batch_size - 1) // batch_size), desc='Generating videos'):
            z = all_z[batch_idx * batch_size : (batch_idx + 1) * batch_size] # [batch_size, z_dim]
            z_motion = all_z_motion[batch_idx * batch_size : (batch_idx + 1) * batch_size] # [batch_size, z_motion_dim]
            videos = lean_generation(G=G, z=z, video_len=video_len, z_motion=z_motion, noise_mode='const', time_offset=time_offset) # [b, c, t, h, w]
            videos = videos.permute(0, 2, 1, 3, 4) # [b, t, c, h, w]
            videos = (videos * 0.5 + 0.5).clamp(0, 1) # [b, t, c, h, w]

            for video in videos:
                save_video_frames_as_frames(video, os.path.join(output_path, f'{curr_video_idx:06d}'), time_offset=time_offset)
                curr_video_idx += 1

                if curr_video_idx == num_videos:
                    break
    else:
        if os.path.dirname(output_path) != '':
            os.makedirs(os.path.dirname(output_path), exist_ok=True)

        grid_size = (int(math.sqrt(num_videos)), int(math.sqrt(num_videos)))
        if num_z is None:
            grid_z = torch.randn([int(grid_size[0] * grid_size[1]), G.z_dim], device=device).split(1)
        else:
            assert grid_size[0] * grid_size[1] % num_z == 0
            grid_z = torch.randn(num_z, G.z_dim, device=device).repeat_interleave(grid_size[0] * grid_size[1] // num_z, dim=0).split(1)
        # videos = [G(z, None, timesteps=timesteps, noise_mode='const')[0].cpu() for z in grid_z]
        videos = [lean_generation(G=G, z=z, video_len=video_len, noise_mode='const', time_offset=time_offset) for z in tqdm(grid_z, desc='Generating videos')]
        images = torch.cat(videos).numpy()

        save_image_grid(images, output_path, drange=[-1, 1], grid_size=grid_size, fps=fps)


def lean_generation(G: torch.nn.Module, z: Tensor, video_len: int, frames_batch_size: int=32, z_motion: Tensor=None, time_offset: int=0, **kwargs):
    if z_motion is None:
        z_motion = torch.randn(z.shape[0], 512).to(z.device) # [num_videos, z_motion_dim]
    Ts = torch.linspace(0, video_len / 16.0, steps=video_len).view(video_len, 1, 1).unsqueeze(0) + time_offset / 16.0 # [1, video_len, 1, 1]
    Ts = Ts.repeat(z.shape[0], 1, 1, 1).to(z.device) # [num_videos, video_len, 1, 1]
    all_frames = []

    for curr_batch_idx in range((video_len + frames_batch_size - 1) // frames_batch_size):
        curr_ts = Ts[:, curr_batch_idx * frames_batch_size : (curr_batch_idx + 1) * frames_batch_size, :, :] # [1, frames_batch_size, 1, 1]
        curr_ts = curr_ts.reshape(-1, 1, 1, 1) # [frames_batch_size, 1, 1, 1]
        frames = G(z=z, c=None, z_motion=z_motion, timesteps=video_len, Ts=curr_ts, **kwargs)[0].cpu() # [num_videos * frames_batch_size, c, h, w]
        frames = frames.view(len(z), -1, *frames.shape[1:]) # [num_videos, frame_batch_size, c, h, w]
        all_frames.append(frames)

    videos = torch.cat(all_frames, dim=1) # [num_videos, video_len, c, h, w]
    videos = videos.permute(0, 2, 1, 3, 4) # [num_videos, c, video_len, h, w]

    return videos


def save_video_frames_as_frames(frames: List[Tensor], save_dir: os.PathLike, time_offset: int=0):
    os.makedirs(save_dir, exist_ok=True)

    for i, frame in enumerate(frames):
        save_path = os.path.join(save_dir, f'{i + time_offset:06d}.jpg')
        TVF.to_pil_image(frame).save(save_path, q=95)


if __name__ == "__main__":
    generate_videos()

(Also replace the line if timesteps > 2: with if timesteps > 2 or not Ts is None: in src/training/networks.py)

Also, I have uploaded our version of the MEAD 1024 dataset: https://disk.yandex.ru/d/PACh_RRsVJ93AA (Yandex.Disk split the archive into parts, it's 170G in total).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants