Skip to content

Commit

Permalink
Directional sparsity cost functional (#259)
Browse files Browse the repository at this point in the history
* implement directional sparsity

* Update cost_functions.py

* Update oc.py

* Update oc.py
  • Loading branch information
lenasal committed Feb 13, 2024
1 parent b797770 commit ec2434e
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 11 deletions.
64 changes: 56 additions & 8 deletions neurolib/control/optimal_control/cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def accuracy_cost(x, target_timeseries, weights, cost_matrix, dt, interval=(0, N
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
:type interval: tuple, optional
:return: Accuracy cost.
:rtype: float
"""
Expand Down Expand Up @@ -56,7 +56,7 @@ def derivative_accuracy_cost(x, target_timeseries, weights, cost_matrix, interva
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
:type interval: tuple, optional
:return: Accuracy cost derivative.
:rtype: ndarray
"""
Expand Down Expand Up @@ -84,7 +84,7 @@ def precision_cost(x_sim, x_target, cost_matrix, interval=(0, None)):
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
:type interval: tuple
:return: Precision cost for time interval.
:rtype: float
"""
Expand Down Expand Up @@ -114,7 +114,7 @@ def derivative_precision_cost(x_sim, x_target, cost_matrix, interval):
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
:type interval: tuple
:return: Control-dimensions x T array of precision cost gradients.
:rtype: np.ndarray
"""
Expand All @@ -140,7 +140,7 @@ def control_strength_cost(u, weights, dt):
:type weights: dictionary
:param dt: Time step.
:type dt: float
:return: control strength cost of the control.
:rtype: float
"""
Expand All @@ -159,17 +159,22 @@ def control_strength_cost(u, weights, dt):
for t in range(u.shape[2]):
cost += cost_timeseries[n, v, t] * dt

if weights["w_1D"] != 0.0:
cost += weights["w_1D"] * L1D_cost_integral(u, dt)

return cost


@numba.njit
def derivative_control_strength_cost(u, weights):
def derivative_control_strength_cost(u, weights, dt):
"""Derivative of the 'control_strength_cost' wrt. the control 'u'.
:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray
:param weights: Dictionary of weights.
:type weights: dictionary
:param dt: Time step.
:type dt: float
:return: Control-dimensions x T array of L2-cost gradients.
:rtype: np.ndarray
Expand All @@ -179,6 +184,8 @@ def derivative_control_strength_cost(u, weights):

if weights["w_2"] != 0.0:
der += weights["w_2"] * derivative_L2_cost(u)
if weights["w_1D"] != 0.0:
der += weights["w_1D"] * derivative_L1D_cost(u, dt)

return der

Expand All @@ -189,7 +196,7 @@ def L2_cost(u):
:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray
:return: L2 cost of the control.
:rtype: float
"""
Expand All @@ -203,8 +210,49 @@ def derivative_L2_cost(u):
:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray
:return: Control-dimensions x T array of L2-cost gradients.
:rtype: np.ndarray
"""
return u


@numba.njit
def L1D_cost_integral(
u,
dt,
):
"""'Directional sparsity' or 'L1D' cost integrated over time. Penalizes for control strength.
:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray
:param dt: Time step.
:type dt: float
:return: L1D cost of the control.
:rtype: float
"""

return np.sum(np.sum(np.sqrt(np.sum(u**2, axis=2) * dt), axis=1), axis=0)


@numba.njit
def derivative_L1D_cost(
u,
dt,
):
"""
:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray
:param dt: Time step.
:type dt: float
:return : Control-dimensions x T array of L1D-cost gradients.
:rtype: np.ndarray
"""

denominator = np.sqrt(np.sum(u**2, axis=2) * dt)
der = np.zeros((u.shape))
for n in range(der.shape[0]):
for v in range(der.shape[1]):
if denominator[n, v] != 0.0:
der[n, v, :] = u[n, v, :] / denominator[n, v]

return der
7 changes: 4 additions & 3 deletions neurolib/control/optimal_control/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def getdefaultweights():
)
weights["w_p"] = 1.0
weights["w_2"] = 0.0
weights["w_1D"] = 0.0

return weights

Expand Down Expand Up @@ -471,14 +472,14 @@ def __init__(
for v, iv in enumerate(self.model.input_vars):
control[:, v, :] = self.model.params[iv]

self.control = control.copy()
self.control = control.copy()
self.check_params()

self.control = update_control_with_limit(
self.N, self.dim_in, self.T, control, 0.0, np.zeros(control.shape), self.maximum_control_strength
)

self.model_params = self.get_model_params()
self.model_params = self.get_model_params()

def check_params(self):
"""Checks a subset of parameters and throws an error if a wrong dimension is found."""
Expand Down Expand Up @@ -624,7 +625,7 @@ def compute_gradient(self):
:rtype: np.ndarray of shape N x V x T
"""
self.solve_adjoint()
df_du = cost_functions.derivative_control_strength_cost(self.control, self.weights)
df_du = cost_functions.derivative_control_strength_cost(self.control, self.weights, self.dt)
d_du = self.Duh()

return compute_gradient(
Expand Down
24 changes: 24 additions & 0 deletions tests/control/optimal_control/test_oc_cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,30 @@ def test_derivative_L2_cost(self):
desired_output = u
self.assertTrue(np.all(cost_functions.derivative_L2_cost(u) == desired_output))

def test_L1D_cost(self):
print(" Test L1D cost")
dt = 0.1
reference_result = 2.0 * np.sum(np.sqrt(np.sum(p.TEST_INPUT_1N_6**2 * dt, axis=1)))
weights = getdefaultweights()
weights["w_1D"] = 1.0
u = np.concatenate([p.TEST_INPUT_1N_6[:, np.newaxis, :], p.TEST_INPUT_1N_6[:, np.newaxis, :]], axis=1)
L1D_cost = cost_functions.control_strength_cost(u, weights, dt)

self.assertAlmostEqual(L1D_cost, reference_result, places=8)

def test_derivative_L1D_cost(self):
print(" Test L1D cost derivative")
dt = 0.1
denominator = np.sqrt(np.sum(p.TEST_INPUT_1N_6**2 * dt, axis=1))

u = np.concatenate([p.TEST_INPUT_1N_6[:, np.newaxis, :], p.TEST_INPUT_1N_6[:, np.newaxis, :]], axis=1)
reference_result = np.zeros((u.shape))
for n in range(u.shape[0]):
for v in range(u.shape[1]):
reference_result[n, v, :] = u[n, v, :] / denominator[n]

self.assertTrue(np.all(cost_functions.derivative_L1D_cost(u, dt) == reference_result))

def test_weights_dictionary(self):
print("Test dictionary of cost weights")
model = FHNModel()
Expand Down

0 comments on commit ec2434e

Please sign in to comment.