Skip to content

Commit

Permalink
small gm tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
bdpedigo committed Feb 29, 2024
1 parent 50f0842 commit 27f3552
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
21 changes: 17 additions & 4 deletions graspologic/match/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def repl(f: Callable) -> Callable:


@parameterized
def write_status(f: Callable, msg: str, level: int) -> Callable:
def write_status(
f: Callable, msg: str, level: int, print_out: bool = False
) -> Callable:
@wraps(f)
def wrap(*args, **kw): # type: ignore
obj = args[0]
Expand All @@ -52,6 +54,10 @@ def wrap(*args, **kw): # type: ignore
sec = te - ts
output = total_msg + f" took {sec:.3f} seconds."
print(output)
if print_out:
total_msg = (level - 1) * " "
total_msg += obj.status() + " Result:" + str(result)
print(total_msg)
else:
result = f(*args, **kw)
return result
Expand Down Expand Up @@ -251,6 +257,7 @@ def __init__(
def solve(self, rng: RngType = None) -> None:
rng = np.random.default_rng(rng)

self.changes_ = []
self.n_iter_ = 0
if self.n_seeds == self.n: # all seeded, break
P = np.empty((0, 0))
Expand All @@ -268,7 +275,9 @@ def solve(self, rng: RngType = None) -> None:
# take a step in this direction
P_new = alpha * P + (1 - alpha) * Q

if self.check_converged(P, P_new):
change = self.compute_change(P, P_new)
self.changes_.append(change)
if self.check_converged(change):
self.converged_ = True
P = P_new
break
Expand Down Expand Up @@ -416,8 +425,12 @@ def compute_step_size(self, P: np.ndarray, Q: np.ndarray) -> float:
alpha = float(np.argmin([0, (b + a) * self.obj_func_scalar]))
return alpha

def check_converged(self, P: np.ndarray, P_new: np.ndarray) -> bool:
return np.linalg.norm(P - P_new) / np.sqrt(self.n_unseed) < self.tol
@write_status("Computing relative change from previous", 2, print_out=True)
def compute_change(self, P: np.ndarray, P_new: np.ndarray) -> float:
return np.linalg.norm(P - P_new) / np.sqrt(self.n_unseed)

def check_converged(self, change: float) -> bool:
return change < self.tol

@write_status("Finalizing assignment", 1)
def finalize(self, P: np.ndarray, rng: np.random.Generator) -> None:
Expand Down
2 changes: 2 additions & 0 deletions graspologic/match/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def graph_match(
transport_tol=transport_tol,
transport_max_iter=transport_max_iter,
fast=fast,
verbose=solver_verbose,
)

def run_single_graph_matching(seed: RngType) -> MatchResult:
Expand All @@ -311,6 +312,7 @@ def run_single_graph_matching(seed: RngType) -> MatchResult:
misc["n_iter"] = solver.n_iter_
misc["convex_solution"] = solver.convex_solution_
misc["converged"] = solver.converged_
misc['changes'] = solver.changes_
return MatchResult(indices_A, indices_B, score, [misc])

seeds = rng.integers(max_seed, size=n_init)
Expand Down

0 comments on commit 27f3552

Please sign in to comment.