Skip to content

Commit

Permalink
Fix some style issues
Browse files Browse the repository at this point in the history
  • Loading branch information
aleixalcacer committed Oct 16, 2023
1 parent 93ceac3 commit 13367d8
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 31 deletions.
13 changes: 10 additions & 3 deletions archetypes/datasets/make_archetypal_dataset.py
Expand Up @@ -57,10 +57,17 @@ def make_archetypal_dataset(
A = [np.zeros((s_i, a_i)) for s_i, a_i in zip(shape, n_archetypes)]

for A_i, labels_i in zip(A, labels):
l_i_prev = -1
for i, l_i in enumerate(labels_i):
alpha_i = [alpha] * A_i.shape[1]
alpha_i[l_i] = 1
A_i[i, :] = generator.dirichlet(alpha_i)
if l_i_prev != l_i:
alpha_i = [0] * A_i.shape[1]
alpha_i[l_i] = 1
A_i[i, :] = alpha_i
l_i_prev = l_i
else:
alpha_i = [alpha] * A_i.shape[1]
alpha_i[l_i] = 1
A_i[i, :] = generator.dirichlet(alpha_i)

X = einsum(A, archetypes)

Expand Down
4 changes: 4 additions & 0 deletions archetypes/datasets/permutations.py
Expand Up @@ -88,9 +88,13 @@ def sort_by_archetype_similarity(data, alphas):
data, info = permute_dataset(data, perms)

labels = [np.argmax(a, axis=1) for a in alphas]
scores = [np.max(a, axis=1) for a in alphas]
labels = [labels[i][perms[i]] for i in range(data.ndim)]
scores = [scores[i][perms[i]] for i in range(data.ndim)]

info["labels"] = labels
info["scores"] = scores
info["n_archetypes"] = [ai.shape[1] for ai in alphas]

return data, info

Expand Down
19 changes: 4 additions & 15 deletions archetypes/visualization/bisimplex.py
Expand Up @@ -2,19 +2,8 @@
import matplotlib.pyplot as plt
import numpy as np

from archetypes.visualization import simplex


def _create_palette(saturation, value, n_colors, int_colors=3):
hue = np.linspace(0, 1, n_colors, endpoint=False)
hue = np.hstack([hue[i::int_colors] for i in range(int_colors)])
saturation = np.full(n_colors, saturation)
value = np.full(n_colors, value)
# convert to RGB
c = mpl.colors.hsv_to_rgb(np.vstack([hue, saturation, value]).T)
# Create palette
palette = mpl.colors.ListedColormap(c)
return palette
from .simplex import simplex
from .utils import create_palette


def bisimplex(alphas, archetypes, ax=None, **kwargs):
Expand Down Expand Up @@ -44,7 +33,7 @@ def bisimplex(alphas, archetypes, ax=None, **kwargs):
n_archetypes = archetypes.shape

# Get the colors for the vertices of the polytopes
palette = _create_palette(
palette = create_palette(
saturation=0.35, value=0.9, n_colors=n_archetypes[0] + n_archetypes[1], int_colors=1
)

Expand Down Expand Up @@ -85,7 +74,7 @@ def bisimplex(alphas, archetypes, ax=None, **kwargs):

# Use grayscale palette from matplotlib between 0 and 1
# archetypes[:] = .6
palette = mpl.colormaps["Grays"]
palette = mpl.colormaps["Greys"]

for x_i, y_i, c_i, a_i in zip(
xx.flatten(), yy.flatten(), archetypes.flatten(), archetypes_scaled.flatten()
Expand Down
99 changes: 87 additions & 12 deletions archetypes/visualization/heatmap.py
@@ -1,17 +1,26 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap, to_rgb
from matplotlib.patches import Polygon

from .utils import create_palette

def heatmap(data, labels=None, ax=None, **kwargs):

def heatmap(data, labels=None, n_archetypes=None, scores=None, ax=None, **kwargs):
"""Plot a heatmap of the data. If labels are provided, the heatmap is divided into cells.
Parameters
----------
data: np.ndarray
The data to plot.
labels: list of np.ndarray or None
The labels to use to divide the heatmap into cells. If None, no labels are used.
The labels values to use for the plot.
If None, the labels values are computed from the labels.
n_archetypes: list of int or None
The number of archetypes for each dimension.
If None, the number of archetypes is computed from the labels.
scores: list of np.ndarray or None
The scores values to use for the plot.
ax: matplotlib.pyplot.axes or None
The axes to plot on. If None, a new figure and axes is created.
kwargs: dict
Expand All @@ -33,6 +42,8 @@ def heatmap(data, labels=None, ax=None, **kwargs):
if "cmap" not in kwargs:
kwargs["cmap"] = "Greys"

data_size = max(data.shape)

# Plot line if labels[i] != labels[i+1]
if labels is not None:
# check labels is a list of 2 arrays
Expand All @@ -47,22 +58,87 @@ def heatmap(data, labels=None, ax=None, **kwargs):
f"labels must be a list of 2 arrays, got {type(labels[0])} and {type(labels[1])}"
)

labels_h = np.concatenate([[-1], labels[1].flatten(), [-1]])
labels_v = np.concatenate([[-1], labels[0].flatten(), [-1]])
labels_h = np.concatenate([labels[1].flatten()])
labels_v = np.concatenate([labels[0].flatten()])

polygon_kwargs = {"color": "r", "lw": 1}
polygon_kwargs = {"color": "k", "lw": 1}

for i in range(len(labels_h) - 1):
if labels_h[i] != labels_h[i + 1]:
line = Polygon(np.array([[i, 0], [i, data.shape[0]]]) - 0.5, **polygon_kwargs)
line = Polygon(
np.array([[i + 1, 0], [i + 1, data.shape[0]]]) - 0.5, **polygon_kwargs
)
ax.add_patch(line)

for i in range(len(labels_v) - 1):
if labels_v[i] != labels_v[i + 1]:
line = Polygon(np.array([[0, i], [data.shape[1], i]]) - 0.5, **polygon_kwargs)
line = Polygon(
np.array([[0, i + 1], [data.shape[1], i + 1]]) - 0.5, **polygon_kwargs
)
ax.add_patch(line)

ax.matshow(data, rasterized=True, **kwargs)
if n_archetypes is None:
n_archetypes = [len(np.unique(labels[0])), len(np.unique(labels[1]))]

# Add a rectangle to frame the data
rect = Polygon(
np.array(
[[0, 0], [data.shape[1], 0], [data.shape[1], data.shape[0]], [0, data.shape[0]]]
)
- 0.5,
fill=False,
**polygon_kwargs,
)
ax.add_patch(rect)

palette = create_palette(
saturation=0.35, value=0.9, n_colors=n_archetypes[0] + n_archetypes[1], int_colors=1
)
colors = palette(np.arange(n_archetypes[0] + n_archetypes[1]))
colors_1 = colors[: n_archetypes[0]]
colors_2 = colors[n_archetypes[0] : n_archetypes[0] + n_archetypes[1]]

# Plot archetypes
counts = [np.count_nonzero(labels[0] == i) for i in range(n_archetypes[0])]
counts = np.cumsum(counts)
counts = np.concatenate([[0], counts]) - 0.5

arch_factor = 0.05 * data_size

if scores is None:
scores = [np.ones_like(labels[0]), np.ones_like(labels[1])]

for c, (i0, i1) in zip(colors_1, zip(counts, counts[1:])):
c1 = np.array(to_rgb(c))
c2 = np.array([1, 1, 1])

ax.imshow(
scores[0][int(i0 + 0.5) : int(i1 + 0.5)][::-1].reshape(-1, 1),
extent=[-arch_factor, -2 * arch_factor, i0, i1],
cmap=LinearSegmentedColormap.from_list("c", [c2, c1]),
interpolation="none",
vmax=1,
vmin=0,
)

counts = [np.count_nonzero(labels[1] == i) for i in range(n_archetypes[1])]
counts = np.cumsum(counts)
counts = np.concatenate([[0], counts]) - 0.5

for c, (i0, i1) in zip(colors_2, zip(counts, counts[1:])):
c1 = np.array(to_rgb(c))
c2 = np.array([1, 1, 1])

ax.imshow(
scores[1][int(i0 + 0.5) : int(i1 + 0.5)].reshape(1, -1),
extent=[i0, i1, -arch_factor, -2 * arch_factor],
cmap=LinearSegmentedColormap.from_list("c", [c2, c1]),
interpolation="none",
vmax=1,
vmin=0,
)

ax.matshow(data, interpolation="none", **kwargs)

# set aspect ratio to equal
ax.set_aspect("equal")
Expand All @@ -75,9 +151,8 @@ def heatmap(data, labels=None, ax=None, **kwargs):
xlim = ax.get_xlim()
ylim = ax.get_ylim()

exp_factor = 0.01 * max(np.abs(np.diff(xlim)), np.abs(np.diff(ylim)))

ax.set_xlim(xlim[0] - exp_factor, xlim[1] + exp_factor)
ax.set_ylim(ylim[0] + exp_factor, ylim[1] - exp_factor)
lim_factor = 0.01 * data_size
ax.set_xlim(xlim[0] - lim_factor, xlim[1] + lim_factor)
ax.set_ylim(ylim[0] + lim_factor, ylim[1] - lim_factor)

return ax
2 changes: 1 addition & 1 deletion archetypes/visualization/simplex.py
Expand Up @@ -88,7 +88,7 @@ def simplex(
for p1, p2 in edges:
x1, y1 = p1
x2, y2 = p2
ax.plot([x1, x2], [y1, y2], "-", linewidth=0.75, color="lightgray", zorder=0)
ax.plot([x1, x2], [y1, y2], "-", linewidth=1, color="lightgray", zorder=0)

# ax.plot(vertices[:, 0], vertices[:, 1], "o", color="black", alpha=1)
if show_vertices:
Expand Down
14 changes: 14 additions & 0 deletions archetypes/visualization/utils.py
@@ -0,0 +1,14 @@
import matplotlib as mpl
import numpy as np


def create_palette(saturation, value, n_colors, int_colors=3):
hue = np.linspace(0, 1, n_colors, endpoint=False)
hue = np.hstack([hue[i::int_colors] for i in range(int_colors)])
saturation = np.full(n_colors, saturation)
value = np.full(n_colors, value)
# convert to RGB
c = mpl.colors.hsv_to_rgb(np.vstack([hue, saturation, value]).T)
# Create palette
palette = mpl.colors.ListedColormap(c)
return palette

0 comments on commit 13367d8

Please sign in to comment.