Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
bdpedigo committed Apr 30, 2024
1 parent bed96ab commit 978e8dc
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
50 changes: 46 additions & 4 deletions graspologic/match/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from beartype import beartype
from ot import sinkhorn
from scipy.optimize import linear_sum_assignment
from scipy.sparse import csr_array
from scipy.sparse import csr_array, lil_array
from scipy.sparse import linalg as sparse_linalg
from sklearn.utils import check_scalar

Expand Down Expand Up @@ -91,6 +91,9 @@ def __init__(
transport_max_iter: Int = 1000,
fast: bool = True,
sparse_position: bool = False,
damping_factor: Optional[Union[Scalar, Callable]] = None,
gradient_mask: Optional[AdjacencyMatrix] = None,
labels: Optional[np.ndarray] = None,
):
# TODO check if init is doubly stochastic
self.init = init
Expand Down Expand Up @@ -123,7 +126,11 @@ def __init__(
self.transport_max_iter = transport_max_iter

self.fast = fast

self.sparse_position = sparse_position
self.damping_factor = damping_factor
self.gradient_mask = gradient_mask
self.labels = labels

if maximize:
self.obj_func_scalar = -1
Expand Down Expand Up @@ -258,6 +265,13 @@ def __init__(

self.S_ss, self.S_sn, self.S_ns, self.S_nn = _split_matrix(S, n_seeds)

@property
def damping_factor_at_iter(self):
if callable(self.damping_factor):
return self.damping_factor(self.n_iter_)
else:
return self.damping_factor

def solve(self, rng: RngType = None) -> None:
rng = np.random.default_rng(rng)

Expand All @@ -277,11 +291,20 @@ def solve(self, rng: RngType = None) -> None:
assert isinstance(gradient, csr_array)

Q = self.compute_step_direction(gradient, rng)

Q = self.mask_gradient(Q)

if self.sparse_position:
assert isinstance(Q, csr_array)

alpha = self.compute_step_size(P, Q)

beta = 1 - alpha
damping_factor = self.damping_factor_at_iter
if damping_factor is not None:
beta *= damping_factor
alpha = 1 - beta

# take a step in this direction
P_new = alpha * P + (1 - alpha) * Q
if self.sparse_position:
Expand Down Expand Up @@ -355,6 +378,25 @@ def compute_gradient(self, P: np.ndarray) -> np.ndarray:
)
return gradient

@write_status("Masking gradient", 2)
def mask_gradient(self, gradient: np.ndarray) -> np.ndarray:
mask = self.gradient_mask
labels = self.labels
if mask is not None:
raise NotImplementedError(
"Arbitrary masking of gradient not implemented yet"
)
if labels is None:
return gradient
if isinstance(gradient, csr_array):
gradient = lil_array(gradient)
gradient[labels[:, None] != labels[None, :]] = 0
gradient.eliminate_zeros()
gradient = csr_array(gradient)
else:
gradient[labels[:, None] != labels[None, :]] = 0
return gradient

@write_status("Solving assignment problem", 2)
def compute_step_direction(
self, gradient: np.ndarray, rng: np.random.Generator
Expand Down Expand Up @@ -422,7 +464,7 @@ def linear_sum_transport(
)
return P_eps

@write_status("Computing step size", 2)
@write_status("Computing step size", 2, print_out=True)
def compute_step_size(self, P: np.ndarray, Q: np.ndarray) -> float:
a, b = _compute_coefficients(
P,
Expand Down Expand Up @@ -562,7 +604,7 @@ def _check_input_matrix(
raise ValueError(msg)
return A


# TODO implement gradient masking here?
def _compute_gradient(
P: np.ndarray,
A: MultilayerAdjacency,
Expand Down Expand Up @@ -817,7 +859,7 @@ def _check_init_input(init: np.ndarray, n: int) -> None:
msg = ""
if init.shape != (n, n):
msg = "`init` matrix must be n x n, where n is the number of non-seeded nodes"
elif ( # TODO make this work for sparse case
elif ( # TODO make this work for sparse case
(~np.isclose(row_sum, 1, atol=tol)).any()
or (~np.isclose(col_sum, 1, atol=tol)).any()
or (init < 0).any()
Expand Down
8 changes: 7 additions & 1 deletion graspologic/match/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation and contributors.
# Licensed under the MIT License.

from typing import Any, NamedTuple, Optional, Union
from typing import Any, NamedTuple, Optional, Union, Callable

import numpy as np
from beartype import beartype
Expand Down Expand Up @@ -72,6 +72,9 @@ def graph_match(
transport_max_iter: Int = 1000,
fast: bool = True,
sparse_position: bool = False,
damping_factor: Optional[Union[Callable, Scalar]] = None,
gradient_mask: Optional[np.ndarray] = None,
labels: Optional[np.ndarray] = None,
) -> MatchResult:
"""
Attempts to solve the Graph Matching Problem or the Quadratic Assignment Problem
Expand Down Expand Up @@ -302,6 +305,9 @@ def graph_match(
fast=fast,
verbose=solver_verbose,
sparse_position=sparse_position,
damping_factor=damping_factor,
gradient_mask=gradient_mask,
labels=labels,
)

def run_single_graph_matching(seed: RngType) -> MatchResult:
Expand Down

0 comments on commit 978e8dc

Please sign in to comment.