Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to train/load S/M/L CNN models in tensorflow? #220

Open
kristoftunner opened this issue Dec 19, 2023 · 33 comments
Open

How to train/load S/M/L CNN models in tensorflow? #220

kristoftunner opened this issue Dec 19, 2023 · 33 comments

Comments

@kristoftunner
Copy link

Is there a way to train/load S/M/L CNN models in tensorflow? I am interested in experimenting a bit with these models in tensorflow or onnxruntime. I see that there is one specific model in the tensorflow directory, but I am not sure which one is it.

@Tama47
Copy link
Contributor

Tama47 commented Feb 2, 2024

Is there a way to train/load S/M/L CNN models in tensorflow?

Yes, you would need to load the original models in TensorFlow.

I see that there is one specific model in the tensorflow directory, but I am not sure which one is it.

Someone has converted the original Anime4K models into Core ML models. I can provide you the link.

The ones you're looking for are under Models >
model-sr-s.wifm / model-sr-m.wifm / model-sr-l.wifm for upscale models
model-restore-s.wifm / model-restore-m.wifm / model-restore-l.wifm for restore models

You would need to convert them to TensorFlow, then create a Python or Jupyter Notebook script to load the weights and models. You can use the models to fine-tune and train your own, better model.

Note I have not converted or trained the models myself, and cannot guarantee success. I can only provide general steps, and you will need to do your own research. Supposedly, the steps to convert between Core ML and TensorFlow should be relatively straightforward. The training process itself should be more or less the same as training any other TensorFlow or ESRGAN models.

Sample Python Script:
import coremltools as ct


# Load Core ML model coreml_model = ct.models.MLModel('model-sr-s.wifm')

# Convert Core ML to TensorFlow tf_model = ct.convert(coreml_model, source='mlmodel', target='tensorflow')

@arianaa30
Copy link

@Tama47 The training code located in \tensorflow dir is for the restore or upscale model? And if it is the restore, is it easy to change it to the "upscale" model to train?

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 4, 2024

@Tama47 From what I've researched so far, there is no way to convert current version of MLModel to TF2 or ONNX. However, I managed to get Netron working and also loading weight:

  1. Change the file extension from .whml to .zip
  2. Compress all files inside of *.mlpackage folder (not including the folder itself) to a zip file

Preset-a-hq:
preset-a-hq zip

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 6, 2024

I can convert some GLSL files to PyTorch now but still stuck at converting the weight. Here is the code if anyone interested:
https://colab.research.google.com/drive/11xAn4fyAUJPZOjrxwnL2ipl_1DGGegkB

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 6, 2024

@arianaa30

The training code located in \tensorflow dir is for the restore or upscale model?

It contains both. In Gen_Shader.ipynb, SR1Model generate restore models while SR2Model generate upscale ones

@arianaa30
Copy link

@arianaa30

The training code located in \tensorflow dir is for the restore or upscale model?

It contains both. In Gen_Shader.ipynb, SR1Model generate restore models while SR2Model generate upscale ones

Thanks. Apparently the models there for for M size shaders. Do you happen to know what parameter values (block_depth, etc?) I should use to get the S/L sizes?

Also, the code uses epochs=1 (3 times). Should I change them to like 100? I noticed the loss doesn't really decrease.

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 7, 2024

@arianaa30

Thanks. Apparently the models there for for M size shaders. Do you happen to know what parameter values (block_depth, etc?) I should use to get the S/L sizes?

Ig you can figure out the block_depth with a model's components
Conv2d(3, 4) means it has 3 input channels and 4 output channels. CReLU() activation function doubles channel size, e.g. (1, 128, 128, 4) -> (1, 128, 128, 8)

Size S:

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

Size M

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_0): Sequential(
    (0): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(56, 4, kernel_size=(1, 1), stride=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

Size L:

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_last_1): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2d_last_2): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

Size VL:

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_0): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_1): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(112, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_1): Conv2d(112, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_2): Conv2d(112, 4, kernel_size=(1, 1), stride=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

Size UL:

Anime4K(
  (conv2d_0): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_1_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_2_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_3_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_4_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_5_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_0): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_1): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_6_2): Sequential(
    (0): Conv2d(24, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): CReLU()
  )
  (conv2d_last_0): Conv2d(120, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_1): Conv2d(120, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2d_last_2): Conv2d(120, 4, kernel_size=(1, 1), stride=(1, 1))
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (depth_to_space): PixelShuffle(upscale_factor=2)
)

@arianaa30
Copy link

arianaa30 commented Feb 7, 2024

Thanks. I have those architecture. But do you know what to pass to this function to get each of those S, L, VL sizes? I need it for training.

def SR2Model (input_depth=3, highway_depth=4, block_depth=4, init='he_normal', init_last = RandomNormal (mean=0.0, stddev=0.001)):

@Fannovel16
Copy link
Contributor

@arianaa30 My main library is PyTorch so Idk tbh

@arianaa30
Copy link

@Fannovel16 Btw, do you know how to measure SSIM/PSNR of what Anime4K shaders provide me (upscaled version of low-res image) vs the original high resolution image? Is there a way to measure them?

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 10, 2024

@arianaa30 You can pass images to mpv, ffmpeg with compiled libplacebo using commands and save upscaled images. I'm not sure how to do that tho since I'm not familiar much with ffmpeg or mpv. Alternatively, you can try Anime4K-GPU and use canvas.toDataURL("image/png") to save the results

@Fannovel16
Copy link
Contributor

Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight.

@arianaa30
Copy link

@arianaa30 You can pass images to mpv, ffmpeg with compiled libplacebo using commands and save upscaled images. I'm not sure how to do that tho since I'm not familiar much with ffmpeg or mpv. Alternatively, you can try Anime4K-GPU and use canvas.toDataURL("image/png") to save the results

Hmm ok thanks. The problem is we apply multiple anime4k shaders (restore, upscale, restore, ...). Not sure if we can do that..

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 11, 2024

@arianaa30 It's possible: mpv-player/mpv#9589. But now you mentioned it, I kinda wonder how A4K shaders were actually trained.

@arianaa30
Copy link

arianaa30 commented Feb 14, 2024

@Fannovel16 Yeah the training has some unknowns. Using the Tensorflow script, I trained a model/shader by calling SR2Model() function, and it works. But when I trained the SR1Model (which should be the Restore), the h5 model training works. But when trying to convert with Gen_Shader.py, it shows me a "Shape Mismatch" error. Have you experienced it before?

  • Adding @Tama47 In case you might have insights on this.
 Layer (type)                                Output Shape                                 Param #        Connected to
======================================================================================================================================================
 input.MAIN (InputLayer)                     [(None, None, None, 3)]                      0              []

 conv2d (Conv2D)                             (None, None, None, 4)                        112            ['input.MAIN[0][0]']

 tf.compat.v1.nn.crelu (TFOpLambda)          (None, None, None, 8)                        0              ['conv2d[0][0]']

 conv2d_1 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu[0][0]']

 tf.compat.v1.nn.crelu_1 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_1[0][0]']

 conv2d_2 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_1[0][0]']

 tf.compat.v1.nn.crelu_2 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_2[0][0]']

 conv2d_3 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_2[0][0]']

 tf.compat.v1.nn.crelu_3 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_3[0][0]']

 conv2d_4 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_3[0][0]']

 tf.compat.v1.nn.crelu_4 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_4[0][0]']

 conv2d_5 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_4[0][0]']

 tf.compat.v1.nn.crelu_5 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_5[0][0]']

 conv2d_6 (Conv2D)                           (None, None, None, 4)                        292            ['tf.compat.v1.nn.crelu_5[0][0]']

 tf.compat.v1.nn.crelu_6 (TFOpLambda)        (None, None, None, 8)                        0              ['conv2d_6[0][0]']

 concatenate (Concatenate)                   (None, None, None, 56)                       0              ['tf.compat.v1.nn.crelu[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_1[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_2[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_3[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_4[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_5[0][0]',
                                                                                                          'tf.compat.v1.nn.crelu_6[0][0]']

 conv2d_lastresid.MAIN (Conv2D)              (None, None, None, 3)                        171            ['concatenate[0][0]']

 add.ignore.MAIN (Add)                       (None, None, None, 3)                        0              ['conv2d_lastresid.MAIN[0][0]',
                                                                                                          'input.MAIN[0][0]']

======================================================================================================================================================
Total params: 2035 (7.95 KB)
Trainable params: 2035 (7.95 KB)
Non-trainable params: 0 (0.00 Byte)
______________________________________________________________________________________________________________________________________________________
Traceback (most recent call last):
  File "Gen_Shader.py", line 141, in <module>
    model.load_weights("model-checkpoint-new.h5")
  File "/opt/tensorflow/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/opt/tensorflow/lib/python3.10/site-packages/keras/src/backend.py", line 4361, in _assign_value_to_variable
    variable.assign(value)
ValueError: Cannot assign value to variable ' conv2d_lastresid.MAIN/kernel:0': Shape mismatch.The variable shape (1, 1, 56, 3), and the assigned value shape (12, 56, 1, 1) are incompatible.

@kato-megumi
Copy link

kato-megumi commented Feb 14, 2024

Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight.

@Fannovel16
This is my script to convert GLSL shaders to PyTorch model.
https://gist.github.com/kato-megumi/44e52b4cc0e082e94d452a7df04243e0

@arianaa30
Copy link

Btw I have no luck on converting GLSL shaders into actual models in PyTorch. I believe the problem is transforming the weight.

@Fannovel16 This is my script to convert GLSL shaders to PyTorch model. https://gist.github.com/kato-megumi/44e52b4cc0e082e94d452a7df04243e0

Is the displayed image the upscaled output? Can we apply multiple shaders as well?

@kato-megumi
Copy link

Of course, just string the models together like this: model2(model1(image))

@Fannovel16
Copy link
Contributor

@kato-megumi Thanks! It seems like I got the CreLU formula wrong

@arianaa30
Copy link

arianaa30 commented Feb 15, 2024

Of course, just string the models together like this: model2(model1(image))

Great thanks. Can we simply add other shaders to the list as well? I want to use Anime4K_Clamp_Highlights.glsl as well. Instructions highly recommend to have this in the list as it highly increases the quality.

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 15, 2024

@arianaa30 Here it is
P/s: I made some changes based on kato's advice

def get_luma(x):
    x = x[:, 0] * 0.299 + x[:, 1] * 0.587 + x[:, 2] * 0.114
    x = x.unsqueeze(1)
    return x

class MaxPoolKeepShape(nn.Module):
    def __init__(self, kernel_size, stride=None):
        super(MaxPoolKeepShape, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        kernel_height, kernel_width = self.kernel_size
        pad_height = (((height - 1) // self.stride + 1) - 1) * self.stride + kernel_height - height
        pad_width = (((width - 1) // self.stride + 1) - 1) * self.stride + kernel_width - width

        x = F.pad(x, (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2))
        x = F.max_pool2d(x, kernel_size=self.kernel_size, stride=self.stride)
        return x

class ClampHighlight(nn.Module):
    def __init__(self):
        super(ClampHighlight, self).__init__()
        self.max_pool = MaxPoolKeepShape(kernel_size=(5, 5), stride=1)
    def forward(self, shader_img, orig_img):
        curr_luma = get_luma(shader_img)
        statsmax = self.max_pool(get_luma(orig_img))
        if statsmax.shape != curr_luma.shape:
            statsmax = F.interpolate(statsmax, curr_luma.shape[2:4])
        new_luma = torch.min(curr_luma, statsmax)
        return shader_img - (curr_luma - new_luma)

new_img = ClampHighlight()(out[None], image2)
display(to_pil(new_img[0]))

@kato-megumi
Copy link

  • ClampHighlight clamps the output of another shader using the original image's luminance, so it requir two images as input.
  • The kernel size should be (5, 5).

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 15, 2024

@kato-megumi

The kernel size should be (5, 5).

Oh so the first block iterates x-axis while the second block iterates y-axis

ClampHighlight clamps the output of another shader using the original image's luminance, so it requir two images as input.

What is PREKERNEL? I assumed it is the same as MAIN as the mpv doc is a bit ambiguous

@kato-megumi
Copy link

kato-megumi commented Feb 15, 2024

Oh so the first block iterates x-axis while the second block iterates y-axis

Yeah, it reduce computation cost compare to find max of 25 pixel in single pass.

What is PREKERNEL? I assumed it is the same as MAIN as the mpv doc is a bit ambiguous

In anime4k doc about ClampHighlight: "Computes and saves image statistics at the location it is placed in the shader stage, then clamps the image highlights at the end after all the shaders to prevent overshoot and reduce ringing."

PREKERNEL
The image immediately before the scaler kernel runs.

I think it refers to the image right before mpv performs internal scaling.
Other shaders are hooked to MAIN, which come before PREKERNEL in mpv's rendering process, so those should run first.

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 15, 2024

I added ClampHightlight, AutoDownscalePre, automatic glsl downloading and pipeline class for convenience:
https://colab.research.google.com/drive/11xAn4fyAUJPZOjrxwnL2ipl_1DGGegkB

@arianaa30
Copy link

@arianaa30 Here it is P/s: I made some changes based on kato's advice

def get_luma(x):
    x = x[:, 0] * 0.299 + x[:, 1] * 0.587 + x[:, 2] * 0.114
    x = x.unsqueeze(1)
    return x

class MaxPoolKeepShape(nn.Module):
    def __init__(self, kernel_size, stride=None):
        super(MaxPoolKeepShape, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        kernel_height, kernel_width = self.kernel_size
        pad_height = (((height - 1) // self.stride + 1) - 1) * self.stride + kernel_height - height
        pad_width = (((width - 1) // self.stride + 1) - 1) * self.stride + kernel_width - width

        x = F.pad(x, (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2))
        x = F.max_pool2d(x, kernel_size=self.kernel_size, stride=self.stride)
        return x

class ClampHighlight(nn.Module):
    def __init__(self):
        super(ClampHighlight, self).__init__()
        self.max_pool = MaxPoolKeepShape(kernel_size=(5, 5), stride=1)
    def forward(self, shader_img, orig_img):
        curr_luma = get_luma(shader_img)
        statsmax = self.max_pool(get_luma(orig_img))
        if statsmax.shape != curr_luma.shape:
            statsmax = F.interpolate(statsmax, curr_luma.shape[2:4])
        new_luma = torch.min(curr_luma, statsmax)
        return shader_img - (curr_luma - new_luma)

new_img = ClampHighlight()(out[None], image2)
display(to_pil(new_img[0]))

Great I will try it.

@kato-megumi
Copy link

I recommend using https://github.com/muslll/neosr/ to train model.
Just put pytorch model in arch/ folder, tweak some config in yml file and train.

@arianaa30
Copy link

@Fannovel16 Btw do you have a training code for the PyTorch models? Would you be able to share?

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 26, 2024

@arianaa30 No but you can use my notebook to get the model and randomize its parameters to train

@arianaa30
Copy link

@arianaa30 No but you can use my notebook to get the model and randomize its parameters to train

Should I fine tune it (only train last layers) or train the whole network?
Btw your notebook shows some errors in the convert() function and use of combination () when I want to run the pipeline code. Maybe something recently changed.

@Fannovel16
Copy link
Contributor

Fannovel16 commented Feb 26, 2024

@arianaa30 I forgot to test 😅 . It works now

Should I fine tune it (only train last layers) or train the whole network?

Anime4K's CNN networks are pretty small so training from scratch is a better choice, imo.

@arianaa30
Copy link

@Fannovel16 @kato-megumi bumping up this thread:
So I was trying to train the PyTorch Anime4K models using NeoSR. I trained and tested the produced model and idk why it generates a bad quality image with many grains, with SSIM like 0.65- 0.70. Here are more details.

Have any of you had any success training the PyTorch models, at least just to test out? It is so weird

@kato-megumi
Copy link

@arianaa30
You might have encountered grain artifacts due to a flawed dataset or improper loss configuration. You'd likely receive better assistance from the folks at Enhance Everything!.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants