Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batching with missing data #103

Closed
davidwilby opened this issue Mar 8, 2024 · 7 comments
Closed

Batching with missing data #103

davidwilby opened this issue Mar 8, 2024 · 7 comments

Comments

@davidwilby
Copy link
Collaborator

@tom-andersson et al. I wonder if you can help clear up some difficulty that @MartinSJRogers and I are having with batched training for gridded data with missing values using deepsensor.

As yet I'm unable to work out whether we're doing something incorrectly or whether there are bugs in deepsensor's implementation.

We're working with gridded data with missing values represented as NaNs as specified in the Data Requirements section of the docs.

When setting a batch_size during training, concat_tasks is called, in the below snippet, the remove_target_nans() method is called:

for i, task in enumerate(tasks):
if "numpy_mask" in task["ops"] or "nps_mask" in task["ops"]:
raise ValueError(
"Cannot concatenate tasks that have had NaNs masked. "
"Masking will be applied automatically after concatenation."
)
if "target_nans_removed" not in task["ops"]:
task = task.remove_target_nans()
if "batch_dim" not in task["ops"]:
task = task.add_batch_dim()
if "float32" not in task["ops"]:
task = task.cast_to_float32()
tasks[i] = task

This results in a ValueError raised later in concat_tasks since there are different numbers of targets in each batch:

ValueError: All tasks must have the same number of targets to concatenate: got [9460279, 10432117, 8255541, 10345501]. To train with Task batches containing differing numbers of targets, run the model individually over each task and average the losses.

and as a result we don't get to the calls to mask_nans_numpy and mask_nans_nps towards the end of concat_tasks.

I'm confused by this, since from the message above "Cannot concatenate tasks that have had NaNs masked. " "Masking will be applied automatically after concatenation." and the later call to mask_nans_{numpy,nps} it seems like this should be handled by those methods.

When I remove the call to remove_target_nans here for testing, of course the batch sizes are the same and the ValueError above isn't raised, the rest of concat_tasks runs and mask_nans_{numpy,nps} are called successfully.

This, however, results in an error further down the line in which the Masked object from neuralprocesses is found not to have the dtype attribute:

Full stack trace
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[12], line 15
     13 trainer = Trainer(model, lr=5e-5)
     14 for epoch in tqdm(range(1)):
---> 15     batch_losses = trainer(train_tasks, tqdm_notebook=True, batch_size=None) # error here due to filesize. I have attempted using batch_size = n, 
     16                                         # but get seperate error asserting that the number of targets in each batch must be the same. 
     17                                         # Todo- work out how to calculate number of targets in each batch, and ensure the batch size allows me to honour this assertion. 
     18     losses.append(np.mean(batch_losses))

File _/deepsensor/train/train.py:177, in Trainer.__call__(self, tasks, batch_size, progress_bar, tqdm_notebook)
    170 def __call__(
    171     self,
    172     tasks: List[Task],
   (...)
    175     tqdm_notebook=False,
    176 ) -> List[float]:
--> 177     return train_epoch(
    178         model=self.model,
    179         tasks=tasks,
    180         batch_size=batch_size,
    181         opt=self.opt,
    182         progress_bar=progress_bar,
    183         tqdm_notebook=tqdm_notebook,
    184     )

File _/deepsensor/train/train.py:145, in train_epoch(model, tasks, lr, batch_size, opt, progress_bar, tqdm_notebook)
    143     else:
    144         task = tasks[batch_i]
--> 145     batch_loss = train_step(task)
    146     batch_losses.append(batch_loss)
    148 return batch_losses

File _/deepsensor/train/train.py:116, in train_epoch.<locals>.train_step(tasks)
    114 task_losses = []
    115 for task in tasks:
--> 116     task_losses.append(model.loss_fn(task, normalise=True))
    117 mean_batch_loss = B.mean(B.stack(*task_losses))
    118 mean_batch_loss.backward()

File _/deepsensor/model/convnp.py:869, in ConvNP.loss_fn(self, task, fix_noise, num_lv_samples, normalise)
    865 task = ConvNP.modify_task(task)
    867 context_data, xt, yt, model_kwargs = convert_task_to_nps_args(task)
--> 869 logpdfs = backend.nps.loglik(
    870     self.model,
    871     context_data,
    872     xt,
    873     yt,
    874     **model_kwargs,
    875     fix_noise=fix_noise,
    876     num_samples=num_lv_samples,
    877     normalise=normalise,
    878 )
    880 loss = -B.mean(logpdfs)
    882 return loss

File _conda-envs/deepsensor/lib/python3.11/site-packages/plum/function.py:399, in Function.__call__(self, *args, **kw_args)
    397 def __call__(self, *args, **kw_args):
    398     method, return_type = self._resolve_method_with_cache(args=args)
--> 399     return _convert(method(*args, **kw_args), return_type)

File _conda-envs/deepsensor/lib/python3.11/site-packages/neuralprocesses/model/loglik.py:113, in loglik(model, *args, **kw_args)
    110 @_dispatch
    111 def loglik(model: Model, *args, **kw_args):
    112     state = B.global_random_state(B.dtype(args[-2]))
--> 113     state, logpdfs = loglik(state, model, *args, **kw_args)
    114     B.set_global_random_state(state)
    115     return logpdfs

File _conda-envs/deepsensor/lib/python3.11/site-packages/plum/function.py:399, in Function.__call__(self, *args, **kw_args)
    397 def __call__(self, *args, **kw_args):
    398     method, return_type = self._resolve_method_with_cache(args=args)
--> 399     return _convert(method(*args, **kw_args), return_type)

File _conda-envs/deepsensor/lib/python3.11/site-packages/neuralprocesses/model/loglik.py:48, in loglik(state, model, contexts, xt, yt, num_samples, batch_size, normalise, fix_noise, dtype_lik, **kw_args)
     12 @_dispatch
     13 def loglik(
     14     state: B.RandomState,
   (...)
     25     **kw_args,
     26 ):
     27     \"\"\"Log-likelihood objective.
     28 
     29     Args:
   (...)
     46         tensor: Log-likelihoods.
     47     \"\"\"
---> 48     float = B.dtype_float(yt)
     49     float64 = B.promote_dtypes(float, np.float64)
     51     # For the likelihood computation, default to using a 64-bit version of the data
     52     # type of `yt`.

File _conda-envs/deepsensor/lib/python3.11/site-packages/plum/function.py:399, in Function.__call__(self, *args, **kw_args)
    397 def __call__(self, *args, **kw_args):
    398     method, return_type = self._resolve_method_with_cache(args=args)
--> 399     return _convert(method(*args, **kw_args), return_type)

File _conda-envs/deepsensor/lib/python3.11/site-packages/lab/types.py:342, in dtype_float(x)
    340 @dispatch
    341 def dtype_float(x):
--> 342     return dtype_float(dtype(x))

File _conda-envs/deepsensor/lib/python3.11/site-packages/plum/function.py:399, in Function.__call__(self, *args, **kw_args)
    397 def __call__(self, *args, **kw_args):
    398     method, return_type = self._resolve_method_with_cache(args=args)
--> 399     return _convert(method(*args, **kw_args), return_type)

File _conda-envs/deepsensor/lib/python3.11/site-packages/lab/types.py:236, in dtype(a)
    226 @dispatch
    227 def dtype(a):
    228     \"\"\"Determine the data type of an object.
    229 
    230     Args:
   (...)
    234         dtype: Data type of `a`.
    235     \"\"\"
--> 236     return a.dtype

AttributeError: 'Masked' object has no attribute 'dtype'"

Are we doing something incorrectly here? Or are there bugs in the implementation. Happy to add more docs when we've worked out what we're doing or contribute bug fixes if required!

Lastly, for non-batched training, are mask_nans_{numpy,nps} run somewhere else? I notice that they're called in modify_task but I'm not yet sure when this is called.

@nilsleh
Copy link

nilsleh commented Mar 12, 2024

@davidwilby Not sure if this helps, I ran into something similar a while ago where the target sets have different number of targets across tasks while batching, and I adapted the concat_tasks function to randomly subsample the targets to a common batch size:

for target_set_i in range(n_target_sets):
    # Raise error if target sets have different numbers of targets across tasks
    n_target_obs = [task["Y_t"][target_set_i].size for task in tasks]
    if not all([n == n_target_obs[0] for n in n_target_obs]):
        # for this target set adapt the number of observations across tasks to min_n_target_obs
        shapes = [task["Y_t"][target_set_i].shape[-1] for task in tasks]
        min_n = min(shapes)

        for task in tasks:
            rand_indices = np.random.choice(
                np.arange(task["Y_t"][target_set_i].shape[-1]),
                size=min_n,
                replace=False,
            )
            task["Y_t"][target_set_i] = task["Y_t"][target_set_i][..., rand_indices]
            task["X_t"][target_set_i] = task["X_t"][target_set_i][..., rand_indices]

Not sure if this helps, but I would agree that it would be nice to include some functionality that handles this for batched training as this can happen quiet frequently.

@tom-andersson
Copy link
Collaborator

Hi @davidwilby + @MartinSJRogers, thank you for raising this :) this boils down to a few things:

  1. We can have missing data/NaNs on the context side because ConvNPs represent missing data as zeros in the density channels of the context set encodings. This is handled by neuralprocesses.Masked objects which are constructed by the Task.mask_nans_{numpy,nps} methods you mentioned.
  2. However, we can't have NaNs in target values because backpropagation would fail.
  3. We can remove NaNs from target sets, and NP models can happily train on varying-length target arrays when this happens.
  4. However, we can't concatenate varying-length arrays into a single array (and there is no support in deepsensor/neuralprocesses for padding the target arrays and then masking the padded values from the loss).
  5. Therefore we can't run in batch mode if there are missing values in the targets.

The workaround, as suggested in the error message, is to manually run the model multiple times in a for loop over your 'batch', and then average the losses within your model update. This gives you the smoother loss surface of batch training, but unfortunately it doesn't give you the computational efficiency of running on multiple examples in parallel on a GPU.

@nilsleh's workaround of subsampling to the smallest number of targets is a nice idea, although the model will see fewer target points per batch than it would otherwise, so this is a trade-off between computational efficiency and learning efficiency. If the number of non-missing target points are similar between all Tasks, which looks to be the case from [9460279, 10432117, 8255541, 10345501], then it might not be a bad shout.

When I remove the call to remove_target_nans here for testing, of course the batch sizes are the same and the ValueError above isn't raised, the rest of concat_tasks runs and mask_nans_{numpy,nps} are called successfully. This, however, results in an error further down the line in which the Masked object from neuralprocesses is found not to have the dtype attribute:

neuralprocesses stack traces can be confusing and the dtype error isn't clear, but you can't have neuralprocesses.Masked objects in the targets. The targets need to be vanilla tensors. Only context data can have neuralprocesses.Masked objects, and the missing data will be dealt with under the hood, as mentioned above.

Hope this clears things up, and please close if so :)

@nilsleh
Copy link

nilsleh commented Apr 5, 2024

I have a related question to NaNs in target sets, which is the case for the data that I am working with. If I don't modify anything and use the provided Trainer code as such:

trainer = Trainer(model, lr=5e-5)
batch_losses = trainer(train_tasks, batch_size=None)

A single task looks like this in the loss_fn computation after the modify_task call here

time: Timestamp/2013-06-15 12:00:00
ops: ['str/batch_dim', 'str/float32', 'str/numpy_mask', 'str/nps_mask', 'str/tensor']
X_c: ['Tensor/torch.float32/torch.Size([1, 2, 256000])', 'Tensor/torch.float32/torch.Size([1, 2, 256000])', 'Tensor/torch.float32/torch.Size([1, 2, 256000])']
Y_c: ['Masked/(y=torch.float32/torch.Size([1, 1, 256000]))/(mask=torch.float32/torch.Size([1, 1, 256000]))', 'Masked/(y=torch.float32/torch.Size([1, 1, 256000]))/(mask=torch.float32/torch.Size([1, 1, 256000]))', 'Tensor/torch.float32/torch.Size([1, 1, 256000])']
X_t: ['Tensor/torch.float32/torch.Size([1, 2, 224000])']
Y_t: ['Masked/(y=torch.float32/torch.Size([1, 1, 224000]))/(mask=torch.float32/torch.Size([1, 1, 224000]))']

And I get the error: AttributeError: 'Masked' object has no attribute 'dtype'

If I change the loss function to remove the targets before modifying the task by adding task.remove_target_nans(), while keeping batch_size=None or just changing the trainer batch_size>2 because then remove_target_nans() is called in concat_tasks a single task looks like this:

time: Timestamp/2013-06-15 12:00:00
ops: ['str/target_nans_removed', 'str/batch_dim', 'str/float32', 'str/numpy_mask', 'str/nps_mask', 'str/tensor']
X_c: ['Tensor/torch.float32/torch.Size([1, 2, 256000])', 'Tensor/torch.float32/torch.Size([1, 2, 256000])', 'Tensor/torch.float32/torch.Size([1, 2, 256000])']
Y_c: ['Masked/(y=torch.float32/torch.Size([1, 1, 256000]))/(mask=torch.float32/torch.Size([1, 1, 256000]))', 'Masked/(y=torch.float32/torch.Size([1, 1, 256000]))/(mask=torch.float32/torch.Size([1, 1, 256000]))', 'Tensor/torch.float32/torch.Size([1, 1, 256000])']
X_t: ['Tensor/torch.float32/torch.Size([1, 2, 215843])']
Y_t: ['Tensor/torch.float32/torch.Size([1, 1, 215843])']

but then I get the neuralprocess library error: AssertionError: Expected not a parallel of elements, but got inputs and outputs in parallel.

Full Stacktrace
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/deepsensor/train/train.py", line 116, in train_step                                                                                                                                                                                                                                                                                                               
    task_losses.append(model.loss_fn(task, normalise=True))                                                                                                                                                                                                                                                                                                                                                                         
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/deepsensor/model/convnp.py", line 870, in loss_fn                                                                                                                                                                                                                                                                                                                 
    logpdfs = backend.nps.loglik(                                                                                                                                                                                                                                                                                                                                                                                                   
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__                                                                                                                                                                                                                                                                                                                          
    return _convert(method(*args, **kw_args), return_type)                                                                                                                                                                                                                                                                                                                                                                          
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/loglik.py", line 113, in loglik                                                                                                                                                                                                                                                                                                             
    state, logpdfs = loglik(state, model, *args, **kw_args)                                                                                                                                                                                                                                                                                                                                                                         
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__                                                                                                                                                                                                                                                                                                                          
    return _convert(method(*args, **kw_args), return_type)                                                                                                                                                                                                                                                                                                                                                                          
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/loglik.py", line 64, in loglik                                                                                                                                                                                                                                                                                                              
    state, pred = model(                                                                                                                                                                                                                                                                                                                                                                                                            
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 489, in __call__                                                                                                                                                                                                                                                                                                                          
    return self._f(self._instance, *args, **kw_args)                                                                                                                                                                                                                                                                                                                                                                                
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__                                                                                                                                                                                                                                                                                                                          
    return _convert(method(*args, **kw_args), return_type)                                                                                                                                                                                                                                                                                                                                                                          
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/model.py", line 101, in __call__                                                                                                                                                                                                                                                                                                            
    return self(                                                                                                                                                                                                                                                                                                                                                                                                                    
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 489, in __call__                                                                                                                                                                                                                                                                                                                          
    return self._f(self._instance, *args, **kw_args)                                                                                                                                                                                                                                                                                                                                                                                
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__                                                                                                                                                                                                                                                                                                                          
    return _convert(method(*args, **kw_args), return_type)                                                                                                                                                                                                                                                                                                                                                                          
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/model.py", line 72, in __call__                                                                                                                                                                                                                                                                                                             
    xz, pz = code(self.encoder, xc, yc, xt, root=True, **enc_kw_args)                                                                                                                                                                                                                                                                                                                                                               
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__                                                                                                                                                                                                                                                                                                                          
    return _convert(method(*args, **kw_args), return_type)                                                                                                                                                                                                                                                                                                                                                                          
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped                                                                                                                                                                                                                                                                                                               
    return f(*args, **kw_args)                                                                                                                                                                                                                                                                                                                                                                                                      
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/coders/functional.py", line 39, in code                                                                                                                                                                                                                                                                                                           
    return code(coder.coder, xz, z, x, **kw_args)                                                                                                                                                                                                                                                                                                                                                                                   
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__                                                                                                                                                                                                                                                                                                                          
    return _convert(method(*args, **kw_args), return_type)                                                                                                                                                                                                                                                                                                                                                                          
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped                                                                                                                                                                                                                                                                                                               
    return f(*args, **kw_args)                                                                                                                                                                                                                                                                                                                                                                                                      
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/chain.py", line 56, in code                                                                                                                                                                                                                                                                                                                       
    xz, z = code(link, xz, z, x, **kw_args)                                                                                                                                                                                                                                                                                                                                                                                         
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__                                                                                                                                                                                                                                                                                                                          
    return _convert(method(*args, **kw_args), return_type)                                                                                                                                                                                                                                                                                                                                                                          
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped                                                                                                                                                                                                                                                                                                               
    return f(*args, **kw_args)                                                                                                                                                                                                                                                                                                                                                                                                      
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/chain.py", line 56, in code                                                                                                                                                                                                                                                                                                                       
    xz, z = code(link, xz, z, x, **kw_args)                                                                                                                                                                                                                                                                                                                                                                                         
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__                                                                                                                                                                                                                                                                                                                          
    return _convert(method(*args, **kw_args), return_type)                                                                                                                                                                                                                                                                                                                                                                          
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped                                                                                                                                                                                                                                                                                                               
    return f(*args, **kw_args)                                                                                                                                                                                                                                                                                                                                                                                                      
  File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/coders/shaping.py", line 179, in code                                                                                                                                                                                                                                                                                                             
    raise AssertionError(                                                                                                                                                                                                                                                                                                                                                                                                           
AssertionError: Expected not a parallel of elements, but got inputs and outputs in parallel.

</details>

And I am not sure what I have done wrong.

@tom-andersson
Copy link
Collaborator

Hi @nilsleh, the AttributeError: 'Masked' object has no attribute 'dtype' is exactly what is described above - essentially you can't train a model with NaNs in targets. It is unfortunate that when target NaNs are present the error message is confusing. As you say, using batch_size > 1 means target NaNs are automatically removed within the concat_tasks method.

Regarding the AssertionError: Expected not a parallel of elements, but got inputs and outputs in parallel, I have never seen that neuralprocesses error before. The shape of the Task looks fine, but I am missing context for what exact code you call prior to this. Would you be able to produce an MWE in a Colab by generating random data?

@nilsleh
Copy link

nilsleh commented Apr 9, 2024

Hi @tom-andersson thanks for the reply, I have created a gist with the accompanying data I am using. The data tar file also contains the normalization parameters for the data processor.

EDIT: I was able to resolve it thanks to Wessel, it was a misconfiguration of the data processor and model.

@tom-andersson
Copy link
Collaborator

Glad you could solve this @nilsleh - to copy over your solution from the neuralprocesses GitHub for future reference:

I had forgotten to pass in the task_loader as an argument to the ConvNP model as I was using multiple context sets. And that just initializes a model with default parameters than then result in a mismatch, when you try to pass in your "actual" data.

@tom-andersson
Copy link
Collaborator

There were no complaints about closing this issue, so closing now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants