Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding "Image Align with Rife" and "Wavelet Color Fix" Nodes (#2714)
* adding "Image Align with Rife" and "Wavelet Color Fix" nodes * Create test * adding additional files for the "Image Align with Rife" node and the Rife model * Delete backend/src/packages/chaiNNer_pytorch/pytorch/processing/rife/test * Update IFNet_HDv3_v4_14_align.py * Update image_align_rife.py * Update wavelet_color_fix.py * Delete rife model * update image_align_rife.py to download rife model only when needed * cosmetic fixes image_align_rife.py * added minimums, removed comments, changed download, changed name, ruff * added minimum for wavelet number * removed commented out code * removed commented out code * cosmetics * Run ruff formatting * fixes, ignore a ton of stuff * move some files and reorder nodes in list * fix pyright errors --------- Co-authored-by: Joey Ballentine <joeyjballentine@gmail.com>
- Loading branch information
1 parent
3f1fbdf
commit a4be51d
Showing
5 changed files
with
664 additions
and
0 deletions.
There are no files selected for viewing
271 changes: 271 additions & 0 deletions
271
backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
# type: ignore | ||
# Original Rife Frame Interpolation by hzwer | ||
# https://github.com/megvii-research/ECCV2022-RIFE | ||
# https://github.com/hzwer/Practical-RIFE | ||
|
||
# Modifications to use Rife for Image Alignment by tepete/pifroggi ('Enhance Everything!' Discord Server) | ||
|
||
# Additional helpful github issues | ||
# https://github.com/megvii-research/ECCV2022-RIFE/issues/278 | ||
# https://github.com/megvii-research/ECCV2022-RIFE/issues/344 | ||
|
||
import torch | ||
import torch.nn.functional as F # noqa: N812 | ||
from torch import nn | ||
from torchvision import transforms | ||
|
||
from .warplayer import warp | ||
|
||
|
||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): # noqa: ANN001 | ||
return nn.Sequential( | ||
nn.Conv2d( | ||
in_planes, | ||
out_planes, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
dilation=dilation, | ||
bias=True, | ||
), | ||
nn.LeakyReLU(0.2, True), | ||
) | ||
|
||
|
||
def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): # noqa: ANN001 | ||
return nn.Sequential( | ||
nn.Conv2d( | ||
in_planes, | ||
out_planes, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
dilation=dilation, | ||
bias=False, | ||
), | ||
nn.BatchNorm2d(out_planes), | ||
nn.LeakyReLU(0.2, True), | ||
) | ||
|
||
|
||
class Head(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) | ||
self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) | ||
self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) | ||
self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) | ||
self.relu = nn.LeakyReLU(0.2, True) | ||
|
||
def forward(self, x, feat=False): # noqa: ANN001 | ||
x0 = self.cnn0(x) | ||
x = self.relu(x0) | ||
x1 = self.cnn1(x) | ||
x = self.relu(x1) | ||
x2 = self.cnn2(x) | ||
x = self.relu(x2) | ||
x3 = self.cnn3(x) | ||
if feat: | ||
return [x0, x1, x2, x3] | ||
return x3 | ||
|
||
|
||
class ResConv(nn.Module): | ||
def __init__(self, c, dilation=1): # noqa: ANN001 | ||
super().__init__() | ||
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) | ||
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) | ||
self.relu = nn.LeakyReLU(0.2, True) | ||
|
||
def forward(self, x): # noqa: ANN001 | ||
return self.relu(self.conv(x) * self.beta + x) | ||
|
||
|
||
class IFBlock(nn.Module): | ||
def __init__(self, in_planes, c=64): # noqa: ANN001 | ||
super().__init__() | ||
self.conv0 = nn.Sequential( | ||
conv(in_planes, c // 2, 3, 2, 1), | ||
conv(c // 2, c, 3, 2, 1), | ||
) | ||
self.convblock = nn.Sequential( | ||
ResConv(c), | ||
ResConv(c), | ||
ResConv(c), | ||
ResConv(c), | ||
ResConv(c), | ||
ResConv(c), | ||
ResConv(c), | ||
ResConv(c), | ||
) | ||
self.lastconv = nn.Sequential( | ||
nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2) | ||
) | ||
|
||
def forward(self, x, flow=None, scale=1): # noqa: ANN001 | ||
x = F.interpolate( | ||
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False | ||
) | ||
if flow is not None: | ||
flow = ( | ||
F.interpolate( | ||
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False | ||
) | ||
* 1.0 | ||
/ scale | ||
) | ||
x = torch.cat((x, flow), 1) | ||
feat = self.conv0(x) | ||
feat = self.convblock(feat) | ||
tmp = self.lastconv(feat) | ||
tmp = F.interpolate( | ||
tmp, scale_factor=scale, mode="bilinear", align_corners=False | ||
) | ||
flow = tmp[:, :4] * scale | ||
mask = tmp[:, 4:5] | ||
return flow, mask | ||
|
||
|
||
class IFNet(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.block0 = IFBlock(7 + 16, c=192) | ||
self.block1 = IFBlock(8 + 4 + 16, c=128) | ||
self.block2 = IFBlock(8 + 4 + 16, c=96) | ||
self.block3 = IFBlock(8 + 4 + 16, c=64) | ||
self.encode = Head() | ||
|
||
def align_images( | ||
self, | ||
img0, # noqa: ANN001 | ||
img1, # noqa: ANN001 | ||
timestep, # noqa: ANN001 | ||
scale_list, # noqa: ANN001 | ||
blur_strength, # noqa: ANN001 | ||
ensemble, # noqa: ANN001 | ||
device, # noqa: ANN001 | ||
): | ||
# optional blur | ||
if blur_strength is not None and blur_strength > 0: | ||
blur = transforms.GaussianBlur( | ||
kernel_size=(5, 5), sigma=(blur_strength, blur_strength) | ||
) | ||
img0_blurred = blur(img0) | ||
img1_blurred = blur(img1) | ||
else: | ||
img0_blurred = img0 | ||
img1_blurred = img1 | ||
|
||
f0 = self.encode(img0_blurred[:, :3]) | ||
f1 = self.encode(img1_blurred[:, :3]) | ||
flow_list = [] | ||
mask_list = [] | ||
flow = None | ||
mask = None | ||
block = [self.block0, self.block1, self.block2, self.block3] | ||
for i in range(4): | ||
if flow is None: | ||
flow, mask = block[i]( | ||
torch.cat( | ||
(img0_blurred[:, :3], img1_blurred[:, :3], f0, f1, timestep), 1 | ||
), | ||
None, | ||
scale=scale_list[i], | ||
) | ||
if ensemble: | ||
f_, m_ = block[i]( | ||
torch.cat( | ||
( | ||
img1_blurred[:, :3], | ||
img0_blurred[:, :3], | ||
f1, | ||
f0, | ||
1 - timestep, | ||
), | ||
1, | ||
), | ||
None, | ||
scale=scale_list[i], | ||
) | ||
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 | ||
mask = (mask + (-m_)) / 2 | ||
else: | ||
wf0 = warp(f0, flow[:, :2], device) | ||
wf1 = warp(f1, flow[:, 2:4], device) | ||
fd, m0 = block[i]( | ||
torch.cat( | ||
( | ||
img0_blurred[:, :3], | ||
img1_blurred[:, :3], | ||
wf0, | ||
wf1, | ||
timestep, | ||
mask, | ||
), | ||
1, | ||
), | ||
flow, | ||
scale=scale_list[i], | ||
) | ||
if ensemble: | ||
f_, m_ = block[i]( | ||
torch.cat( | ||
( | ||
img1_blurred[:, :3], | ||
img0_blurred[:, :3], | ||
wf1, | ||
wf0, | ||
1 - timestep, | ||
-mask, | ||
), | ||
1, | ||
), | ||
torch.cat((flow[:, 2:4], flow[:, :2]), 1), | ||
scale=scale_list[i], | ||
) | ||
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 | ||
mask = (m0 + (-m_)) / 2 | ||
else: | ||
mask = m0 | ||
flow = flow + fd | ||
mask_list.append(mask) | ||
flow_list.append(flow) | ||
|
||
# apply warp to original image | ||
aligned_img0 = warp(img0, flow_list[-1][:, :2], device) | ||
|
||
# add clamp here instead of in warplayer script, as it changes the output there | ||
aligned_img0 = aligned_img0.clamp(min=0.0, max=1.0) | ||
return aligned_img0, flow_list[-1] | ||
|
||
def forward( | ||
self, | ||
x, # noqa: ANN001 | ||
timestep=1, # noqa: ANN001 | ||
training=False, # noqa: ANN001 | ||
fastmode=True, # noqa: ANN001 | ||
ensemble=True, # noqa: ANN001 | ||
num_iterations=1, # noqa: ANN001 | ||
multiplier=0.5, # noqa: ANN001 | ||
blur_strength=0, # noqa: ANN001 | ||
device="cuda", # noqa: ANN001 | ||
): | ||
if not training: | ||
channel = x.shape[1] // 2 | ||
img0 = x[:, :channel] | ||
img1 = x[:, channel:] | ||
|
||
scale_list = [multiplier * 8, multiplier * 4, multiplier * 2, multiplier] | ||
|
||
if not torch.is_tensor(timestep): | ||
timestep = (x[:, :1].clone() * 0 + 1) * timestep | ||
else: | ||
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) # type: ignore | ||
|
||
for _iteration in range(num_iterations): | ||
aligned_img0, flow = self.align_images( | ||
img0, img1, timestep, scale_list, blur_strength, ensemble, device | ||
) | ||
img0 = aligned_img0 # use the aligned image as img0 for the next iteration | ||
|
||
return aligned_img0, flow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# type: ignore | ||
import torch | ||
|
||
backwarp_tenGrid = {} # noqa: N816 | ||
|
||
|
||
def warp(tenInput, tenFlow, device): # noqa: ANN001, N803 | ||
k = (str(tenFlow.device), str(tenFlow.size())) | ||
if k not in backwarp_tenGrid: | ||
tenHorizontal = ( # noqa: N806 | ||
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) | ||
.view(1, 1, 1, tenFlow.shape[3]) | ||
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) | ||
) | ||
tenVertical = ( # noqa: N806 | ||
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) | ||
.view(1, 1, tenFlow.shape[2], 1) | ||
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) | ||
) | ||
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) | ||
|
||
tenFlow = torch.cat( # noqa: N806 | ||
[ | ||
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), | ||
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), | ||
], | ||
1, | ||
) | ||
|
||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) | ||
tenOutput = torch.nn.functional.grid_sample( | ||
input=tenInput, | ||
grid=g, | ||
mode="bicubic", | ||
padding_mode="border", | ||
align_corners=True, | ||
) | ||
return tenOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.