Skip to content

Commit

Permalink
Improve sorting algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
aleixalcacer committed Oct 24, 2023
1 parent ba49788 commit b90d636
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions archetypes/datasets/permutations.py
Expand Up @@ -49,8 +49,8 @@ def shuffle_dataset(data, generator=None):
-------
data: array-like
The shuffled dataset.
perms: list of array-like
The permutations used to shuffle the dataset.
info: dict
The information about the shuffling.
"""

generator = check_generator(generator)
Expand All @@ -63,7 +63,7 @@ def shuffle_dataset(data, generator=None):
return data, info


def sort_by_archetype_similarity(data, alphas):
def sort_by_archetype_similarity(data, alphas, archetypes):
"""Sort a dataset using the archetypal spaces previously computed.
Parameters
Expand All @@ -72,15 +72,22 @@ def sort_by_archetype_similarity(data, alphas):
The dataset to sort.
alphas: list of array-like
The dataset in the archetypal spaces.
archetypes: list of array-like
The archetypes.
Returns
-------
data: array-like
The sorted dataset.
perms: list of array-like
The permutations used to sort the dataset.
info: dict
The information about the sorting.
"""

# reorder data and archetypes by the number of elements in each 'archetypal group'
perms = [np.argsort(-np.unique(np.argmax(a, axis=1), return_counts=True)[1]) for a in alphas]
alphas = [a[:, perms_i] for a, perms_i in zip(alphas, perms)]

archetypes, _ = permute_dataset(archetypes, perms)

values_to_sort = [(-np.max(a, axis=1), np.argmax(a, axis=1)) for a in alphas]
# get index of ordered values
perms = [np.lexsort(values_to_sort_i) for values_to_sort_i in values_to_sort]
Expand All @@ -95,6 +102,8 @@ def sort_by_archetype_similarity(data, alphas):
info["labels"] = labels
info["scores"] = scores
info["n_archetypes"] = [ai.shape[1] for ai in alphas]
info["alphas"] = alphas
info["archetypes"] = archetypes

return data, info

Expand All @@ -113,8 +122,8 @@ def sort_by_labels(data, labels):
-------
data: array-like
The sorted dataset.
perms: list of array-like
The permutations used to sort the dataset.
info: dict
The information about the sorting.
"""

perms = [np.lexsort([labels_i]) for labels_i in labels]
Expand Down

0 comments on commit b90d636

Please sign in to comment.