Skip to content

Commit

Permalink
address issue with encoding first frame separately #30
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 12, 2024
1 parent eef687c commit 434f36e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
18 changes: 11 additions & 7 deletions magvit2_pytorch/magvit2_pytorch.py
Expand Up @@ -815,7 +815,7 @@ def forward(self, x):
x = rearrange(x, 'b c t h w -> b h w c t')
x, ps = pack_one(x, '* c t')

x = F.pad(x, self.time_casual_padding)
x = F.pad(x, self.time_causal_padding)
out = self.conv(x)

out = unpack_one(out, ps, '* c t')
Expand Down Expand Up @@ -1544,7 +1544,10 @@ def encode(
# whether to pad video or not

if video_contains_first_frame:
video_len = video.shape[2]

video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2)
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]

# conditioning, if needed

Expand All @@ -1560,13 +1563,14 @@ def encode(
# taking into account whether to encode first frame separately

if encode_first_frame_separately:
first_frame, video = video[:, :, 0], video[:, :, 1:]
xff = self.conv_in_first_frame(first_frame)
pad, first_frame, video = unpack(video, video_packed_shape, 'b c * h w')
first_frame = self.conv_in_first_frame(first_frame)

x = self.conv_in(video)
video = self.conv_in(video)

if encode_first_frame_separately:
x, _ = pack([xff, x], 'b c * h w')
video, _ = pack([first_frame, video], 'b c * h w')
video = pad_at_dim(video, (self.time_padding, 0), dim = 2)

# encoder layers

Expand All @@ -1577,11 +1581,11 @@ def encode(
if has_cond:
layer_kwargs = cond_kwargs

x = fn(x, **layer_kwargs)
video = fn(video, **layer_kwargs)

maybe_quantize = identity if not quantize else self.quantizers

return maybe_quantize(x)
return maybe_quantize(video)

@beartype
def decode_from_code_indices(
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
@@ -1 +1 @@
__version__ = '0.3.2'
__version__ = '0.3.4'

0 comments on commit 434f36e

Please sign in to comment.