Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupNorm & InstanceNorm does not handle channels_last correctly #111824

Open
PeterL1n opened this issue Oct 23, 2023 · 6 comments 路 May be fixed by #126635
Open

GroupNorm & InstanceNorm does not handle channels_last correctly #111824

PeterL1n opened this issue Oct 23, 2023 · 6 comments 路 May be fixed by #126635
Assignees
Labels
actionable module: cuda Related to torch.cuda, and CUDA support in general module: memory format Memory format/layout related issues/changes (channels_last, nhwc) module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@PeterL1n
Copy link
Contributor

PeterL1n commented Oct 23, 2023

馃悰 Describe the bug

GroupNorm does not return channels_last tensor.

norm = nn.GroupNorm(8, 32).to(device, memory_format=torch.channels_last)
x = torch.randn([4, 32, 24, 24], device=device).to(memory_format=torch.channels_last)

print(x.stride())
assert x.is_contiguous(memory_format=torch.channels_last) # Pass

y = norm(x)

print(y.stride())
assert y.is_contiguous(memory_format=torch.channels_last) # Fail

I had to implement groupnorm manually. Somehow this works great and fast.

class GroupNorm(nn.GroupNorm):
    def forward(self, x):
        dtype = x.dtype
        x = x.float()
        x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups)

        mean = x.mean(dim=[2,3,4], keepdim=True)
        var = x.var(dim=[2,3,4], keepdim=True)

        x = (x - mean) * (var + self.eps).rsqrt()
        x = rearrange(x, "b g c h w -> b (g c) h w")

        if self.affine:
            weight = rearrange(self.weight, "c -> 1 c 1 1")
            bias = rearrange(self.bias, "c -> 1 c 1 1")
            x = x * weight + bias

        x = x.type(dtype)
        return x

Versions

Collecting environment information...
PyTorch version: 2.1.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31

Python version: 3.9.2 (default, Feb 28 2021, 17:03:44) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-5.4.56.bsk.11-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB
Nvidia driver version: 470.129.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.4
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping: 7
CPU MHz: 3599.940
BogoMIPS: 5999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 71.5 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Vulnerable
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Vulnerable, STIBP: disabled
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] byted-torch==2.1.0.post0
[pip3] byted-torch-monitor==0.0.1
[pip3] numpy==1.26.1
[pip3] torch==2.1.0
[pip3] torchaudio==2.1.0+cu121
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0
[conda] Could not collect

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @ptrblck @jamesr66a @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@soulitzer soulitzer added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 23, 2023
@mikaylagawarecki mikaylagawarecki added intel This tag is for PR from Intel actionable and removed intel This tag is for PR from Intel labels Oct 26, 2023
@mikaylagawarecki
Copy link
Contributor

mikaylagawarecki commented Oct 26, 2023

We would accept a PR that adds channels_last support to these functions on CUDA

@mingfeima
Copy link
Collaborator

we do have native channels last implementation for group norm https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L284

@CaoE could you please help check why this does not work?

@CaoE
Copy link
Collaborator

CaoE commented Oct 27, 2023

we do have native channels last implementation for group norm https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L284

There are channels last implementations for group norm on CPU. It seems that there is no Cl support on CUDA.

@PeterL1n What is the device in this example ?

@PeterL1n
Copy link
Contributor Author

we do have native channels last implementation for group norm https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L284

There are channels last implementations for group norm on CPU. It seems that there is no Cl support on CUDA.

@PeterL1n What is the device in this example ?

A100 GPU

@CaoE CaoE removed their assignment Oct 27, 2023
@CaoE CaoE added module: cuda Related to torch.cuda, and CUDA support in general module: memory format Memory format/layout related issues/changes (channels_last, nhwc) labels Oct 27, 2023
@ZelboK
Copy link
Contributor

ZelboK commented Apr 29, 2024

@mikaylagawarecki Hi, I can work on this. Is it ok if you assigned me to the issue?

@mikaylagawarecki
Copy link
Contributor

@ZelboK sure, assigned you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable module: cuda Related to torch.cuda, and CUDA support in general module: memory format Memory format/layout related issues/changes (channels_last, nhwc) module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: To pick up
Development

Successfully merging a pull request may close this issue.

6 participants