Skip to content

Commit

Permalink
able to customize attention heads and dim head differently across stages
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 14, 2023
1 parent fb22e07 commit a8d1582
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 17 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-unet',
packages = find_packages(exclude=[]),
version = '0.3.0',
version = '0.3.1',
license='MIT',
description = 'X-Unet',
long_description_content_type = 'text/markdown',
Expand All @@ -18,6 +18,7 @@
'unets',
],
install_requires=[
'beartype',
'einops>=0.4',
'torch>=1.6',
],
Expand Down
40 changes: 24 additions & 16 deletions x_unet/x_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from beartype import beartype
from beartype.typing import Tuple, Union, Optional

# helper functions

def exists(val):
Expand Down Expand Up @@ -291,30 +294,35 @@ def forward(self, x, fmaps = None):

# unet

def MaybeTuple(type):
return Union[type, Tuple[type, ...]]

def kernel_and_same_pad(*kernel_size):
paddings = tuple(map(lambda k: k // 2, kernel_size))
return dict(kernel_size = kernel_size, padding = paddings)

class XUnet(nn.Module):

@beartype
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
frame_kernel_size = 1,
dim_mults = (1, 2, 4, 8),
num_blocks_per_stage = (2, 2, 2, 2),
num_self_attn_per_stage = (0, 0, 0, 1),
nested_unet_depths = (0, 0, 0, 0),
dim_mults: MaybeTuple(int) = (1, 2, 4, 8),
num_blocks_per_stage: MaybeTuple(int) = (2, 2, 2, 2),
num_self_attn_per_stage: MaybeTuple(int) = (0, 0, 0, 1),
nested_unet_depths: MaybeTuple(int) = (0, 0, 0, 0),
nested_unet_dim = 32,
channels = 3,
use_convnext = False,
resnet_groups = 8,
consolidate_upsample_fmaps = True,
skip_scale = 2 ** -0.5,
weight_standardize = False,
attn_heads = 8,
attn_dim_head = 32
attn_heads: MaybeTuple(int) = 8,
attn_dim_head: MaybeTuple(int) = 32
):
super().__init__()

Expand Down Expand Up @@ -354,10 +362,8 @@ def __init__(

# attn kwargs

attn_kwargs = dict(
heads = attn_heads,
dim_head = attn_dim_head
)
attn_heads = cast_tuple(attn_heads, num_resolutions)
attn_dim_head = cast_tuple(attn_dim_head, num_resolutions)

# modules for all layers

Expand All @@ -367,21 +373,23 @@ def __init__(
in_out,
nested_unet_depths,
num_blocks_per_stage,
num_self_attn_per_stage
num_self_attn_per_stage,
attn_heads,
attn_dim_head
]

up_stage_parameters = [reversed(params[:-1]) for params in down_stage_parameters]

# downs

for ind, ((dim_in, dim_out), nested_unet_depth, num_blocks, self_attn_blocks) in enumerate(zip(*down_stage_parameters)):
for ind, ((dim_in, dim_out), nested_unet_depth, num_blocks, self_attn_blocks, heads, dim_head) in enumerate(zip(*down_stage_parameters)):
is_last = ind >= (num_resolutions - 1)
skip_dims.append(dim_in)

self.downs.append(nn.ModuleList([
blocks(dim_in, dim_in, nested_unet_depth = nested_unet_depth, nested_unet_dim = nested_unet_dim),
nn.ModuleList([blocks(dim_in, dim_in, nested_unet_depth = nested_unet_depth, nested_unet_dim = nested_unet_dim) for _ in range(num_blocks - 1)]),
nn.ModuleList([TransformerBlock(dim_in, depth = self_attn_blocks, **attn_kwargs) for _ in range(self_attn_blocks)]),
nn.ModuleList([TransformerBlock(dim_in, depth = self_attn_blocks, heads = heads, dim_head = dim_head) for _ in range(self_attn_blocks)]),
Downsample(dim_in, dim_out)
]))

Expand All @@ -391,20 +399,20 @@ def __init__(
mid_nested_unet_depth = nested_unet_depths[-1]

self.mid = blocks(mid_dim, mid_dim, nested_unet_depth = mid_nested_unet_depth, nested_unet_dim = nested_unet_dim)
self.mid_attn = Attention(mid_dim)
self.mid_attn = Attention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1])
self.mid_after = blocks(mid_dim, mid_dim, nested_unet_depth = mid_nested_unet_depth, nested_unet_dim = nested_unet_dim)

self.mid_upsample = Upsample(mid_dim, dims[-2])

# ups

for ind, ((dim_in, dim_out), nested_unet_depth, num_blocks, self_attn_blocks) in enumerate(zip(*up_stage_parameters)):
for ind, ((dim_in, dim_out), nested_unet_depth, num_blocks, self_attn_blocks, heads, dim_head) in enumerate(zip(*up_stage_parameters)):
is_last = ind >= (num_resolutions - 1)

self.ups.append(nn.ModuleList([
blocks(dim_out + skip_dims.pop(), dim_out, nested_unet_depth = nested_unet_depth, nested_unet_dim = nested_unet_dim),
nn.ModuleList([blocks(dim_out, dim_out, nested_unet_depth = nested_unet_depth, nested_unet_dim = nested_unet_dim) for _ in range(num_blocks - 1)]),
nn.ModuleList([TransformerBlock(dim_out, depth = self_attn_blocks, **attn_kwargs) for _ in range(self_attn_blocks)]),
nn.ModuleList([TransformerBlock(dim_out, depth = self_attn_blocks, heads = heads, dim_head = dim_head) for _ in range(self_attn_blocks)]),
Upsample(dim_out, dim_in) if not is_last else nn.Identity()
]))

Expand Down

0 comments on commit a8d1582

Please sign in to comment.