Skip to content

Commit

Permalink
fix some weird reversions to csr_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
bdpedigo committed Mar 24, 2023
1 parent d654e7a commit b219e5f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
4 changes: 2 additions & 2 deletions graspologic/embed/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _get_omnibus_matrix_sparse(matrices: List[csr_array]) -> csr_array:
# row
rows.append(hstack(current_row))

return vstack(rows, format="csr")
return csr_array(vstack(rows, format="csr"))


def _get_laplacian_matrices(
Expand Down Expand Up @@ -97,7 +97,7 @@ def _get_omni_matrix(
out : 2d-array
Array of shape (n_vertices * n_graphs, n_vertices * n_graphs)
"""
if isspmatrix_csr(graphs[0]):
if isinstance(graphs[0], csr_array):
return _get_omnibus_matrix_sparse(graphs) # type: ignore

shape = graphs[0].shape
Expand Down
30 changes: 23 additions & 7 deletions graspologic/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def average_matrices(
if isinstance(matrices[0], np.ndarray):
return np.mean(matrices, axis=0) # type: ignore
elif isspmatrix_csr(matrices[0]):
return sum(matrices) / len(matrices)
return np.sum(matrices) / len(matrices)

raise TypeError(f"Unexpected type {matrices}")

Expand Down Expand Up @@ -325,7 +325,9 @@ def symmetrize(
[1, 1, 1]])
"""
# graph = import_graph(graph)
sparse = isspmatrix_csr(graph)

sparse = isinstance(graph, csr_array)

pac = scipy.sparse if sparse else np

if method == "triu":
Expand All @@ -340,6 +342,11 @@ def symmetrize(

dia = diags(graph.diagonal()) if sparse else np.diag(np.diag(graph))
graph = graph + graph.T - dia

if sparse:
# some scipy funcs still output a sparse matrix
graph = csr_array(graph)

return graph


Expand All @@ -360,7 +367,11 @@ def remove_loops(graph: GraphRepresentation) -> Union[np.ndarray, csr_array]:
"""
graph = import_graph(graph)

dia = diags(graph.diagonal()) if isspmatrix_csr(graph) else np.diag(np.diag(graph))
dia = (
csr_array(diags(graph.diagonal()))
if isinstance(graph, csr_array)
else np.diag(np.diag(graph))
)

graph = graph - dia

Expand Down Expand Up @@ -460,23 +471,28 @@ def to_laplacian(
in_root = 1 / np.sqrt(in_degree) # this is 10x faster than ** -0.5
out_root = 1 / np.sqrt(out_degree)

diag = diags if isspmatrix_csr(graph) else np.diag
sparse = isinstance(graph, csr_array)

diag = diags if sparse else np.diag

in_root[np.isinf(in_root)] = 0
out_root[np.isinf(out_root)] = 0

in_root = diag(in_root) # just change to sparse diag for sparse support
out_root = diag(out_root)

if sparse:
# for some reason scipy is still returning in csr_matrix form, not what we want
in_root = csr_array(in_root)
out_root = csr_array(out_root)

if form == "I-DAD":
L = diag(in_degree) - A
L = in_root @ L @ in_root
elif form == "DAD" or form == "R-DAD":
L = out_root @ A @ in_root
if is_symmetric(A):
return symmetrize(
L, method="avg"
) # sometimes machine prec. makes this necessary
L = symmetrize(L, method="avg") # sometimes machine prec. makes this necessary
return L


Expand Down

0 comments on commit b219e5f

Please sign in to comment.