New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batching with missing data #103
Comments
@davidwilby Not sure if this helps, I ran into something similar a while ago where the target sets have different number of targets across tasks while batching, and I adapted the for target_set_i in range(n_target_sets):
# Raise error if target sets have different numbers of targets across tasks
n_target_obs = [task["Y_t"][target_set_i].size for task in tasks]
if not all([n == n_target_obs[0] for n in n_target_obs]):
# for this target set adapt the number of observations across tasks to min_n_target_obs
shapes = [task["Y_t"][target_set_i].shape[-1] for task in tasks]
min_n = min(shapes)
for task in tasks:
rand_indices = np.random.choice(
np.arange(task["Y_t"][target_set_i].shape[-1]),
size=min_n,
replace=False,
)
task["Y_t"][target_set_i] = task["Y_t"][target_set_i][..., rand_indices]
task["X_t"][target_set_i] = task["X_t"][target_set_i][..., rand_indices] Not sure if this helps, but I would agree that it would be nice to include some functionality that handles this for batched training as this can happen quiet frequently. |
Hi @davidwilby + @MartinSJRogers, thank you for raising this :) this boils down to a few things:
The workaround, as suggested in the error message, is to manually run the model multiple times in a for loop over your 'batch', and then average the losses within your model update. This gives you the smoother loss surface of batch training, but unfortunately it doesn't give you the computational efficiency of running on multiple examples in parallel on a GPU. @nilsleh's workaround of subsampling to the smallest number of targets is a nice idea, although the model will see fewer target points per batch than it would otherwise, so this is a trade-off between computational efficiency and learning efficiency. If the number of non-missing target points are similar between all
Hope this clears things up, and please close if so :) |
I have a related question to NaNs in target sets, which is the case for the data that I am working with. If I don't modify anything and use the provided Trainer code as such: trainer = Trainer(model, lr=5e-5)
batch_losses = trainer(train_tasks, batch_size=None) A single task looks like this in the
And I get the error: If I change the loss function to remove the targets before modifying the task by adding
but then I get the Full StacktraceFile "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/deepsensor/train/train.py", line 116, in train_step
task_losses.append(model.loss_fn(task, normalise=True))
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/deepsensor/model/convnp.py", line 870, in loss_fn
logpdfs = backend.nps.loglik(
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/loglik.py", line 113, in loglik
state, logpdfs = loglik(state, model, *args, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/loglik.py", line 64, in loglik
state, pred = model(
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 489, in __call__
return self._f(self._instance, *args, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/model.py", line 101, in __call__
return self(
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 489, in __call__
return self._f(self._instance, *args, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/model.py", line 72, in __call__
xz, pz = code(self.encoder, xc, yc, xt, root=True, **enc_kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped
return f(*args, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/coders/functional.py", line 39, in code
return code(coder.coder, xz, z, x, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped
return f(*args, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/chain.py", line 56, in code
xz, z = code(link, xz, z, x, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped
return f(*args, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/chain.py", line 56, in code
xz, z = code(link, xz, z, x, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped
return f(*args, **kw_args)
File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/coders/shaping.py", line 179, in code
raise AssertionError(
AssertionError: Expected not a parallel of elements, but got inputs and outputs in parallel.
|
Hi @nilsleh, the Regarding the |
Hi @tom-andersson thanks for the reply, I have created a gist with the accompanying data I am using. The data tar file also contains the normalization parameters for the data processor. EDIT: I was able to resolve it thanks to Wessel, it was a misconfiguration of the data processor and model. |
Glad you could solve this @nilsleh - to copy over your solution from the
|
There were no complaints about closing this issue, so closing now. |
@tom-andersson et al. I wonder if you can help clear up some difficulty that @MartinSJRogers and I are having with batched training for gridded data with missing values using deepsensor.
As yet I'm unable to work out whether we're doing something incorrectly or whether there are bugs in deepsensor's implementation.
We're working with gridded data with missing values represented as NaNs as specified in the Data Requirements section of the docs.
When setting a
batch_size
during training,concat_tasks
is called, in the below snippet, theremove_target_nans()
method is called:deepsensor/deepsensor/data/task.py
Lines 477 to 489 in aeccc09
This results in a
ValueError
raised later inconcat_tasks
since there are different numbers of targets in each batch:and as a result we don't get to the calls to
mask_nans_numpy
andmask_nans_nps
towards the end ofconcat_tasks
.I'm confused by this, since from the message above
"Cannot concatenate tasks that have had NaNs masked. " "Masking will be applied automatically after concatenation."
and the later call tomask_nans_{numpy,nps}
it seems like this should be handled by those methods.When I remove the call to
remove_target_nans
here for testing, of course the batch sizes are the same and theValueError
above isn't raised, the rest ofconcat_tasks
runs andmask_nans_{numpy,nps}
are called successfully.This, however, results in an error further down the line in which the
Masked
object fromneuralprocesses
is found not to have thedtype
attribute:Full stack trace
Are we doing something incorrectly here? Or are there bugs in the implementation. Happy to add more docs when we've worked out what we're doing or contribute bug fixes if required!
Lastly, for non-batched training, are
mask_nans_{numpy,nps}
run somewhere else? I notice that they're called inmodify_task
but I'm not yet sure when this is called.The text was updated successfully, but these errors were encountered: