Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing linear search as a step in GW descent #491

Open
patrick-nicodemus opened this issue Jul 18, 2023 · 2 comments
Open

Removing linear search as a step in GW descent #491

patrick-nicodemus opened this issue Jul 18, 2023 · 2 comments

Comments

@patrick-nicodemus
Copy link

TL; DR - I would like the GW function to provide an option to skip the line search in favor of the naive update step, because in my experiments line search is unambiguously worse.

In the currently implemented gradient descent algorithm for GW, the current implementation is something like the following. At time $t$, we have an existing coupling matrix, $C_t$. We apply optimal transport to solve for $C$ minimizing the expression $\langle AC_tB,C\rangle$ subject to the marginal constraints.

We write $\Delta C = C - C_t$; this is the gradient vector of steepest descent. One can show that the squared GW loss associated to the matrix $C_t + \alpha \Delta C$ is a quadratic function of the real parameter $\alpha$. Solving for the optimal value of $\alpha$ using the ordinary quadratic formula, we then coerce it to lie in the interval $[0,1]$. If $\alpha=0$ the algorithm terminates.

This cleverness with the line search goes beyond the theory developed in Peyre et. al., Gromov-Wasserstein Averaging of Kernel and Distance Matrices, ICML 2016.
I am concerned that that this is somewhat premature optimization which is not worth it. In order to do this line search, one must solve for the coefficients of the quadratic formula, which involves four matrix multiplications by my count. Compare this to directly computing $\langle ACB,C\rangle$ which only involves two matrix multiplications. Therefore, there is a tradeoff involved; two more matrix multiplications for a possibly better next step.

I have experimented with the behavior of this algorithm on a real world data set, computing a number of Gromov-Wasserstein distances between point clouds equipped with the uniform distribution. In this experiment, the line search never found a solution $\alpha$ which fell strictly within the open interval $(0,1)$, so there was no tradeoff at all, it was just uniformly slower. I have rewritten the algorithm to remove the line search and it is approximately 23% faster and gives identical results.

I suggest the devs do their own experiments and see if there is a dataset in which the line search performs better; I would be interested in seeing such a dataset, perhaps one not using the uniform distribution on points. I request that there be an option for the user to disable the line search in favor of the more naive update step.

@cedricvincentcuaz
Copy link
Collaborator

cedricvincentcuaz commented Jul 18, 2023

First, hello.

The exact line-search step can be skipped using the armijo rule, by setting 'armijo=True'. Other rules may be added as features but would be more costly most likely. However for this solver to be a proper conditional gradient, the step size $\alpha$ has to be in [0; 1], otherwise the transport plan iterates will not stay in the transport polytope and the estimated optimal transport might not be non-negative and/or satisfy the marginal constraints. As other solvers based on bregman projections including the solver from Peyré & al, or the proximal point algorithm implemented in the latest dev version (see e.g https://arxiv.org/pdf/2303.06595.pdf for more details).

@rflamary
Copy link
Collaborator

Hello @patrick-nicodemus ,

If you find a better line-search (at least on some type of data) we encourage you to implement it and add it as an alternative in the GW solvers with a PR. @cedricvincentcuaz did recently a big revamp of the GW and CG solvers that makes this easier to do (basically just give your implementation of line-search to the solver). As @cedricvincentcuaz said, you can use the armijo linesearch or implement your own (we did not implement the traditional CG lineserach in 1/sqrt(k) for instance).

Still note that in our experience the kind of claim that one lineseach is better is very data-dependent because the problem remains highly non convex that is also why we did a real line-search by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants