Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Vectorization problem in SRVShapeBundle #1808

Open
ninamiolane opened this issue Jan 30, 2023 · 2 comments
Open

Bug: Vectorization problem in SRVShapeBundle #1808

ninamiolane opened this issue Jan 30, 2023 · 2 comments
Labels

Comments

@ninamiolane
Copy link
Collaborator

ninamiolane commented Jan 30, 2023

Describe the bug

from geomstats.geometry.discrete_curves import R2, DiscreteCurves, SRVShapeBundle
INFO: Using numpy backend
dc = DiscreteCurves(k_sampling_points=10, ambient_manifold=R2)
curve1 = dc.random_point()
curve2 = dc.random_point()
bundle = SRVShapeBundle(k_sampling_points=10, ambient_manifold=R2)
bundle.align(curve1, curve2)

gives the error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Volumes/GoogleDrive/My Drive/code/geomstats/geomstats/geometry/discrete_curves.py", line 2133, in align
    return horizontal_path(1.0)
  File "/Volumes/GoogleDrive/My Drive/code/geomstats/geomstats/geometry/discrete_curves.py", line 2070, in horizontal_path
    n_times = len(t)
TypeError: object of type 'float' has no len()

Steps/Code to Reproduce

from geomstats.geometry.discrete_curves import R2, DiscreteCurves, SRVShapeBundle
dc = DiscreteCurves(k_sampling_points=10, ambient_manifold=R2)
curve1 = dc.random_point()
curve2 = dc.random_point()
bundle = SRVShapeBundle(k_sampling_points=10, ambient_manifold=R2)
bundle.align(curve1, curve2)

Expected Behaviour

No error is thrown.

Actual Behaviour

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Volumes/GoogleDrive/My Drive/code/geomstats/geomstats/geometry/discrete_curves.py", line 2133, in align
    return horizontal_path(1.0)
  File "/Volumes/GoogleDrive/My Drive/code/geomstats/geomstats/geometry/discrete_curves.py", line 2070, in horizontal_path
    n_times = len(t)
TypeError: object of type 'float' has no len()

Your environment

Github master version of geomstats.
Python 3.8.10, Mac.
@alebrigant
Copy link
Collaborator

alebrigant commented Jan 31, 2023

I have fixed the bug in the align method of ShapeSrvBundle in PR #1823. However, the way that random points are generated on the space of discrete curves should also be changed, because for now it generate curves that are completely erratic and on which the alignment procedure is sure not to converge.

@alebrigant
Copy link
Collaborator

Working example:

import matplotlib.pyplot as plt
import geomstats.backend as gs
from geomstats.geometry.discrete_curves import R2, SRVShapeBundle


parametrized_curve_1 = lambda x: gs.transpose(
    gs.array([1 + 2 * gs.sin(gs.pi * x), 3 + 2 * gs.cos(gs.pi * x)])
)
parametrized_curve_2 = lambda x: gs.transpose(
    gs.array([5 * gs.ones(len(x)), 4 * (1 - x) + 1])
)
sampling_points = gs.linspace(0., 1., 10)
curve1 = parametrized_curve_1(sampling_points)
curve2 = parametrized_curve_2(sampling_points)

bundle = SRVShapeBundle(k_sampling_points=10, ambient_manifold=R2)
curve2_aligned = bundle.align(curve2, curve1)


plt.figure()
plt.plot(curve1[:, 0], curve1[:, 1], 'o-')
plt.plot(curve2_aligned[:, 0], curve2_aligned[:, 1], 'o-')
plt.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants