Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.ao.nn.quantized.Conv2d get error result in some intel cpu #126521

Closed
Novelfor opened this issue May 17, 2024 · 2 comments
Closed

torch.ao.nn.quantized.Conv2d get error result in some intel cpu #126521

Novelfor opened this issue May 17, 2024 · 2 comments
Assignees
Labels
oncall: quantization Quantization support in PyTorch

Comments

@Novelfor
Copy link

Novelfor commented May 17, 2024

馃悰 Describe the bug

import torch.ao.nn.quantized as nnq
import torch.nn as nn
import torch

if __name__ == "__main__":
    print(torch.__version__)
    qconv_module = nnq.Conv2d(3, 1, 3)
    conv_module = nn.Conv2d(3, 1, 3)

    in_channels = 3
    out_channels = 1
    inp_scale = 0.007812488358467817
    inp_zero_point = 128
    X = torch.Tensor([[[[61., 61., 59.],
        [70., 69., 65.],
        [79., 81., 75.]],

        [[61., 61., 59.],
        [70., 69., 65.],
        [79., 81., 75.]],

        [[61., 61., 59.],
        [70., 69., 65.],
        [79., 81., 75.]]]]) * inp_scale

    X_q = torch.quantize_per_tensor(X, 
        inp_scale, inp_zero_point, torch.quint8)

    W_scale = 0.046079955995082855
    W_zero_point = 0
    W = torch.Tensor([[[[  12,  -89,   17],
        [ -24, -116, -102],
        [ -18,   78,  -83]],

        [[  -8,  127,   92],
        [  15,  -11,  127],
        [  43,  -45,   -5]],

        [[ -44,   40,  -35],
        [ -60,  -35,  -44],
        [  90,  -36,   85]]]]) * W_scale
    W_q = torch.quantize_per_tensor(W, W_scale, W_zero_point, torch.qint8)
    b = torch.Tensor([1.193039894104004]).float()

    example_input = [X, ]
    example_input_q = [X_q, ]


    # Make sure the weight shape is correct
    Y_scale = 0.09463921934366226
    Y_zero_point = 0
    qconv_module.set_weight_bias(W_q, b)
    qconv_module.scale = Y_scale
    qconv_module.zero_point = Y_zero_point

    raw_conv_module = conv_module
    raw_conv_module.weight.data = W
    raw_conv_module.bias.data = b

    # Test forward
    Y_exp = conv_module(*example_input)
    Y_exp = torch.quantize_per_tensor(
        Y_exp, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8)
    Y_act = qconv_module(*example_input_q)

    print(Y_act.dequantize().item(), Y_exp.item())

I run this code get error result 0, 0.4731960892677307 in Intel(R) Xeon(R) CPU E3-1240 v6 @ 3.70GHz
But in Intel(R) Xeon(R) Gold 6230 CPU @ 2.10GHz, the result is correct 0.4731960892677307 0.4731960892677307
Pytorch version: 2.3.0+cu121

Versions

CORRECT CPU:

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.26.1
Libc version: glibc-2.35

Python version: 3.10.4 (main, Jun 29 2022, 12:14:53) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.25-nvidia-gpu-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: 
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          25
On-line CPU(s) list:             0-24
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Gold 6230 CPU @ 2.10GHz
CPU family:                      6
Model:                           85
Thread(s) per core:              1
Core(s) per socket:              25
Socket(s):                       1
Stepping:                        7
BogoMIPS:                        4190.15
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat umip pku avx512_vnni md_clear arch_capabilities
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       416 KiB (13 instances)
L1i cache:                       416 KiB (13 instances)
L2 cache:                        13 MiB (13 instances)
L3 cache:                        27.5 MiB (1 instance)
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Mitigation; Clear CPU buffers; SMT Host state unknown

Versions of relevant libraries:
[pip3] aimet-torch==0.0.2
[pip3] efficientnet-pytorch==0.7.1
[pip3] executorch==0.1.2
[pip3] msgpack-numpy==0.4.8
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] onnx==1.15.0
[pip3] onnxconverter-common==1.14.0
[pip3] onnxruntime==1.15.0
[pip3] onnxscript==0.1.0.dev20240325
[pip3] optree==0.9.1
[pip3] segmentation-models-pytorch==0.3.3
[pip3] torch==2.3.0
[pip3] torch-tensorrt==1.3.0
[pip3] torchaudio==2.3.0
[pip3] torchvision==0.18.0
[pip3] triton==2.3.0
[conda] blas                      1.0                         mkl    conda-forge
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] numpy                     1.25.1                   pypi_0    pypi
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.3.0               py310_cu121    pytorch
[conda] torchtriton               2.3.0                     py310    pytorch

ERROR CPU:

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.26.1
Libc version: glibc-2.35

Python version: 3.10.4 (main, Jun 29 2022, 12:14:53) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.15-1.el7.elrepo.x86_64-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: 
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   39 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          8
On-line CPU(s) list:             0
Off-line CPU(s) list:            1-7
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) CPU E3-1240 v6 @ 3.70GHz
CPU family:                      6
Model:                           158
Thread(s) per core:              2
Core(s) per socket:              4
Socket(s):                       1
Stepping:                        9
CPU max MHz:                     4100.0000
CPU min MHz:                     800.0000
BogoMIPS:                        7399.70
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single pti ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp
Virtualization:                  VT-x
L1d cache:                       128 KiB (4 instances)
L1i cache:                       128 KiB (4 instances)
L2 cache:                        1 MiB (4 instances)
L3 cache:                        8 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-7
Vulnerability Itlb multihit:     KVM: Mitigation: Split huge pages
Vulnerability L1tf:              Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:               Vulnerable: Clear CPU buffers attempted, no microcode; SMT vulnerable
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Full generic retpoline, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Tsx async abort:   Vulnerable: Clear CPU buffers attempted, no microcode; SMT vulnerable

Versions of relevant libraries:
[pip3] aimet-torch==0.0.2
[pip3] efficientnet-pytorch==0.7.1
[pip3] executorch==0.1.2
[pip3] msgpack-numpy==0.4.8
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] onnx==1.15.0
[pip3] onnxconverter-common==1.14.0
[pip3] onnxruntime==1.15.0
[pip3] onnxscript==0.1.0.dev20240325
[pip3] optree==0.9.1
[pip3] segmentation-models-pytorch==0.3.3
[pip3] torch==2.3.0
[pip3] torch-tensorrt==1.3.0
[pip3] torchaudio==2.3.0
[pip3] torchvision==0.18.0
[pip3] triton==2.3.0
[conda] blas                      1.0                         mkl    conda-forge
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] numpy                     1.25.1                   pypi_0    pypi
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.3.0               py310_cu121    pytorch
[conda] torchtriton               2.3.0                     py310    pytorch

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

@mikaylagawarecki mikaylagawarecki added the oncall: quantization Quantization support in PyTorch label May 20, 2024
@leslie-fang-intel leslie-fang-intel self-assigned this May 21, 2024
@leslie-fang-intel
Copy link
Collaborator

Hi @Novelfor, thanks for reporting the issue.
The first CPU has Vector Neural Network Instruction (VNNI) supported as the flag avx512_vnni and the second one doesn't.
For the CPU without VNNI, oneDNN needs 7 bits [0, 127] instead of 8 bit for activation as pointed in https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html#inputs-of-mixed-type-u8-and-s8.
BTW: if you are using FX Quantization with oneDNN Backend, there is also a warning message for this potential error.

@leslie-fang-intel
Copy link
Collaborator

Hi @Novelfor, I am tending to close this issue as expected. Feel free to reopen if any further question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: quantization Quantization support in PyTorch
Projects
None yet
Development

No branches or pull requests

3 participants