Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarifications on dimensions #18

Open
anar-rzayev opened this issue Jan 12, 2024 · 7 comments
Open

Clarifications on dimensions #18

anar-rzayev opened this issue Jan 12, 2024 · 7 comments

Comments

@anar-rzayev
Copy link

anar-rzayev commented Jan 12, 2024

Hi, thank you for this amazing paper. I wanted to ask you very few questions to elaborate in very detail.

I have seen it in multiple places (i.e. mri_dataset.py) that you define valid_mask = [10, 160]. Considering your data size as (81, 106, 76, 160), are there any particular reasons you choose val_volume _idx = 40 and select valid_mask = [10, 160] in hardi150.json? The reason I asked is that I am working with (118, 118, 25, 56) 4D-diffusion data and there are some issues I fall into when defining mri_dataset.py as follows:

valid_mask = np.zeros(56,)
valid_mask[10:] += 1
valid_mask = valid_mask.astype(np.bool8)
dataset = MRIDataset("/home/anar/DDM2/data/HARDI150.nii.gz", valid_mask = [10, 56],
                         phase='train', val_volume_idx=40, padding=3)
Traceback (most recent call last):
  File "/home/anar/DDM2/train_noise_model.py", line 92, in <module>
    for _,  val_data in enumerate(val_loader):
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/_utils.py", line 694, in reraise
    raise exception
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/anar/DDM2/data/mri_dataset.py", line 118, in __getitem__
    raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,[volume_idx+self.padding]]), axis=-1)
IndexError: index 41 is out of bounds for axis 3 with size 8

Do you have any slightest ideas where could this originate from?

@tiangexiang
Copy link
Collaborator

Hi, thank you for your interest!

Let me clarify the data dimensions first. Given a 4D MRI data [W, H, Z, T], in which [W, H, Z] are the dimensions to describe the MRI scan in the 3D space representing the length at the three (x,y,z) axes. For a 4D MRI scan with multiple acquisitions, the additional 4th dimension T indicates the number of observations of the same 3D structure. Due to the randomness in each acquisition, each 3D observation may be noisy in a different way. Denoising methods (including DDM2) are designed to try to learn consistency across those different noisy observations to recover a clean 3D structure.

For the choice of val_volume _idx = 40: this is a random choice. You may choose any volume to validate without affecting the training process. And the choice ofvalid_mask = [10, 160]: this indicates which interval of observations are NOT with a b-value = 0. For the hardi150 dataset, the first 10 observations have b-values = 0, therefore we exclude them from the training process and mask them using the valid_mask = [10, 160] (meaning using 160-10 observations at the Tdimension).

Back to your problem, I think it still makes sense to set the same parameters for val_volume _idx and valid_mask. The problem indicates the raw data you loaded seems to only have T=8 observations instead of T=56. I suggest to double check the data first and make sure the loaded raw data has the expected data size.

@anar-rzayev
Copy link
Author

anar-rzayev commented Jan 14, 2024

Thanks so much, @tiangexiang, for your fast response to this issue and detailed explanations on valid_mask & val_volume_idx. As you instructed, I changed the .json file and mri_dataset.py to have valid_mask = [20, 56] to capture non-zero b_val volumes from my dataset.

Now, after solving a minor bug as in the following,

24-01-14 20:28:08.899 - INFO: [Phase 1] Training noise model!
24-01-14 20:28:10.231 - INFO: MRI dataset [hardi] is created.
24-01-14 20:28:13.083 - INFO: MRI dataset [hardi] is created.
24-01-14 20:28:13.083 - INFO: Initial Dataset Finished
24-01-14 20:28:13.463 - INFO: Noise Model is created.
24-01-14 20:28:13.463 - INFO: Initial Model Finished
2.1.2+cu121 12.1
export CUDA_VISIBLE_DEVICES=0
Loaded data of size: (118, 118, 25, 56)
Loaded data of size: (118, 118, 25, 56)
dropout 0.0 encoder dropout 0.0
Traceback (most recent call last):
  File "/home/anar/DDM2/train_noise_model.py", line 72, in <module>
    trainer.optimize_parameters()
  File "/home/anar/DDM2/model/model_stage1.py", line 62, in optimize_parameters
    outputs = self.netG(self.data)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anar/DDM2/model/mri_modules/noise_model.py", line 44, in forward
    return self.p_losses(x, *args, **kwargs)
  File "/home/anar/DDM2/model/mri_modules/noise_model.py", line 36, in p_losses
    x_recon = self.denoise_fn(x_in['condition'])
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anar/DDM2/model/mri_modules/unet.py", line 286, in forward
    x = layer(x)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (double) and bias type (float) should be the same

I added a new line raw_data = raw_data.astype(np.float32) after loading NIFTI data to solve this double-float error and then found another issue from training stage1:

24-01-14 20:29:55.122 - INFO: [Phase 1] Training noise model!
24-01-14 20:29:56.211 - INFO: MRI dataset [hardi] is created.
24-01-14 20:29:56.892 - INFO: MRI dataset [hardi] is created.
24-01-14 20:29:56.892 - INFO: Initial Dataset Finished
24-01-14 20:29:57.252 - INFO: Noise Model is created.
24-01-14 20:29:57.252 - INFO: Initial Model Finished
24-01-14 20:31:55.243 - INFO: <epoch: 35, iter:   1,000> l_pix: 3.8268e-03 
24-01-14 20:33:51.252 - INFO: <epoch: 69, iter:   2,000> l_pix: 3.1790e-03 
2.1.2+cu121 12.1
export CUDA_VISIBLE_DEVICES=0
Loaded data of size: (118, 118, 25, 56)
Loaded data of size: (118, 118, 25, 56)
dropout 0.0 encoder dropout 0.0
Validation
Traceback (most recent call last):
  File "/home/anar/DDM2/train_noise_model.py", line 92, in <module>
    for _,  val_data in enumerate(val_loader):
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/_utils.py", line 694, in reraise
    raise exception
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/anar/DDM2/data/mri_dataset.py", line 128, in __getitem__
    raw_input = raw_input[:,:,0]
IndexError: index 0 is out of bounds for axis 2 with size 0

Do you have any ideas on how to possibly solve this issue? I have tried some modifications in mri_dataset.py but neither of them solves the last IndexError.

@tiangexiang
Copy link
Collaborator

Hi, it seems the error comes from the shape of the raw_input. Can you make sure the tensor raw_input is in the shape [W, H, -1], or [W, H, 1, -1] before that line of code?

@anar-rzayev
Copy link
Author

anar-rzayev commented Jan 17, 2024

To double-check the shapes, I added a few lines in mri_dataset.py:

# w, h, c, d = raw_input.shape
# raw_input = np.reshape(raw_input, (w, h, -1))
print("raw_input shape before slicing:", raw_input.shape)
if len(raw_input.shape) == 4:
    raw_input = raw_input[:,:,0]
    print("raw_input shape after slicing:", raw_input.shape)
raw_input = self.transforms(raw_input) # only support the first channel for now
# raw_input = raw_input.view(c, d, w, h)

And in the results, I get as follows:

raw_input shape before slicing: (118, 118, 1, 3)
raw_input shape after slicing: (118, 118, 3)
Validation
Traceback (most recent call last):
  File "/home/anar/DDM2/train_noise_model.py", line 92, in <module>
    for _,  val_data in enumerate(val_loader):
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/_utils.py", line 694, in reraise
    raise exception
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/anar/DDM2/data/mri_dataset.py", line 129, in __getitem__
    raw_input = raw_input[:,:,0]
IndexError: index 0 is out of bounds for axis 2 with size 0

@tiangexiang
Copy link
Collaborator

it seems that the data loaded for training is good, but the data loaded for validation may not be in the right shape. In any forms, the problem must come from how you load the data. I am not able to provide meaningful suggestions without more information, I still suggest to inspect the tensor shape at all possible locations.

@anar-rzayev
Copy link
Author

BTW, for the previous issue case, I even added raw_input = np.reshape(raw_input, (raw_input.shape[0], raw_input.shape[1], 1, -1)) before raw_input = raw_input[:,:,0] but still, the following error comes to play

raw_input shape before slicing: (118, 118, 1, 3)
raw_input shape after slicing: (118, 118, 3)
Validation
Traceback (most recent call last):
  File "/home/anar/DDM2/train_noise_model.py", line 92, in <module>
    for _,  val_data in enumerate(val_loader):
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/_utils.py", line 694, in reraise
    raise exception
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/anar/miniconda3/envs/ddm2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/anar/DDM2/data/mri_dataset.py", line 135, in __getitem__
    ret = dict(X=raw_input[[-1], :, :], condition=raw_input[:-1, :, :])
IndexError: index is out of bounds for dimension with size 0

@anar-rzayev
Copy link
Author

anar-rzayev commented Jan 17, 2024

This is so weird, how come the training load is successful but validation fails?! My dataset is simply dwi_combined.nii.gz where I specify the path in the JSON file. I did not change anything in mri_dataset.py except for the valid_mask to restrict the interval where b-values are non-zero:

from curses import raw
from io import BytesIO
from PIL import Image
from torch.utils.data import Dataset
import random
import os
import numpy as np
import torch
from dipy.io.image import save_nifti, load_nifti
from matplotlib import pyplot as plt
from torchvision import transforms, utils


class MRIDataset(Dataset):
    def __init__(self, dataroot, valid_mask, phase='train', image_size=128, in_channel=1, val_volume_idx=50, val_slice_idx=40,
                 padding=1, lr_flip=0.5, stage2_file=None):
        self.padding = padding // 2
        self.lr_flip = lr_flip
        self.phase = phase
        self.in_channel = in_channel

        # read data
        raw_data, _ = load_nifti(dataroot) # width, height, slices, gradients
        raw_data = raw_data.astype(np.float32)
        
        print('Loaded data of size:', raw_data.shape)
        # normalize data
        raw_data = raw_data.astype(np.float32) / np.max(raw_data, axis=(0,1,2), keepdims=True)

        # parse mask
        assert type(valid_mask) is (list or tuple) and len(valid_mask) == 2
 
        # mask data
        raw_data = raw_data[:,:,:,valid_mask[0]:valid_mask[1]] 
        self.data_size_before_padding = raw_data.shape

        self.raw_data = np.pad(raw_data, ((0,0), (0,0), (in_channel//2, in_channel//2), (self.padding, self.padding)), mode='wrap')

        # running for Stage3?
        if stage2_file is not None:
            print('Parsing Stage2 matched states from the stage2 file...')
            self.matched_state = self.parse_stage2_file(stage2_file)
        else:
            self.matched_state = None

        # transform
        if phase == 'train':
            self.transforms = transforms.Compose([
                transforms.ToTensor(),
                #transforms.Resize(image_size),
                transforms.RandomVerticalFlip(lr_flip),
                transforms.RandomHorizontalFlip(lr_flip),
                transforms.Lambda(lambda t: (t * 2) - 1)
            ])
        else:
            self.transforms = transforms.Compose([
                transforms.ToTensor(),
                #transforms.Resize(image_size),
                transforms.Lambda(lambda t: (t * 2) - 1)
            ])

        # prepare validation data
        if val_volume_idx == 'all':
            self.val_volume_idx = range(raw_data.shape[-1])
        elif type(val_volume_idx) is int:
            self.val_volume_idx = [val_volume_idx]
        elif type(val_volume_idx) is list:
            self.val_volume_idx = val_volume_idx
        else:
            self.val_volume_idx = [int(val_volume_idx)]

        if val_slice_idx == 'all':
            self.val_slice_idx = range(0, raw_data.shape[-2])
        elif type(val_slice_idx) is int:
            self.val_slice_idx = [val_slice_idx]
        elif type(val_slice_idx) is list:
            self.val_slice_idx = val_slice_idx
        else:
            self.val_slice_idx = [int(val_slice_idx)]

    def parse_stage2_file(self, file_path):
        results = dict()
        with open(file_path, 'r') as f:
            lines = f.readlines()
            
            for line in lines:
                info = line.strip().split('_')
                volume_idx, slice_idx, t = int(info[0]), int(info[1]), int(info[2])
                if volume_idx not in results:
                    results[volume_idx] = {}
                results[volume_idx][slice_idx] = t
        return results


    def __len__(self):
        if self.phase == 'train' or self.phase == 'test':
            return self.data_size_before_padding[-2] * self.data_size_before_padding[-1] # num of volumes
        elif self.phase == 'val':
            return len(self.val_volume_idx) * len(self.val_slice_idx)

    def __getitem__(self, index):
        if self.phase == 'train' or self.phase == 'test':
            # decode index to get slice idx and volume idx
            volume_idx = index // self.data_size_before_padding[-2]
            slice_idx = index % self.data_size_before_padding[-2]
        elif self.phase == 'val':
            s_index = index % len(self.val_slice_idx)
            index = index // len(self.val_slice_idx)
            slice_idx = self.val_slice_idx[s_index]
            volume_idx = self.val_volume_idx[index]

        raw_input = self.raw_data
       
        if self.padding > 0:
            raw_input = np.concatenate((
                                    raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,volume_idx:volume_idx+self.padding],
                                    raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,volume_idx+self.padding+1:volume_idx+2*self.padding+1],
                                    raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,[volume_idx+self.padding]]), axis=-1)

        elif self.padding == 0:
            raw_input = np.concatenate((
                                    raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,[volume_idx+self.padding-1]],
                                    raw_input[:,:,slice_idx:slice_idx+2*(self.in_channel//2)+1,[volume_idx+self.padding]]), axis=-1)

        # w, h, c, d = raw_input.shape
        # raw_input = np.reshape(raw_input, (w, h, -1))
        print("raw_input shape before slicing:", raw_input.shape)
        if len(raw_input.shape) == 4:
            raw_input = np.reshape(raw_input, (raw_input.shape[0], raw_input.shape[1], 1, -1))
            raw_input = raw_input[:,:,0]
            print("raw_input shape after slicing:", raw_input.shape)
        raw_input = self.transforms(raw_input) # only support the first channel for now
        # raw_input = raw_input.view(c, d, w, h)

        ret = dict(X=raw_input[[-1], :, :], condition=raw_input[:-1, :, :])

        if self.matched_state is not None:
            ret['matched_state'] = torch.zeros(1,) + self.matched_state[volume_idx][slice_idx]

        return ret


if __name__ == "__main__":

    # hardi
    valid_mask = np.zeros(56,)
    valid_mask[20:] += 1
    valid_mask = valid_mask.astype(np.bool8)
    dataset = MRIDataset('/home/anar/DDM2/data/combined_dwi.nii.gz', valid_mask = [20, 56],
                         phase='train', val_volume_idx=40, padding=3)
    
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
    for i, data in enumerate(trainloader):
        if i < 95 != 0:
            continue
        if i > 108:
            break
        img = data['X']
        condition = data['condition']
        img = img.numpy()
        condition = condition.numpy()

        vis = np.hstack((img[0].transpose(1,2,0), condition[0,[0]].transpose(1,2,0), condition[0,[1]].transpose(1,2,0)))
        # plt.imshow(img[0].transpose(1,2,0), cmap='gray')
        # plt.show()
        # plt.imshow(condition[0,[0]].transpose(1,2,0), cmap='gray')
        # plt.show()
        # plt.imshow(condition[0,[1]].transpose(1,2,0), cmap='gray')
        # plt.show()

        plt.imshow(vis, cmap='gray')
        plt.show()
        #break

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants