New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Benedict: Visualization of manifolds of categorical distributions #1813
base: main
Are you sure you want to change the base?
Conversation
…l_distributions.py
…_distributions.py
…ricalDistributionsManifold.py
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes to be made before we can merge:
- no CamelCase in python file names + use simpler file name: CategoricalDistributionsManifold.py --> categorical.py (and likewise for the notebook)
- test_visualization_manifold_of_categorical_distributions.py --> test_visualization_categorical.py
- the folder
image
containing the images needs to be removed, the images contained within it need to go to the existingnotebooks/figures/
, and the text in the notebook should be adapted accordingly. - add test functions for functions that are in categorical.py but not tests in the test file.
- Copy-paste the "setup" section of this example so that the notebook runs.
@@ -0,0 +1,484 @@ | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring at the top of file
import numpy as np | ||
from geomstats.information_geometry.categorical import ( | ||
CategoricalDistributions, CategoricalMetric) | ||
from matplotlib import pyplot as plt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection | ||
|
||
|
||
class CategoricalDistributionsManifold: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename it CategoricalSimplex
""" | ||
|
||
def __init__(self, dim): | ||
"""Construct a CategoricalDistributionsManifold object. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change docstring following my comment on renaming to CategoricalSimplex
def plot(self): | ||
"""Plot the 2D or 3D Manifold. | ||
|
||
Plot the 2D Manifold as a regular 2-simplex(triangle) or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: add space before (
"""Plot the 2D or 3D Manifold. | ||
|
||
Plot the 2D Manifold as a regular 2-simplex(triangle) or | ||
the 3D Manifold as a regular 3-simplex(tetrahedral). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
|
||
Construct a CategoricalDistributionsManifold with a given dimension. | ||
|
||
Parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove : and corresponding underline
operation="Exp", | ||
) | ||
|
||
def plot_helper(self, end_point, base_point, tangent_vec, operation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to _plot_two_points_and_a_vector
Manifold" | ||
) | ||
if self.dim == 3: | ||
# Plot in Matplotlib |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rm unnecessary comment
@@ -0,0 +1,48 @@ | |||
"""Unit tests for visualization.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for visualization of the manifold of categorical distributions
matplotlib.use("Agg") # NOQA | ||
|
||
|
||
class TestVisualizationManifoldOfCategoricalDistributions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to TestVisualizationCategorical
self.CD2.scatter(self.n_samples) | ||
self.CD3.scatter(self.n_samples) | ||
|
||
def test_plot_geodesic(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing test_plot_log, test_plot_exp, test__plot_two_points_and_a_vector
Checklist
Description
Issue
Additional context