Skip to content

Commit

Permalink
Some improvements. Change narchetypes signature
Browse files Browse the repository at this point in the history
  • Loading branch information
aleixalcacer committed Nov 21, 2023
1 parent c5c609f commit 575e907
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 36 deletions.
3 changes: 2 additions & 1 deletion archetypes/algorithms/torch/__init__.py
@@ -1,5 +1,6 @@
from .archetypes import AA
from .biarchetypes import BiAA
from .narchetypes import NAA
from .narchetypes2 import NAA2

__all__ = ["AA", "BiAA", "NAA"]
__all__ = ["AA", "BiAA", "NAA", "NAA2"]
141 changes: 115 additions & 26 deletions archetypes/algorithms/torch/narchetypes.py
@@ -1,9 +1,13 @@
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from tqdm.auto import tqdm

from .utils import einsum
from .utils import einsum, einsum_dc, loss_fun, softmax


def map_params(params, relations):
return [params[r] for r in relations]


class NAA(nn.Module):
Expand All @@ -22,27 +26,81 @@ class NAA(nn.Module):
The device to use for training the model. Defaults to "cpu".
"""

def __init__(self, k, s, device="cpu"):
def __init__(
self,
n_archetypes,
shape,
relations=None,
degree_correction=False,
membership="soft",
loss="normal",
device="cpu",
):
super().__init__()

# Check that k and s has the same length
assert len(k) == len(s), "k and s must have the same length"
if len(n_archetypes) > len(shape):
raise ValueError(
"The number of archetypes must be less or equal than the number of dimensions."
)

self.k = n_archetypes
self.s = shape
# compute difference between k and s dimensions
self.n_free_dim = len(self.s) - len(self.k)

self.k = k
self.s = s
self.n = len(k)
self.n = len(n_archetypes)
self.device = device

self._A = [
torch.nn.Parameter(torch.randn(s_i, k_i, device=self.device), requires_grad=True)
for s_i, k_i in zip(self.s, self.k)
if membership not in ["soft"]:
raise ValueError("membership must be one of 'soft'")
self.membership = membership

if loss not in ["normal", "bernoulli", "poisson"]:
raise ValueError("loss must be one of 'normal', 'bernoulli', 'poisson'")
self.loss = loss

# relations
if relations is None:
relations = list(np.arange(self.n))
self.relations = relations

relations_s = dict(zip(self.relations, self.s))
relations_k = dict(zip(self.relations, self.k))

# unique sorted relations
relations_unique = sorted(set(relations))

self._DC = None
if degree_correction:
self._DC_params = [
torch.nn.Parameter(
torch.randn(relations_s[r], device=self.device), requires_grad=True
)
for r in relations_unique
]
self._DC = map_params(self._DC_params, relations)

# data-membership matrices
self._A_params = [
torch.nn.Parameter(
torch.randn(relations_s[r], relations_k[r], device=self.device), requires_grad=True
)
for r in relations_unique
]

self._B = [
torch.nn.Parameter(torch.randn(k_i, s_i, device=self.device), requires_grad=True)
for s_i, k_i in zip(self.s, self.k)
self._A = map_params(self._A_params, relations)

# archetype-membership matrices
self._B_params = [
torch.nn.Parameter(
torch.randn(relations_k[r], relations_s[r], device=self.device), requires_grad=True
)
for r in relations_unique
]

self._B = map_params(self._B_params, relations)

# archetypes
self._Z = None

self.losses = []
Expand All @@ -55,14 +113,14 @@ def _loss(self):
X1 = self._X
Z = einsum(self.B, X1)
X2 = einsum(self.A, Z)
if self._DC:
X2 = einsum_dc(self.DC, X2)

loss = torch.pow(X1 - X2, 2).sum()

return loss
return loss_fun(X1, X2, self.loss).sum()

def train(self, data, n_epochs, learning_rate=0.01):
def fit(self, data, n_epochs, learning_rate=0.01):
"""
Train the model.
Fit the model.
Parameters
----------
Expand All @@ -80,7 +138,11 @@ def train(self, data, n_epochs, learning_rate=0.01):
self._X = data.to(self.device)
self._Z = einsum(self.B, self._X)

optimizer = torch.optim.Adam(params=[*self._A, *self._B], lr=learning_rate)
params = [*self._A_params, *self._B_params]
if self._DC:
params += self._DC_params

optimizer = torch.optim.Adam(params=params, lr=learning_rate)

pbar_epoch = tqdm(range(n_epochs), leave=True)

Expand All @@ -90,11 +152,24 @@ def train(self, data, n_epochs, learning_rate=0.01):
loss.backward()
optimizer.step()

self.losses.append(loss.item())
pbar_epoch.set_description(f"Epoch {epoch}/{n_epochs} | loss {loss.item():.4f}")
loss_item = loss.item()
self.losses.append(loss_item)
pbar_epoch.set_description(f"Epoch {epoch}/{n_epochs} | loss {loss_item:.4f}")

self._Z = einsum(self.B, self._X)
plt.close()

@property
def DC(self):
"""
The degree correction matrices.
Returns
-------
list of torch.Tensor
"""
if self._DC is None:
return None
return [torch.sigmoid(dc_i) for dc_i in self._DC]

@property
def A(self):
Expand All @@ -105,7 +180,7 @@ def A(self):
-------
list of torch.Tensor
"""
return [torch.softmax(a_i, dim=1) for a_i in self._A]
return [softmax(a_i, dim=1) for a_i in self._A]

@property
def B(self):
Expand All @@ -121,10 +196,24 @@ def B(self):
@property
def Z(self):
"""
The archetypes matrix.
The archetype matrix.
Returns
-------
torch.Tensor
"""
return self._Z

@property
def estimated_data(self):
"""
The estimated data matrix.
Returns
-------
torch.Tensor
"""
data = einsum(self.A, self.Z)
if self._DC:
data = einsum_dc(self.DC, data)
return data
73 changes: 66 additions & 7 deletions archetypes/algorithms/torch/utils.py
@@ -1,14 +1,56 @@
import torch
from opt_einsum import contract

# from torch.nn.functional import gumbel_softmax

def einsum(param_tensors, tensor):

def einsum(param_tensors, tensor: torch.Tensor):
n = len(param_tensors)
letters = [chr(i) for i in range(97, 97 + 2 * n)]
inner_symbols = letters[:n]
outer_symbols = letters[-n:]
equation = [f"{o}{i}," for o, i in zip(outer_symbols, inner_symbols)]
diff = tensor.ndim - n

letters = [chr(i) for i in range(97, 97 + 2 * n + diff)]
inner_symbols = letters[:n] + letters[2 * n :]
outer_symbols = letters[n : 2 * n] + letters[2 * n :]

equation = [f"{o}{i}," for o, i in zip(outer_symbols[:n], inner_symbols[:n])]
equation = "".join(equation) + "".join(inner_symbols) + "->" + "".join(outer_symbols)
return torch.einsum(equation, *param_tensors, tensor)

return contract(equation, *param_tensors, tensor)


def partial_einsum(param_tensors, tensor: torch.Tensor, index: []):
n = len(param_tensors) + len(index)
diff = tensor.ndim - n

letters = [chr(i) for i in range(97, 97 + 2 * n + diff)]
inner_symbols = letters[:n] + letters[2 * n :]
outer_symbols = letters[n : 2 * n] + letters[2 * n :]
res_equation = [
f"{o}" if ind not in index else f"{i}"
for ind, (o, i) in enumerate(zip(outer_symbols[:n], inner_symbols[:n]))
] + letters[2 * n :]

equation = [
f"{o}{i},"
for ind, (o, i) in enumerate(zip(outer_symbols[:n], inner_symbols[:n]))
if ind not in index
]
equation = "".join(equation) + "".join(inner_symbols) + "->" + "".join(res_equation)

return contract(equation, *param_tensors, tensor, optimize="auto")


def einsum_dc(param_tensors, tensor):
n = len(param_tensors)
diff = tensor.ndim - n

letters = [chr(i) for i in range(97, 97 + 2 * n + diff)]
inner_symbols = letters[:n] + letters[2 * n :]

equation = [f"{i}," for i in inner_symbols[:n]]
equation = "".join(equation) + "".join(inner_symbols) + "->" + "".join(inner_symbols)

return contract(equation, *param_tensors, tensor)


def hardmax(tensor, dim):
Expand All @@ -20,7 +62,9 @@ def hardmax(tensor, dim):


def softmax(tensor, dim):
return torch.softmax(tensor, dim=dim)
# return gumbel_softmax(tensor, tau=2, dim=dim, hard=False)
y_soft = torch.softmax(tensor, dim)
return y_soft


def normal_distance(X, maxoids):
Expand Down Expand Up @@ -54,3 +98,18 @@ def update_clusters_A(X, clusters, centroids, likelihood="normal"):

def update_clusters_D(X, clusters, maxoids, likelihood="normal"):
return update_clusters_A(X.T, clusters.T, maxoids.T, likelihood).T


def loss_fun(X1, X2, loss="normal"):
if loss == "normal":
loss_i = torch.pow(X1 - X2, 2)
elif loss == "bernoulli":
e = 1e-8
X2[X2 == 0] = e
X2[X2 == 1] = 1 - e
loss_i = -(X1 * X2.log() + (1 - X1) * (1 - X2).log())
elif loss == "poisson":
loss_i = -(X1 * torch.log(X2) - X2)
else:
raise ValueError("loss must be one of 'normal', 'bernoulli', 'poisson'")
return loss_i
13 changes: 11 additions & 2 deletions archetypes/datasets/permutations.py
Expand Up @@ -83,8 +83,17 @@ def sort_by_archetype_similarity(data, alphas, archetypes):
"""

# reorder data and archetypes by the number of elements in each 'archetypal group'
perms = [np.argsort(-np.unique(np.argmax(a, axis=1), return_counts=True)[1]) for a in alphas]
alphas = [a[:, perms_i] for a, perms_i in zip(alphas, perms)]

perms = []
for i, alpha_i in enumerate(alphas):
values, counts = np.unique(np.argmax(alpha_i, axis=1), return_counts=True)
# sort values by counts
perms_i = values[np.argsort(-counts)]
# add missing indexes to perms
perms_i = np.concatenate([perms_i, np.setdiff1d(np.arange(alpha_i.shape[1]), perms_i)])
perms.append(perms_i)

alphas = [a_i[:, perms_i] for a_i, perms_i in zip(alphas, perms)]

archetypes, _ = permute_dataset(archetypes, perms)

Expand Down

0 comments on commit 575e907

Please sign in to comment.