Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Training MVit with RawFrameDataset type #2834

Open
3 tasks done
maximefuchs opened this issue Apr 26, 2024 · 0 comments
Open
3 tasks done

[Bug] Training MVit with RawFrameDataset type #2834

maximefuchs opened this issue Apr 26, 2024 · 0 comments
Assignees

Comments

@maximefuchs
Copy link

maximefuchs commented Apr 26, 2024

Branch

main branch (1.x version, such as v1.0.0, or dev-1.x branch)

Prerequisite

Environment

sys.platform: linux
Python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0: NVIDIA RTX A4000
CUDA_HOME: /usr
NVCC: Cuda compilation tools, release 11.5, V11.5.119
GCC: x86_64-linux-gnu-gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
PyTorch: 2.2.2+cu121
PyTorch compiling details: PyTorch built with:

  • GCC 9.3
  • C++ Version: 201703
  • Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • LAPACK is enabled (usually provided by MKL)
  • NNPACK is enabled
  • CPU capability usage: AVX2
  • CUDA Runtime 12.1
  • NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  • CuDNN 8.9.2
  • Magma 2.6.1
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.2.2, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

TorchVision: 0.17.2+cu121
OpenCV: 4.9.0
MMEngine: 0.10.4
MMAction2: 1.2.0+
MMCV: 2.1.0
MMDetection: 3.3.0

Describe the bug

I'm trying to set up a training using MViT with frames already extracted.
Hence, I'm using a RawframeDataset dataset type.
Here is my config file:

_base_ = [
    "../mmaction2/configs/_base_/models/mvit_small.py",
    "../mmaction2/configs/_base_/default_runtime.py",
]
# dataset settings
classes = ("cl1", "cl2", "cl3", "cl4", "cl5", "cl6", "cl7")
num_class = len(classes) 
dataset_type = "RawframeDataset"
data_root = "/home/maxime/Documents/DATA/dataset_classifier/"
ann_file_train = "train.txt"
ann_file_val = "val.txt"
ann_file_test = "test.txt"
# hyperparameters
clip_len = 8
batch_size = 4
num_workers = 1

metainfo = dict(classes=classes)
model = dict(
    backbone=dict(
        arch="base",
        temporal_size=clip_len,
        drop_path_rate=0.3,
    ),
    data_preprocessor=dict(
        type="ActionDataPreprocessor",
        mean=[114.75, 114.75, 114.75],
        std=[57.375, 57.375, 57.375],
        blending=dict(
            type="RandomBatchAugment",
            augments=[
                dict(type="MixupBlending", alpha=0.8, num_classes=num_class),
                dict(type="CutmixBlending", alpha=1, num_classes=num_class),
            ],
        ),
        format_shape="NCTHW",
    ),
    cls_head=dict(num_classes=num_class),
)


train_pipeline = [
    dict(type="SampleFrames", clip_len=clip_len, frame_interval=1, num_clips=3),
    dict(type="RawFrameDecode"),
    dict(type="Resize", scale=(-1, 256)),
    dict(type="RandomResizedCrop"),
    dict(type="Resize", scale=(224, 224), keep_ratio=False),
    dict(type="Flip", flip_ratio=0.5),
    dict(type="FormatShape", input_format="NCTHW"),
    dict(type="PackActionInputs"),
]
val_pipeline = [
    dict(
        type="SampleFrames",
        clip_len=clip_len,
        frame_interval=1,
        num_clips=3,
        test_mode=True,
    ),
    dict(type="RawFrameDecode"),
    dict(type="Resize", scale=(-1, 256)),
    dict(type="CenterCrop", crop_size=224),
    dict(type="FormatShape", input_format="NCTHW"),
    dict(type="PackActionInputs"),
]
test_pipeline = [
    dict(
        type="SampleFrames",
        clip_len=clip_len,
        frame_interval=1,
        num_clips=25,
        test_mode=True,
    ),
    dict(type="RawFrameDecode"),
    dict(type="Resize", scale=(-1, 256)),
    dict(type="TenCrop", crop_size=224),
    dict(type="FormatShape", input_format="NCTHW"),
    dict(type="PackActionInputs"),
]

train_dataloader = dict(
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    sampler=dict(type="DefaultSampler", shuffle=True),
    dataset=dict(
        type=dataset_type,
        metainfo=metainfo,
        ann_file=data_root + ann_file_train,
        filename_tmpl="img_{:05}.png",  # id of images has to start at 1
        # modality="Flow",
        data_prefix=dict(img=data_root),
        pipeline=train_pipeline,
    ),
)
val_dataloader = dict(
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    sampler=dict(type="DefaultSampler", shuffle=False),
    dataset=dict(
        type=dataset_type,
        metainfo=metainfo,
        ann_file=data_root + ann_file_val,
        filename_tmpl="img_{:05}.png",  # id of images has to start at 1
        # modality="Flow",
        data_prefix=dict(img=data_root),
        pipeline=val_pipeline,
        test_mode=True,
    ),
)
test_dataloader = dict(
    batch_size=1,
    num_workers=num_workers,
    persistent_workers=True,
    sampler=dict(type="DefaultSampler", shuffle=False),
    dataset=dict(
        type=dataset_type,
        metainfo=metainfo,
        ann_file=data_root + ann_file_test,
        filename_tmpl="img_{:05}.png",  # id of images has to start at 1
        # modality="Flow",
        data_prefix=dict(img=data_root),
        pipeline=test_pipeline,
        test_mode=True,
    ),
)

val_evaluator = dict(type="AccMetric")
test_evaluator = val_evaluator

train_cfg = dict(
    type="EpochBasedTrainLoop", max_epochs=200, val_begin=1, val_interval=1
)
val_cfg = dict(type="ValLoop")
test_cfg = dict(type="TestLoop")

base_lr = 1.6e-3
optim_wrapper = dict(
    optimizer=dict(type="AdamW", lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05),
    paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0),
    clip_grad=dict(max_norm=1, norm_type=2),
)

param_scheduler = [
    dict(
        type="LinearLR",
        start_factor=0.01,
        by_epoch=True,
        begin=0,
        end=30,
        convert_to_iter_based=True,
    ),
    dict(
        type="CosineAnnealingLR",
        T_max=200,
        eta_min=base_lr / 100,
        by_epoch=True,
        begin=30,
        end=200,
        convert_to_iter_based=True,
    ),
]

default_hooks = dict(
    checkpoint=dict(interval=1, max_keep_ckpts=5), logger=dict(interval=100)
)

# Default setting for scaling LR automatically
#   - `enable` means enable scaling LR automatically
#       or not by default.
#   - `base_batch_size` = (8 GPUs) x (8 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=256)

When launching the training, I get the following error:

Traceback (most recent call last):
  File "/home/maxime/Documents/classification/mmaction2/tools/train.py", line 143, in <module>
    main()
  File "/home/maxime/Documents/classification/mmaction2/tools/train.py", line 139, in main
    runner.train()
  File "/home/maxime/Documents/classification/.venv/lib/python3.10/site-packages/mmengine/runner/runner.py", line 1777, in train
    model = self.train_loop.run()  # type: ignore
  File "/home/maxime/Documents/classification/.venv/lib/python3.10/site-packages/mmengine/runner/loops.py", line 96, in run
    self.run_epoch()
  File "/home/maxime/Documents/classification/.venv/lib/python3.10/site-packages/mmengine/runner/loops.py", line 114, in run_epoch
    self.run_iter(idx, data_batch)
  File "/home/maxime/Documents/classification/.venv/lib/python3.10/site-packages/mmengine/runner/loops.py", line 130, in run_iter
    outputs = self.runner.model.train_step(
  File "/home/maxime/Documents/classification/.venv/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 118, in train_step
    losses = self._run_forward(data, mode='loss')  # type: ignore
  File "/home/maxime/Documents/classification/.venv/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 365, in _run_forward
    results = self(**data, mode=mode)
  File "/home/maxime/Documents/classification/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/maxime/Documents/classification/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/maxime/Documents/classification/mmaction2/mmaction/models/recognizers/base.py", line 264, in forward
    return self.loss(inputs, data_samples, **kwargs)
  File "/home/maxime/Documents/classification/mmaction2/mmaction/models/recognizers/base.py", line 177, in loss
    loss_cls = self.cls_head.loss(feats, data_samples, **loss_kwargs)
  File "/home/maxime/Documents/classification/mmaction2/mmaction/models/heads/base.py", line 104, in loss
    return self.loss_by_feat(cls_scores, data_samples)
  File "/home/maxime/Documents/classification/mmaction2/mmaction/models/heads/base.py", line 136, in loss_by_feat
    top_k_acc = top_k_accuracy(
  File "/home/maxime/Documents/classification/mmaction2/mmaction/evaluation/functional/accuracy.py", line 148, in top_k_accuracy
    match_array = np.logical_or.reduce(max_k_preds == labels, axis=1)
ValueError: operands could not be broadcast together with shapes (12,5) (4,1,7) 

After some research, it appears it comes from num_clips=3. In the top_k_accuracy function, my inputs are of shape (12,7) and my ground truth is of size (4,7). It feels like the ground truth values are not updated regarding the num_clips parameter.

Reproduces the problem - code sample

No response

Reproduces the problem - command or script

No response

Reproduces the problem - error message

No response

Additional information

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants