Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning stalls with 2 GPUs on 1 node with SLURM (and apptainer) #19883

Open
sorenwacker opened this issue May 20, 2024 · 1 comment
Open

Lightning stalls with 2 GPUs on 1 node with SLURM (and apptainer) #19883

sorenwacker opened this issue May 20, 2024 · 1 comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers

Comments

@sorenwacker
Copy link

sorenwacker commented May 20, 2024

Bug description

When using 2 GPUs on a single node, or multiple nodes on multiple nodes the training does not start while the job keeps running. I use a container to deploy the environment and SLURM. Is there specific cluster/slurm configuration required to make this work?

What version are you seeing the problem on?

v2.2

How to reproduce the bug

script.py

# Lightning implementation

import os
import torch
import torchvision as tv
import lightning as L
from torch.utils.data import DataLoader


class CIFAR10DataModule(L.LightningDataModule):
    def __init__(self, batch_size=64, num_workers=1):
        super().__init__()
        self.data_dir = os.getenv("DATASETS_ROOT", "./data")
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_dataset = None
        self.val_dataset = None

    def prepare_data(self):
        # Ensure CIFAR10 is downloaded once
        tv.datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        tv.datasets.CIFAR10(root=self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Apply transformations
        transform = tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Setup training and validation datasets
        if stage in (None, 'fit', 'train'):
            self.train_dataset = tv.datasets.CIFAR10(root=self.data_dir, train=True, transform=transform)
        if stage in (None, 'fit', 'validate', 'test'):
            self.val_dataset = tv.datasets.CIFAR10(root=self.data_dir, train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        # Check if val_dataset is defined
        if self.val_dataset is None:
            self.setup('validate')
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)


class CIFAR10Model(L.LightningModule):
    def __init__(self):
        super().__init__()
        # Use the new weights parameter with None to signify no pretrained weights
        self.model = tv.models.resnet18(weights=None)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.001)

def main():
    # Initialize LightningDataModule
    data_module = CIFAR10DataModule()

    # Initialize LightningModel
    model = CIFAR10Model()

    # Get cluster configuration from environment variables
    SLURM_NNODES = int(os.getenv('SLURM_NNODES', '1'))
    SLURM_NTASKS_PER_NODE = int(os.getenv('SLURM_NTASKS_PER_NODE', '1'))

    # Debugging
    print(f"SLURM_NNODES: {SLURM_NNODES}")
    print(f"SLURM_NTASKS_PER_NODE: {SLURM_NTASKS_PER_NODE}")

    # Initialize Trainer
    trainer = L.Trainer(
        max_epochs=5, 
        devices=SLURM_NTASKS_PER_NODE, 
        accelerator='gpu',
        num_nodes=SLURM_NNODES,
        strategy='ddp',
        enable_progress_bar=True
    )
    
    # Train the model
    trainer.fit(model, datamodule=data_module)

    # Evaluate on test set
    test_result = trainer.test(datamodule=data_module)
    print(test_result)

if __name__ == "__main__":
    main()

sbatch submission script

#!/bin/sh
#SBATCH --job-name=cifar-lit-2GPU
...
#SBATCH --time=0:30:00                
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=20             
#SBATCH --gres=gpu:2                  
#SBATCH --mem=4GB                   
#SBATCH --output=slurm-%x-%j.out
#SBATCH --error=slurm-%x-%j.err   

# Start measuring execution time
start_time=$(date +%s)

export APPTAINER_HOME=/tudelft.net/staff-umbrella/reit/apptainer
export APPTAINER_NAME=pytorch2.2.1-cuda12.1.sif

# Check that container file exists
if [ ! -f $APPTAINER_HOME/$APPTAINER_NAME ]; then
    ls $APPTAINER_HOME/$APPTAINER_NAME
    exit 1
fi 

# Load CUDA that is compatible to container libraries
module use /opt/insy/modulefiles
module load cuda/12.1

# Run script
srun apptainer exec \
    --nv \
    --env-file ~/.env \
    -B /home/:/home/ \
    -B /tudelft.net/:/tudelft.net/ \
    $APPTAINER_HOME/$APPTAINER_NAME \
    python script.py

# End measuring execution time
end_time=$(date +%s)

elapsed_time=$((end_time - start_time))
echo "Elapsed time: $elapsed_time seconds"

Error messages and logs

==> slurm-cifar-lit-2GPU-10065210.err <==
HPU available: False, using: 0 HPUs
/opt/conda/envs/__apptainer__/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
==> slurm-cifar-lit-2GPU-10065210.out <==
SLURM_NNODES: 1
SLURM_NTASKS_PER_NODE: 2
SLURM_NNODES: 1
SLURM_NTASKS_PER_NODE: 2
Files already downloaded and verified
Files already downloaded and verified

Environment

name: __apptainer__
channels:
  - pytorch
  - anaconda
  - nvidia
  - conda-forge
dependencies:
  - _libgcc_mutex=0.1
  - _openmp_mutex=4.5
  - accelerate=0.30.1
  - aiohttp=3.9.5
  - aiosignal=1.3.1
  - alsa-lib=1.2.11
  - annotated-types=0.6.0
  - anyio=4.3.0
  - aom=3.9.0
  - appdirs=1.4.4
  - argon2-cffi=23.1.0
  - argon2-cffi-bindings=21.2.0
  - arrow=1.3.0
  - asttokens=2.4.1
  - async-lru=2.0.4
  - attr=2.5.1
  - attrs=23.2.0
  - aws-c-auth=0.7.20
  - aws-c-cal=0.6.12
  - aws-c-common=0.9.17
  - aws-c-compression=0.2.18
  - aws-c-event-stream=0.4.2
  - aws-c-http=0.8.1
  - aws-c-io=0.14.8
  - aws-c-mqtt=0.10.4
  - aws-c-s3=0.5.8
  - aws-c-sdkutils=0.1.16
  - aws-checksums=0.1.18
  - aws-crt-cpp=0.26.8
  - aws-sdk-cpp=1.11.267
  - babel=2.14.0
  - beautifulsoup4=4.12.3
  - blas=2.116
  - blas-devel=3.9.0
  - bleach=6.1.0
  - bokeh=3.4.1
  - brotli=1.1.0
  - brotli-bin=1.1.0
  - brotli-python=1.1.0
  - bzip2=1.0.8
  - c-ares=1.28.1
  - ca-certificates=2024.2.2
  - cached-property=1.5.2
  - cached_property=1.5.2
  - cairo=1.18.0
  - catalogue=2.0.10
  - certifi=2024.2.2
  - cffi=1.16.0
  - charset-normalizer=3.3.2
  - click=8.1.7
  - cloudpathlib=0.16.0
  - cloudpickle=3.0.0
  - colorama=0.4.6
  - comm=0.2.2
  - confection=0.1.4
  - contourpy=1.2.1
  - cuda-cudart=12.1.105
  - cuda-cudart_linux-64=12.1.105
  - cuda-cupti=12.1.105
  - cuda-libraries=12.1.0
  - cuda-nvrtc=12.1.105
  - cuda-nvtx=12.1.105
  - cuda-opencl=12.1.105
  - cuda-runtime=12.1.0
  - cuda-version=12.1
  - cycler=0.12.1
  - cymem=2.0.8
  - cython-blis=0.7.10
  - cytoolz=0.12.3
  - dask=2024.5.0
  - dask-core=2024.5.0
  - dask-expr=1.1.0
  - datasets=2.19.1
  - dav1d=1.2.1
  - dbus=1.13.6
  - debugpy=1.8.1
  - decorator=5.1.1
  - deepspeed=0.14.0
  - defusedxml=0.7.1
  - diffusers=0.27.2
  - dill=0.3.8
  - distributed=2024.5.0
  - docker-pycreds=0.4.0
  - double-conversion=3.3.0
  - entrypoints=0.4
  - exceptiongroup=1.2.0
  - executing=2.0.1
  - expat=2.6.2
  - fastai=2.7.15
  - fastcore=1.5.35
  - fastdownload=0.0.7
  - fastprogress=1.0.3
  - ffmpeg=6.1.1
  - filelock=3.14.0
  - font-ttf-dejavu-sans-mono=2.37
  - font-ttf-inconsolata=3.000
  - font-ttf-source-code-pro=2.038
  - font-ttf-ubuntu=0.83
  - fontconfig=2.14.2
  - fonts-conda-ecosystem=1
  - fonts-conda-forge=1
  - fonttools=4.51.0
  - fqdn=1.5.1
  - freeglut=3.2.2
  - freetype=2.12.1
  - fribidi=1.0.10
  - frozenlist=1.4.1
  - fsspec=2024.3.1
  - gettext=0.22.5
  - gettext-tools=0.22.5
  - gflags=2.2.2
  - gitdb=4.0.11
  - gitpython=3.1.43
  - glib=2.80.2
  - glib-tools=2.80.2
  - glog=0.7.0
  - gmp=6.3.0
  - gmpy2=2.1.5
  - gnutls=3.7.9
  - graphite2=1.3.13
  - gst-plugins-base=1.22.9
  - gstreamer=1.22.9
  - h11=0.14.0
  - h2=4.1.0
  - harfbuzz=8.4.0
  - hdf5=1.14.3
  - hjson-py=3.1.0
  - hpack=4.0.0
  - httpcore=1.0.5
  - httpx=0.27.0
  - huggingface_hub=0.23.0
  - hyperframe=6.0.1
  - icu=73.2
  - idna=3.7
  - imath=3.1.11
  - importlib-metadata=7.1.0
  - importlib_metadata=7.1.0
  - importlib_resources=6.4.0
  - ipykernel=6.29.3
  - ipython=8.24.0
  - ipywidgets=8.1.2
  - isoduration=20.11.0
  - jasper=4.2.4
  - jedi=0.19.1
  - jinja2=3.1.4
  - joblib=1.4.2
  - json5=0.9.25
  - jsonpointer=2.4
  - jsonschema=4.22.0
  - jsonschema-specifications=2023.12.1
  - jsonschema-with-format-nongpl=4.22.0
  - jupyter-lsp=2.2.5
  - jupyter_client=8.6.1
  - jupyter_core=5.7.2
  - jupyter_events=0.10.0
  - jupyter_server=2.14.0
  - jupyter_server_terminals=0.5.3
  - jupyterlab=4.2.0
  - jupyterlab_pygments=0.3.0
  - jupyterlab_server=2.27.1
  - jupyterlab_widgets=3.0.10
  - keyutils=1.6.1
  - kiwisolver=1.4.5
  - krb5=1.21.2
  - lame=3.100
  - langcodes=3.4.0
  - language-data=1.2.0
  - lcms2=2.16
  - ld_impl_linux-64=2.40
  - lerc=4.0.0
  - libabseil=20240116.2
  - libaec=1.1.3
  - libaio=0.3.113
  - libarrow=16.0.0
  - libarrow-acero=16.0.0
  - libarrow-dataset=16.0.0
  - libarrow-substrait=16.0.0
  - libasprintf=0.22.5
  - libasprintf-devel=0.22.5
  - libass=0.17.1
  - libblas=3.9.0
  - libbrotlicommon=1.1.0
  - libbrotlidec=1.1.0
  - libbrotlienc=1.1.0
  - libcap=2.69
  - libcblas=3.9.0
  - libclang-cpp18.1=18.1.5
  - libclang13=18.1.5
  - libcrc32c=1.1.2
  - libcublas=12.1.0.26
  - libcufft=11.0.2.4
  - libcufile=1.6.1.9
  - libcups=2.3.3
  - libcurand=10.3.2.106
  - libcurl=8.7.1
  - libcusolver=11.4.4.55
  - libcusparse=12.0.2.55
  - libdeflate=1.20
  - libdrm=2.4.120
  - libedit=3.1.20191231
  - libev=4.33
  - libevent=2.1.12
  - libexpat=2.6.2
  - libffi=3.4.2
  - libflac=1.4.3
  - libgcc-ng=13.2.0
  - libgcrypt=1.10.3
  - libgettextpo=0.22.5
  - libgettextpo-devel=0.22.5
  - libgfortran-ng=13.2.0
  - libgfortran5=13.2.0
  - libglib=2.80.2
  - libglu=9.0.0
  - libgoogle-cloud=2.23.0
  - libgoogle-cloud-storage=2.23.0
  - libgpg-error=1.49
  - libgrpc=1.62.2
  - libhwloc=2.10.0
  - libiconv=1.17
  - libidn2=2.3.7
  - libjpeg-turbo=3.0.0
  - liblapack=3.9.0
  - liblapacke=3.9.0
  - libllvm18=18.1.5
  - libnghttp2=1.58.0
  - libnpp=12.0.2.50
  - libnsl=2.0.1
  - libnvjitlink=12.1.105
  - libnvjpeg=12.1.1.14
  - libogg=1.3.4
  - libopencv=4.9.0
  - libopenvino=2024.0.0
  - libopenvino-auto-batch-plugin=2024.0.0
  - libopenvino-auto-plugin=2024.0.0
  - libopenvino-hetero-plugin=2024.0.0
  - libopenvino-intel-cpu-plugin=2024.0.0
  - libopenvino-intel-gpu-plugin=2024.0.0
  - libopenvino-ir-frontend=2024.0.0
  - libopenvino-onnx-frontend=2024.0.0
  - libopenvino-paddle-frontend=2024.0.0
  - libopenvino-pytorch-frontend=2024.0.0
  - libopenvino-tensorflow-frontend=2024.0.0
  - libopenvino-tensorflow-lite-frontend=2024.0.0
  - libopus=1.3.1
  - libparquet=16.0.0
  - libpciaccess=0.18
  - libpng=1.6.43
  - libpq=16.3
  - libprotobuf=4.25.3
  - libre2-11=2023.09.01
  - libsndfile=1.2.2
  - libsodium=1.0.18
  - libsqlite=3.45.3
  - libssh2=1.11.0
  - libstdcxx-ng=13.2.0
  - libsystemd0=255
  - libtasn1=4.19.0
  - libthrift=0.19.0
  - libtiff=4.6.0
  - libunistring=0.9.10
  - libutf8proc=2.8.0
  - libuuid=2.38.1
  - libva=2.21.0
  - libvorbis=1.3.7
  - libvpx=1.14.0
  - libwebp-base=1.4.0
  - libxcb=1.15
  - libxcrypt=4.4.36
  - libxkbcommon=1.7.0
  - libxml2=2.12.7
  - libzlib=1.2.13
  - lightning=2.2.4
  - lightning-utilities=0.11.2
  - llvm-openmp=15.0.7
  - locket=1.0.0
  - lz4=4.3.3
  - lz4-c=1.9.4
  - marisa-trie=1.1.0
  - markdown-it-py=3.0.0
  - markupsafe=2.1.5
  - matplotlib-base=3.8.4
  - matplotlib-inline=0.1.7
  - mdurl=0.1.2
  - mistune=3.0.2
  - mkl=2022.1.0
  - mkl-devel=2022.1.0
  - mkl-include=2022.1.0
  - mpc=1.3.1
  - mpfr=4.2.1
  - mpg123=1.32.6
  - mpmath=1.3.0
  - msgpack-python=1.0.8
  - multidict=6.0.5
  - multiprocess=0.70.16
  - munkres=1.1.4
  - murmurhash=1.0.10
  - mysql-common=8.3.0
  - mysql-libs=8.3.0
  - nbclient=0.10.0
  - nbconvert-core=7.16.4
  - nbformat=5.10.4
  - ncurses=6.5
  - nest-asyncio=1.6.0
  - nettle=3.9.1
  - networkx=3.3
  - notebook-shim=0.2.4
  - numpy=1.26.4
  - ocl-icd=2.3.2
  - onnx=1.16.0
  - opencv=4.9.0
  - openexr=3.2.2
  - openh264=2.4.1
  - openjpeg=2.5.2
  - openssl=3.3.0
  - orc=2.0.0
  - overrides=7.7.0
  - p11-kit=0.24.1
  - packaging=24.0
  - pandas=2.2.2
  - pandocfilters=1.5.0
  - parso=0.8.4
  - partd=1.4.2
  - pathtools=0.1.2
  - pathy=0.10.1
  - patsy=0.5.6
  - pcre2=10.43
  - pexpect=4.9.0
  - pickleshare=0.7.5
  - pillow=10.3.0
  - pip=24.0
  - pixman=0.43.2
  - pkgutil-resolve-name=1.3.10
  - platformdirs=4.2.1
  - plotly=5.22.0
  - preshed=3.0.9
  - pretty_errors=1.2.25
  - prometheus_client=0.20.0
  - prompt-toolkit=3.0.42
  - protobuf=4.25.3
  - psutil=5.9.8
  - pthread-stubs=0.4
  - ptyprocess=0.7.0
  - pugixml=1.14
  - pulseaudio-client=17.0
  - pure_eval=0.2.2
  - py-cpuinfo=9.0.0
  - py-opencv=4.9.0
  - pyarrow=16.0.0
  - pyarrow-core=16.0.0
  - pyarrow-hotfix=0.6
  - pycparser=2.22
  - pydantic=2.7.1
  - pydantic-core=2.18.2
  - pygments=2.18.0
  - pynvml=11.5.0
  - pyparsing=3.1.2
  - pysocks=1.7.1
  - python=3.11.9
  - python-dateutil=2.9.0
  - python-fastjsonschema=2.19.1
  - python-json-logger=2.0.7
  - python-tzdata=2024.1
  - python-xxhash=3.4.1
  - python_abi=3.11
  - pytorch=2.2.1
  - pytorch-cuda=12.1
  - pytorch-lightning=2.2.2
  - pytorch-mutex=1.0
  - pytz=2024.1
  - pyyaml=6.0.1
  - pyzmq=26.0.3
  - qt6-main=6.6.3
  - re2=2023.09.01
  - readline=8.2
  - referencing=0.35.1
  - regex=2024.5.10
  - requests=2.31.0
  - rfc3339-validator=0.1.4
  - rfc3986-validator=0.1.1
  - rich=13.7.1
  - rpds-py=0.18.1
  - s2n=1.4.13
  - safetensors=0.4.3
  - scikit-learn=1.4.2
  - scipy=1.13.0
  - seaborn=0.13.2
  - seaborn-base=0.13.2
  - send2trash=1.8.3
  - sentry-sdk=2.1.1
  - setproctitle=1.3.3
  - setuptools=69.5.1
  - shellingham=1.5.4
  - six=1.16.0
  - smart_open=6.4.0
  - smmap=5.0.0
  - snappy=1.2.0
  - sniffio=1.3.1
  - sortedcontainers=2.4.0
  - soupsieve=2.5
  - spacy=3.7.3
  - spacy-legacy=3.0.12
  - spacy-loggers=1.0.5
  - srsly=2.4.8
  - stack_data=0.6.2
  - statsmodels=0.14.1
  - svt-av1=2.0.0
  - sympy=1.12
  - tbb=2021.12.0
  - tblib=3.0.0
  - tenacity=8.3.0
  - terminado=0.18.1
  - thinc=8.2.3
  - threadpoolctl=3.5.0
  - tinycss2=1.3.0
  - tk=8.6.13
  - tokenizers=0.19.1
  - tomli=2.0.1
  - toolz=0.12.1
  - torchaudio=2.2.1
  - torchmetrics=1.4.0
  - torchtriton=2.2.0
  - torchvision=0.17.1
  - tornado=6.4
  - tqdm=4.66.4
  - traitlets=5.14.3
  - transformers=4.40.2
  - typer=0.9.4
  - types-python-dateutil=2.9.0.20240316
  - typing-extensions=4.11.0
  - typing_extensions=4.11.0
  - typing_utils=0.1.0
  - tzdata=2024a
  - uri-template=1.3.0
  - urllib3=2.2.1
  - wandb=0.16.5
  - wasabi=1.1.2
  - wcwidth=0.2.13
  - weasel=0.3.4
  - webcolors=1.13
  - webencodings=0.5.1
  - websocket-client=1.8.0
  - wheel=0.43.0
  - widgetsnbextension=4.0.10
  - x264=1!164.3095
  - x265=3.5
  - xcb-util=0.4.0
  - xcb-util-cursor=0.1.4
  - xcb-util-image=0.4.0
  - xcb-util-keysyms=0.4.0
  - xcb-util-renderutil=0.3.9
  - xcb-util-wm=0.4.1
  - xkeyboard-config=2.41
  - xorg-fixesproto=5.0
  - xorg-inputproto=2.3.2
  - xorg-kbproto=1.0.7
  - xorg-libice=1.1.1
  - xorg-libsm=1.2.4
  - xorg-libx11=1.8.9
  - xorg-libxau=1.0.11
  - xorg-libxdmcp=1.1.3
  - xorg-libxext=1.3.4
  - xorg-libxfixes=5.0.3
  - xorg-libxi=1.7.10
  - xorg-libxrender=0.9.11
  - xorg-renderproto=0.11.1
  - xorg-xextproto=7.3.0
  - xorg-xproto=7.0.31
  - xxhash=0.8.2
  - xyzservices=2024.4.0
  - xz=5.2.6
  - yaml=0.2.5
  - yarl=1.9.4
  - zeromq=4.3.5
  - zict=3.0.0
  - zipp=3.17.0
  - zlib=1.2.13
  - zstd=1.5.6
  - pip:
      - mpi4py==3.1.6
      - watermark==2.4.3
prefix: /opt/conda/envs/__apptainer__
@sorenwacker sorenwacker added bug Something isn't working needs triage Waiting to be triaged by maintainers labels May 20, 2024
@enesmsahin
Copy link

Same issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants