Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not batch ot.emd2 via torch.vmap #532

Open
oleg-kachan opened this issue Oct 11, 2023 · 1 comment
Open

Can not batch ot.emd2 via torch.vmap #532

oleg-kachan opened this issue Oct 11, 2023 · 1 comment

Comments

@oleg-kachan
Copy link

oleg-kachan commented Oct 11, 2023

Describe the bug

As my datapoints are empirical distributions I want to use the Wasserstein distance as a loss function over a batch of shape (n_batch, n_points, dimension). Standard way to make functions that take a batch as an input is torch.vmap, yet I get the error described below.

To Reproduce

def wasserstein2_loss(X, Y):
    n, m = X.shape[0], Y.shape[0]
    a = torch.ones(n) / n
    b = torch.ones(m) / m
    M = ot.dist(X, Y, metric="sqeuclidean")
    return ot.emd2(a, b, M) ** 0.5

wasserstein2_loss_batched = torch.vmap(wasserstein2_loss)
W2 = wasserstein2_loss_batched(X, Y) # should be an array of shape `n_batch`

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 W2 = wasserstein2_loss_batched(X, Y)

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:434, in vmap.<locals>.wrapped(*args, **kwargs)
    430     return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
    431                          args_spec, out_dims, randomness, **kwargs)
    433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
    435     func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
    436 )

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:39, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     36 @functools.wraps(f)
     37 def fn(*args, **kwargs):
     38     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39         return f(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:619, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    617 try:
    618     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 619     batched_outputs = func(*batched_inputs, **kwargs)
    620     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    621 finally:

Cell In[4], line 13, in wasserstein2_loss(X, Y)
     11 b = torch.ones(m) / m
     12 M = ot.dist(X, Y, metric="sqeuclidean")
---> 13 return wasserstein_distance(a, b, M) ** 0.5

File /usr/local/lib/python3.10/dist-packages/ot/lp/__init__.py:488, in emd2(a, b, M, processes, numItermax, log, return_matrix, center_dual, numThreads, check_marginals)
    485 nx = get_backend(M0, a0, b0)
    487 # convert to numpy
--> 488 M, a, b = nx.to_numpy(M, a, b)
    490 a = np.asarray(a, dtype=np.float64)
    491 b = np.asarray(b, dtype=np.float64)

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in Backend.to_numpy(self, *arrays)
    205     return self._to_numpy(arrays[0])
    206 else:
--> 207     return [self._to_numpy(array) for array in arrays]

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in <listcomp>(.0)
    205     return self._to_numpy(arrays[0])
    206 else:
--> 207     return [self._to_numpy(array) for array in arrays]

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:1763, in TorchBackend._to_numpy(self, a)
   1761 if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
   1762     return np.array(a)
-> 1763 return a.cpu().detach().numpy()

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Expected behavior

Make POT distance functions batchable via torch.vmap, seems Sinkhorn distance code has this problem too.

@rflamary
Copy link
Collaborator

The exact ot.emd2 solver uses a compiled C++ solver so everything needs to be done on CPU and converted to numpy which is why it cannot be used with vmap that require only pytorch operation.
We might be able to make sinkhorn compatile in the future but emd2 cannot (it is highly non vectorizable also so even if this was possible there would be no gain from batching).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants