Skip to content

Commit

Permalink
add consitancy for svd
Browse files Browse the repository at this point in the history
  • Loading branch information
jdtuck committed Apr 9, 2024
1 parent 7dfca79 commit e5ae269
Showing 1 changed file with 66 additions and 6 deletions.
72 changes: 66 additions & 6 deletions fdasrsf/utility_functions.py
Expand Up @@ -15,7 +15,7 @@
from numpy import ones, real, pi, cumsum, fabs, cov, diagflat, inner
from numpy import gradient, column_stack, append, mean, hstack, median
from numpy import insert, vectorize, ceil, mod, array, percentile, dot
from numpy import exp
from numpy import exp, argmax, abs, sign, take, reshape, newaxis
from joblib import Parallel, delayed
import numpy.random as rn
import optimum_reparamN2 as orN2
Expand Down Expand Up @@ -43,8 +43,8 @@ def smooth_data(f, sparam=1):
fo = f.copy()
for k in range(0, sparam):
for r in range(0, N):
fo[1 : (M - 2), r] = (
fo[0 : (M - 3), r] + 2 * fo[1 : (M - 2), r] + fo[2 : (M - 1), r]
fo[1: (M - 2), r] = (
fo[0: (M - 3), r] + 2 * fo[1: (M - 2), r] + fo[2: (M - 1), r]
) / 4
return fo

Expand Down Expand Up @@ -210,6 +210,7 @@ def optimum_reparam(
gam[:, i] = obj.gammaOpt
elif method == "cRBFGS":
import crbfgs as cr

if penalty == "roughness":
pen = 0
elif penalty == "l2gam":
Expand Down Expand Up @@ -340,7 +341,8 @@ def elastic_depth(f, time, method="DP2", lam=0.0, parallel=True):
phs_dist[i, :] = out[i][1]
else:
for i in range(0, fns):
amp_dist[i, :], phs_dist[i, :] = distmat(f, f[:, i], time, i, method)
amp_dist[i, :], phs_dist[i, :] = distmat(f, f[:, i], time, i,
method)

amp_dist = amp_dist + amp_dist.T
phs_dist = phs_dist + phs_dist.T
Expand Down Expand Up @@ -450,7 +452,6 @@ def SqrtMeanInverse(gam):
min_ind = dqq.argmin()
mu = psi[:, min_ind]
maxiter = 501
tt = 1
lvm = zeros(maxiter)
vec = zeros((T, n))
stp = 0.3
Expand Down Expand Up @@ -657,7 +658,9 @@ def cumtrapzmid(x, y, c, mid):
fa[0:mid] = tmp[::-1]

# case >= mid
fa[mid:a] = c + cumulative_trapezoid(y[mid - 1 : a - 1], x[mid - 1 : a - 1], initial=0)
fa[mid:a] = c + cumulative_trapezoid(
y[mid - 1: a - 1], x[mid - 1: a - 1], initial=0
)

return fa

Expand Down Expand Up @@ -1164,7 +1167,64 @@ def mrdivide(a, b):

def rlbfgs_dist(q1, q2):
import crbfgs as cr

q1 = q1.copy(order="C")
q2 = q2.copy(order="C")
d = cr.rlbfgs_dist(q1, q2)
return d


def svd_flip(u, v, u_based_decision=True):
"""Sign correction to ensure deterministic output from SVD.
Adjusts the columns of u and the rows of v such that the loadings in the
columns in u that are largest in absolute value are always positive.
If u_based_decision is False, then the same sign correction is applied to
so that the rows in v that are largest in absolute value are always
positive.
Parameters
----------
u : ndarray
Parameters u and v are the output of `linalg.svd` or
:func:`~sklearn.utils.extmath.randomized_svd`, with matching inner
dimensions so one can compute `np.dot(u * s, v)`.
v : ndarray
Parameters u and v are the output of `linalg.svd` or
:func:`~sklearn.utils.extmath.randomized_svd`, with matching inner
dimensions so one can compute `np.dot(u * s, v)`. The input v should
really be called vt to be consistent with scipy's output.
u_based_decision : bool, default=True
If True, use the columns of u as the basis for sign flipping.
Otherwise, use the rows of v. The choice of which variable to base the
decision on is generally algorithm dependent.
Returns
-------
u_adjusted : ndarray
Array u with adjusted columns and the same dimensions as u.
v_adjusted : ndarray
Array v with adjusted rows and the same dimensions as v.
"""

if u_based_decision:
# columns of u, rows of v, or equivalently rows of u.T and v
max_abs_u_cols = argmax(abs(u.T), axis=1)
shift = arange(u.T.shape[0])
indices = max_abs_u_cols + shift * u.T.shape[1]
signs = sign(take(reshape(u.T, (-1,)), indices, axis=0))
u *= signs[newaxis, :]
v *= signs[:, newaxis]
else:
# rows of v, columns of u
max_abs_v_rows = argmax(abs(v), axis=1)
shift = arange(v.shape[0])
indices = max_abs_v_rows + shift * v.shape[1]
signs = sign(take(reshape(v, (-1,)), indices))
u *= signs[newaxis, :]
v *= signs[:, newaxis]
return u, v

0 comments on commit e5ae269

Please sign in to comment.