Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in running the command nnUNetv2_train: cannot access local variable 'region_labels' where it is not associated with a value #2196

Open
sinaziaee opened this issue May 16, 2024 · 6 comments
Assignees

Comments

@sinaziaee
Copy link

Hi I am new to nnUNet and I wanted to use it on Kits23 dataset. To start, I am using selecting 5 cases from the datasets randomly and putting them in a folder.
Then I used
export nnUNet_raw="mypath/playground/nnUNet_raw/" export nnUNet_preprocessed="mypath/playground/processed" export nnUNet_results="mypath/playground/nnUNet_results"
to set the paths
Then I used python Dataset220_KiTS2023.py <dataset_folder_with_5_cases>
Then I used nnUNetv2_plan_and_preprocess -d 220 --verify_dataset_integrity -c 2d 3d_fullres 3d_lowres
and created to preprocess the data.
Then I tried to run nnUNetv2_train 220 2d 0 or nnUNetv2_train 220 3d_lowres 0
but I always get this error:
cannot access local variable 'region_labels' where it is not associated with a value

So, I'm lost as I don't know what to do in this step.

@sinaziaee sinaziaee changed the title cannot access local variable 'region_labels' where it is not associated with a value Error in running the command nnUNetv2_train: cannot access local variable 'region_labels' where it is not associated with a value May 16, 2024
@jiashizuo
Copy link

Excuse me, have you solved it? I also encountered this error report.

@sinaziaee
Copy link
Author

Excuse me, have you solved it? I also encountered this error report.

No, unfortunately.

@htcwf89
Copy link

htcwf89 commented May 27, 2024

I'm running into the same error when trying to run region-based training. I traced the issue back to the "ConvertSegmentationToRegionsTransform" class in the batchgeneratorsv2 package (batchgeneratorsv2/batchgeneratorsv2/transforms/utils/seg_to_regions.py).
It appears the latest update by @FabianIsensee that was pushed as the result of this comment #2136 (comment) has broken the region-based training in nnunetv2.
However, I tried manually reverting the "ConvertSegmentationToRegionsTransform" class to previous versions visible in the history here but those introduced different errors instead.
@FabianIsensee & @GregorKoehler could you guys please take a look?

p.s. I'm using nnunetv2 V2.5, torch 2.1.2+cu118, and batchgeneratorsv2 0.1.1
My OS is Linux and my GPUs are RTX A6000s.

@FabianIsensee
Copy link
Member

This should be fixed if you install both the batchgeneratorsv2 and nnunetv2 current master

@htcwf89
Copy link

htcwf89 commented May 29, 2024

Thanks for the super quick response. I tried it but unfortunately won't work. I installed the current masters through cloning from git but it doesn't seem the loss function likes the bool. The full error traceback is below. I haven't gotten around to experimenting with potential solutions but will report back if I find a fix.

In the meantime, I installed nnunet2 V2.4 and was able to run the region-based training, if others want to get things moving on their work @sinaziaee @jiashizuo

Here's the traceback:

"(/home/my_user/nnunet2_envNew) my_user@01:~$ CUDA_VISIBLE_DEVICES=3 ############################ INFO: You are using the old nnU-Net default plans. We have updated our recommendations. Please consider using those instead! Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md ####################################################################### Please cite the following paper when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
#######################################################################
2024-05-28 14:24:29.521258: do_dummy_2d_data_aug: True
2024-05-28 14:24:29.534906: Using splits from existing split file: /mnt/Stuff/MultiModal/nnUNet_preprocessed/Dataset727 /splits_final.json
2024-05-28 14:24:29.540675: The split file contains 5 splits.
2024-05-28 14:24:29.544157: Desired fold for training: 0
2024-05-28 14:24:29.547901: This split has 40 training and 10 validation cases.
using pin_memory on device 0
using pin_memory on device 0
2024-05-28 14:24:35.571942: Using torch.compile...

This is the configuration used by this training:
Configuration name: 3d_fullres
{'data_identifier': 'nnUNetPlans_3d_fullres', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 2, 'patch_size': [32, 256, 224], 'median_image_size_in_voxels': [33.0, 256.0, 252.0], 'spacing': [3.0, 0.5, 0.5], 'normalization_schemes': ['ZScoreNormalization'], 'use_mask_for_norm': [False], 'resampling_fn_data': 'resample_data_or_seg_to_shape', 'resampling_fn_seg': 'resample_data_or_seg_to_shape', 'resampling_fn_data_kwargs': {'is_seg': False, 'order': 3, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_seg_kwargs': {'is_seg': True, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_probabilities': 'resample_data_or_seg_to_shape', 'resampling_fn_probabilities_kwargs': {'is_seg': False, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'architecture': {'network_class_name': 'dynamic_network_architectures.architectures.unet.PlainConvUNet', 'arch_kwargs': {'n_stages': 6, 'features_per_stage': [32, 64, 128, 256, 320, 320], 'conv_op': 'torch.nn.modules.conv.Conv3d', 'kernel_sizes': [[1, 3, 3], [1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], 'strides': [[1, 1, 1], [1, 2, 2], [1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 'n_conv_per_stage': [2, 2, 2, 2, 2, 2], 'n_conv_per_stage_decoder': [2, 2, 2, 2, 2], 'conv_bias': True, 'norm_op': 'torch.nn.modules.instancenorm.InstanceNorm3d', 'norm_op_kwargs': {'eps': 1e-05, 'affine': True}, 'dropout_op': None, 'dropout_op_kwargs': None, 'nonlin': 'torch.nn.LeakyReLU', 'nonlin_kwargs': {'inplace': True}, 'deep_supervision': True}, '_kw_requires_import': ['conv_op', 'norm_op', 'dropout_op', 'nonlin']}, 'batch_dice': False}
These are the global plan.json settings:
{'dataset_name': 'Dataset727, 'plans_name': 'nnUNetPlans', 'original_median_spacing_after_transp': [3.0, 0.5, 0.5], 'original_median_shape_after_transp': [33, 256, 252], 'image_reader_writer': 'SimpleITKIO', 'transpose_forward': [0, 1, 2], 'transpose_backward': [0, 1, 2], 'experiment_planner_used': 'ExperimentPlanner', 'label_manager': 'LabelManager', 'foreground_intensity_properties_per_channel': {'0': {'max': 300.6848449707031, 'mean': 57.38395309448242, 'median': 55.44499206542969, 'min': -34.248321533203125, 'percentile_00_5': -0.029295789077878, 'percentile_99_5': 161.9104461669922, 'std': 34.14013671875}}}
2024-05-28 14:24:38.170208: unpacking dataset... 2024-05-28 14:24:43.407785: unpacking done... 2024-05-28 14:24:43.435786: Unable to plot network architecture: nnUNet_compile is enabled! 2024-05-28 14:24:43.458476: 2024-05-28 14:24:43.462629: Epoch 0
2024-05-28 14:24:43.466364: Current learning rate: 0.01
Traceback (most recent call last): File "/home/my_user/nnunet2_envNew/bin/nnUNetv2_train", line 8, in sys.exit(run_training_entry())
^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnUNetNew/nnunetv2/run/run_training.py", line 275, in run_training_entry
run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
File "/home/my_user/nnUNetNew/nnunetv2/run/run_training.py", line 211, in run_training
nnunet_trainer.run_training()
File "/home/my_user/nnUNetNew/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 1361, in run_training
train_outputs.append(self.train_step(next(self.dataloader_train)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnUNetNew/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 987, in train_step
l = self.loss(output, target)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnUNetNew/nnunetv2/training/loss/deep_supervision.py", line 30, in forward
return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnUNetNew/nnunetv2/training/loss/deep_supervision.py", line 30, in
return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0])
^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnUNetNew/nnunetv2/training/loss/compound_losses.py", line 98, in forward
dc_loss = self.dc(net_output, target_regions, loss_mask=mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
tracer.run()
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
super().run()
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
and self.step()
^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
getattr(self, inst.opname)(inst)
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 364, in inner
eval_result = value.evaluate_expr(self.output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 703, in evaluate_expr
return guard_scalar(self.sym_num)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 298, in guard_scalar
return guard_bool(a)
^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 500, in guard_bool
return a.node.guard_bool("", 0) # NB: uses Python backtrace
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 954, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3536, in evaluate_expr
self._maybe_guard_eq(sympy.Eq(expr, concrete_val), True)
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 3331, in _maybe_guard_eq
assert len(free) > 0, f"The expression should not be static by this point: {expr}"
AssertionError: The expression should not be static by this point: False

from user code:
File "/home/my_user/nnUNetNew/nnunetv2/training/loss/dice.py", line 83, in forward
if x.shape == y.shape:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

Exception in thread Thread-1 (results_loop):
Traceback (most recent call last):
File "/home/my_user/nnunet2_envNew/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
self.run()
File "/home/my_user/nnunet2_envNew/lib/python3.11/threading.py", line 982, in run
self._target(*self._args, **self._kwargs)
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py", line 125, in results_loop
raise e
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py", line 103, in results_loop
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
RuntimeError: One or more background workers are no longer alive. Exiting. Please check the print statements above for the actual error message
Exception in thread Thread-2 (results_loop):
Traceback (most recent call last):
File "/home/my_user/nnunet2_envNew/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
self.run()
File "/home/my_user/nnunet2_envNew/lib/python3.11/threading.py", line 982, in run
self._target(*self._args, **self._kwargs)
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py", line 125, in results_loop
raise e
File "/home/my_user/nnunet2_envNew/lib/python3.11/site-packages/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py", line 103, in results_loop
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
RuntimeError: One or more background workers are no longer alive. Exiting. Please check the print statements above for the actual error message
"

@FabianIsensee
Copy link
Member

This is a different problem. Please upgrade torch to the latest version. It is not a nnU-Net related error. If upgrading doesn't work you can always disable torch.compile by setting nnUNet_compile=f in your environment variables

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants