Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort Markov matrix #218

Open
cap-jmk opened this issue Mar 7, 2022 · 21 comments
Open

Sort Markov matrix #218

cap-jmk opened this issue Mar 7, 2022 · 21 comments

Comments

@cap-jmk
Copy link

cap-jmk commented Mar 7, 2022

Is your feature request related to a problem? Please describe.
I am doing Markov modelling for SAR/QSAR analysis of chemical compounds and would need sorted markov matrices.

I suggest to sort the Markov matrix according to the most stable state. Something like with better memory management:

def sort_markov_matrix(markov_matrix): 
    """Takes in random markov matrix
    returns sorted markov matrix 

    Args:
        markov_matrix (np.array): unsorted matrix

    Returns:
        (np.array): sorted Markov matrix
    """
    
    
    
    b = markov_matrix.copy()
    for i in range(len(markov_matrix)): 
        ref1 = markov_matrix[i,i]
        for j in range(i+1, len(markov_matrix)): 
            ref2 = markov_matrix[j, j]
            if ref2 > ref1: 
                markov_matrix[i, :] = b[j, :]
                markov_matrix[j, :] = b[i, :]
                b = markov_matrix.copy()
                for k in range(len(markov_matrix)):
                    markov_matrix[k,i] = b[k, j]
                    markov_matrix[k,j] = b[k, i]
                    b = markov_matrix.copy()
    return markov_matrix

Test with

def test_sort(): 
    a = np.array([[0.8, 0.1, 0.05, 0.05],[0.005, 0.9, 0.03, 0.015], [0.1, 0.2, 0.4, 0.3],[0.01, 0.02, 0.03, 0.94]])
    sorted_a = sort_markov_matrix(a)
    assert np.array_equal(sorted_a[0,:], np.array([0.94, 0.02, 0.01, 0.03])) == True, str(sorted_a[0,:])
    assert np.array_equal(sorted_a[1,:], np.array([0.015,0.9, 0.005, 0.03])) == True, str(sorted_a[1,:])
    assert np.array_equal(sorted_a[2,:], np.array([0.05, 0.1, 0.8, 0.05])) == True, str(sorted_a[2,:])
    assert np.array_equal(sorted_a[3,:], np.array([0.3, 0.2, 0.1, 0.4])) == True, str(sorted_a[3,:])

What do you think?

@clonker
Copy link
Member

clonker commented Mar 9, 2022

Hi, I think this may be a bit too specific to implement it as a default. There are different ways of understanding stability of a Markov state I would say, for instance you could

  • look at the main diagonal of the transition matrix as you did, or
  • you could look at the probability distribution in the stationary process,
  • you could also think about stable groups of states (like in PCCA+)

In that sense it might be better for each user to implement their own version of such relabeling. A more efficient variant of yours could for instance be implemented with permutation matrices. 🙂

so...

msm = estimate_msm(data)
msm_sorted = deeptime.markov.msm.MarkovStateModel(sort_msm(msm.transition_matrix))

@cap-jmk
Copy link
Author

cap-jmk commented Mar 10, 2022

Maybe you are right, however, it felt like it belonged to the overall Markov modelling which is part of the deeptime package. I can for sure implement it from my side, however, it feels strange. Implementation in deeptime would also improve readability of dependencies or the code in general. I.e.

msm = estimate_msm(data, sorted=True)

I think the TSM gives a good estimate of the states and serves like a fingerprint for my case. PCCA+ seems like a good idea and could be helpful in some cases.

How would you do it with permutation matrices? The best we could go would be O(n), right? How would the memory consumption look like for permutation matrices? I remember some application from solving linear systems with these matrices. Would your solution be similar?

@clonker
Copy link
Member

clonker commented Mar 11, 2022

To my knowedge there is no canonical way of sorting Markov states, so I do not think it is a good idea to make this a True/False decision. What could be done is offer a relabeling function such as

msm_sorted = msm.relabel(np.argsort(np.diag(msm.transition_matrix)), inplace=False)

in your case. I do not have the capacity to implement this right now but am happy to give pointers and work on pull requests with you. There are multiple layers to this, though. In particular we'll have to be very careful this doesn't break any other parts of the library where there are assumptions on the Markov states staying the same over the course of taking submodels (for example when restricting yourself to the largest connected component in terms of jump probability connectivity graph). Also there are the following cases to keep in mind:

  • Markov model without statistics: this should be relatively straightforward
  • Markov model with statistics: here we have to be careful to also relabel the statistics to keep everything consistent
  • Effect on Markov state model collections (in particular MEMMs)
  • Effect on hidden Markov models
  • Implementation on sparse transition matrices / count matrices

There are probably more things to keep in mind here. In any case I think the easiest for you is to really sort the matrix on your own and create a new MSM instance.

Regarding permutation matrices: Yes, we cannot get better than O(n), but we can achieve vectorization.

@cap-jmk
Copy link
Author

cap-jmk commented Mar 13, 2022

Yes, I know what you mean. I will give my best to support you.

@clonker
Copy link
Member

clonker commented Mar 14, 2022

Cool, thanks! 🚀 I think a good first step would be reordering count matrices (in TransitionCountModel). Do you want to have a stab at that? I am still not entirely sure what such a method should be called, as it's not really a sorting but rather a relabeling - in general at least. Perhaps permute? Or reorder? First I thought transpose might be a good fit but that is really more used in the context of axes.

@cap-jmk
Copy link
Author

cap-jmk commented Mar 28, 2022

@clonker, yes i could give it a try. Where do you want to change something? I would sort it in deeptime/markov/_transition_counting.py

@clonker
Copy link
Member

clonker commented Mar 28, 2022

yes that would be a good start!

@cap-jmk
Copy link
Author

cap-jmk commented Apr 13, 2022

Nice. Okay, I got a working sorting algorithm implemented. However, I would love that you review it before I start implementing it in deeptime. I don't know why, but I could only make it work with bubble sort on the diagonal.

def sort_markov_matrix(markov_matrix):
    """Takes in random markov matrix
    returns sorted markov matrix
    Args:
        markov_matrix (np.array): unsorted matrix
    Returns:
        (np.array): sorted Markov matrix
    """
    diag = np.diag(markov_matrix)
    sorting = np.argsort(diag)
    for i in range(len(diag)):
        for j in range(len(diag) - 1):
            if diag[j + 1] > diag[j]:
                markov_matrix[[j, j + 1]] = markov_matrix[[j + 1, j]]
                markov_matrix[:, [j, j + 1]] = markov_matrix[:, [j + 1, j]]
    return markov_matrix

@clonker
Copy link
Member

clonker commented Apr 14, 2022

So here is a version for dense matrices, ideally we would support both dense and sparse though:

import numpy as np
from deeptime.markov.msm import MarkovStateModel

P = np.random.uniform(0, 1, size=(5, 5))
P /= P.sum(1)[:, None]
msm = MarkovStateModel(P)

diag = np.diag(msm.transition_matrix)
sorting = np.argsort(diag)[::-1]

perm = np.eye(len(sorting), dtype=msm.transition_matrix.dtype)[sorting]
msm_reordered = MarkovStateModel(np.linalg.multi_dot((perm, msm.transition_matrix, perm.T)))

@cap-jmk
Copy link
Author

cap-jmk commented Apr 14, 2022

I see what you meant. With the multi-dot, you would always do Θ(2n) operations, whereas if you implement the sorting manually, you would do O(sqrt(n)) operations. Or am I overlooking something?

@clonker
Copy link
Member

clonker commented Apr 14, 2022

Nope not overlooking anything. While it probably warrants a benchmark, I would imagine that multi dot outperforms manual sorting in Python though. Things are different if you implement the sorting in an extension.

Edit: Actually matrix multiplications are (naively) Θ(n^3). In any case, here we can see that complexity =/= efficiency. 🙂

@clonker
Copy link
Member

clonker commented Apr 14, 2022

bench

@cap-jmk
Copy link
Author

cap-jmk commented Apr 15, 2022

Totally enlightening. Just, as you have the benchmark already written, I would be interested how it goes when we go beyond 10k samples. It's where things get messy usually.

@clonker
Copy link
Member

clonker commented Apr 19, 2022

To satisfy your curiosity:
bench

Now what would be interesting is the scaling behavior against a c/c++ coded sorting extension and against sparse matrices. Estimating a dense transition matrix with 10k Markov states is a tough task anyways because of the massive amounts of data you'd need.

@cap-jmk
Copy link
Author

cap-jmk commented Apr 19, 2022

Okay, I see. Maybe putting the loop into @njit() could help? It don't see why it should be slower. Re indexing should be faster than multiplying loads of elements, I guess.
Code:

from numba import njit

@njit(parallel=True)
def sort_markov_matrix(markov_matrix):
    """Takes in random markov matrix
    returns sorted markov matrix
    Args:
        markov_matrix (np.array): unsorted matrix
    Returns:
        (np.array): sorted Markov matrix
    """
    diag = np.diag(markov_matrix)
    sorting = np.argsort(diag)
    for i in range(len(diag)):
        for j in range(len(diag) - 1):
            if diag[j + 1] > diag[j]:
                markov_matrix[[j, j + 1]] = markov_matrix[[j + 1, j]]
                markov_matrix[:, [j, j + 1]] = markov_matrix[:, [j + 1, j]]
    return markov_matrix

@clonker
Copy link
Member

clonker commented Apr 20, 2022

njit doesn't work for this function on my machine, also I don't really want to pull another dependency into deeptime. If you want to put together a python-bound c++ implementation then i'm happy to benchmark it, though. The jit performance is comparable to the python loop.
In any case I think the vectorized permutation matrix implementation is a good middle ground between a lot of implementation work and harder to maintain code (c++ extension) vs. easy to write and maintain but poor performance (python loop).

@cap-jmk
Copy link
Author

cap-jmk commented Apr 20, 2022

Good point. I will be happy to provide a c++ implementation. However, I am not sure how to link it. Do you have any resources on it? Then, let's do it with your implementation?

@clonker
Copy link
Member

clonker commented Apr 20, 2022

I think my implementation would be a good way to move forward, yes. If you want to have a look at c++ extensions in general: We are using pybind11. The extensions are compiled, linked, and installed using CMake. Here is an example of that.

@clonker
Copy link
Member

clonker commented Aug 15, 2022

Hi, any progress on this?

@cap-jmk
Copy link
Author

cap-jmk commented Aug 19, 2022

Yes just learned some more C++ and uni politics and would have more time from now to work on it :)

@clonker
Copy link
Member

clonker commented Aug 19, 2022

Cool, let me know if you need pointers / help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants