Skip to content

Commit

Permalink
Added control interval (#256)
Browse files Browse the repository at this point in the history
* added control interval parameter and test

* fix control interval test

* fix control interval test
  • Loading branch information
lenasal committed Feb 13, 2024
1 parent 9071ca8 commit 0776d7a
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 18 deletions.
52 changes: 42 additions & 10 deletions neurolib/control/optimal_control/oc_aln/oc_aln.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
import numpy as np

from neurolib.control.optimal_control.oc import OC
from neurolib.models.aln.timeIntegration import compute_hx, compute_hx_nw, Duh, Dxdoth, compute_hx_de, compute_hx_di
from neurolib.models.aln.timeIntegration import (
compute_hx,
compute_hx_nw,
Duh,
Dxdoth,
compute_hx_de,
compute_hx_di,
)


class OcAln(OC):
Expand All @@ -21,6 +28,7 @@ def __init__(
weights=None,
print_array=[],
cost_interval=(None, None),
control_interval=(None, None),
cost_matrix=None,
control_matrix=None,
M=1,
Expand All @@ -34,6 +42,7 @@ def __init__(
print_array=print_array,
cost_interval=cost_interval,
cost_matrix=cost_matrix,
control_interval=control_interval,
control_matrix=control_matrix,
M=M,
M_validation=M_validation,
Expand Down Expand Up @@ -197,7 +206,9 @@ def compute_hx_list(self):
hx_de = self.compute_hx_de()
hx_di = self.compute_hx_di()

return numba.typed.List([hx, hx_de, hx_di]), numba.typed.List([0, self.ndt_de, self.ndt_di])
return numba.typed.List([hx, hx_de, hx_di]), numba.typed.List(
[0, self.ndt_de, self.ndt_di]
)

def compute_hx(self):
"""Jacobians of ALNModel wrt. the 'e'- and 'i'-variable for each time step.
Expand Down Expand Up @@ -317,7 +328,9 @@ def get_fullstate(self):
if t <= T - 2:
self.model.params[iv] = control[:, iv_ind, t : t + 2]
elif t == T - 1:
self.model.params[iv] = np.concatenate((control[:, iv_ind, t:], np.zeros((N, 1))), axis=1)
self.model.params[iv] = np.concatenate(
(control[:, iv_ind, t:], np.zeros((N, 1))), axis=1
)
else:
self.model.params[iv] = 0.0
self.model.run()
Expand Down Expand Up @@ -349,11 +362,19 @@ def setasinit(self, fullstate, t):

for n in range(N):
for v in range(V):
if "rates" in self.model.init_vars[v] or "IA" in self.model.init_vars[v]:
if (
"rates" in self.model.init_vars[v]
or "IA" in self.model.init_vars[v]
):
if t >= T:
self.model.params[self.model.init_vars[v]] = fullstate[:, v, t - T : t + 1]
self.model.params[self.model.init_vars[v]] = fullstate[
:, v, t - T : t + 1
]
else:
init = np.concatenate((fullstate[:, v, -T + t + 1 :], fullstate[:, v, : t + 1]), axis=1)
init = np.concatenate(
(fullstate[:, v, -T + t + 1 :], fullstate[:, v, : t + 1]),
axis=1,
)
self.model.params[self.model.init_vars[v]] = init
else:
self.model.params[self.model.init_vars[v]] = fullstate[:, v, t]
Expand All @@ -371,8 +392,13 @@ def getinitstate(self):

for n in range(N):
for v in range(V):
if "rates" in self.model.init_vars[v] or "IA" in self.model.init_vars[v]:
initstate[n, v, :] = self.model.params[self.model.init_vars[v]][n, -T:]
if (
"rates" in self.model.init_vars[v]
or "IA" in self.model.init_vars[v]
):
initstate[n, v, :] = self.model.params[self.model.init_vars[v]][
n, -T:
]

else:
initstate[n, v, :] = self.model.params[self.model.init_vars[v]][n]
Expand All @@ -389,7 +415,10 @@ def getfinalstate(self):
state = np.zeros((N, V))
for n in range(N):
for v in range(V):
if "rates" in self.model.state_vars[v] or "IA" in self.model.state_vars[v]:
if (
"rates" in self.model.state_vars[v]
or "IA" in self.model.state_vars[v]
):
state[n, v] = self.model.state[self.model.state_vars[v]][n, -1]

else:
Expand All @@ -408,7 +437,10 @@ def setinitstate(self, state):

for n in range(N):
for v in range(V):
if "rates" in self.model.init_vars[v] or "IA" in self.model.init_vars[v]:
if (
"rates" in self.model.init_vars[v]
or "IA" in self.model.init_vars[v]
):
self.model.params[self.model.init_vars[v]] = state[:, v, -T:]
else:
self.model.params[self.model.init_vars[v]] = state[:, v, -1]
Expand Down
2 changes: 2 additions & 0 deletions neurolib/control/optimal_control/oc_fhn/oc_fhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
maximum_control_strength=None,
print_array=[],
cost_interval=(None, None),
control_interval=(None, None),
cost_matrix=None,
control_matrix=None,
M=1,
Expand All @@ -32,6 +33,7 @@ def __init__(
maximum_control_strength=maximum_control_strength,
print_array=print_array,
cost_interval=cost_interval,
control_interval=control_interval,
cost_matrix=cost_matrix,
control_matrix=control_matrix,
M=M,
Expand Down
2 changes: 2 additions & 0 deletions neurolib/control/optimal_control/oc_hopf/oc_hopf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
maximum_control_strength=None,
print_array=[],
cost_interval=(None, None),
control_interval=(None, None),
cost_matrix=None,
control_matrix=None,
M=1,
Expand All @@ -34,6 +35,7 @@ def __init__(
maximum_control_strength=maximum_control_strength,
print_array=print_array,
cost_interval=cost_interval,
control_interval=control_interval,
cost_matrix=cost_matrix,
control_matrix=control_matrix,
M=M,
Expand Down
2 changes: 2 additions & 0 deletions neurolib/control/optimal_control/oc_wc/oc_wc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
weights=None,
print_array=[],
cost_interval=(None, None),
control_interval=(None, None),
cost_matrix=None,
control_matrix=None,
M=1,
Expand All @@ -32,6 +33,7 @@ def __init__(
weights=weights,
print_array=print_array,
cost_interval=cost_interval,
control_interval=control_interval,
cost_matrix=cost_matrix,
control_matrix=control_matrix,
M=M,
Expand Down
53 changes: 45 additions & 8 deletions tests/control/optimal_control/test_oc_fhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def test_1n(self):
cost_mat = np.zeros((model.params.N, len(model.output_vars)))
control_mat = np.zeros((model.params.N, len(model.state_vars)))
control_mat[0, input_channel] = 1.0 # only allow inputs to input_channel
cost_mat[0, np.abs(input_channel - 1).astype(int)] = 1.0 # only measure other channel
cost_mat[
0, np.abs(input_channel - 1).astype(int)
] = 1.0 # only measure other channel

test_oc_utils.set_input(model, p.ZERO_INPUT_1N_6)
model.params[model.input_vars[input_channel]] = p.TEST_INPUT_1N_6
Expand All @@ -52,7 +54,9 @@ def test_1n(self):
model_controlled.optimize(p.ITERATIONS)
control = model_controlled.control

c_diff = (np.abs(control[0, input_channel, :] - p.TEST_INPUT_1N_6[0, :]),)
c_diff = (
np.abs(control[0, input_channel, :] - p.TEST_INPUT_1N_6[0, :]),
)

if np.amax(c_diff) < p.LIMIT_DIFF:
control_coincide = True
Expand Down Expand Up @@ -99,7 +103,11 @@ def test_2n(self):
)

model_controlled.control = np.concatenate(
[p.INIT_INPUT_2N_8[:, np.newaxis, :], p.ZERO_INPUT_2N_8[:, np.newaxis, :]], axis=1
[
p.INIT_INPUT_2N_8[:, np.newaxis, :],
p.ZERO_INPUT_2N_8[:, np.newaxis, :],
],
axis=1,
)
model_controlled.update_input()

Expand Down Expand Up @@ -261,7 +269,9 @@ def test_u_max_no_optimizations(self):
control_matrix=control_mat,
)

self.assertTrue(np.max(np.abs(model_controlled.control) <= maximum_control_strength))
self.assertTrue(
np.max(np.abs(model_controlled.control) <= maximum_control_strength)
)

# Arbitrary network and control setting, initial control violates the maximum absolute criterion.
def test_u_max_after_optimizations(self):
Expand Down Expand Up @@ -289,7 +299,9 @@ def test_u_max_after_optimizations(self):
)

model_controlled.optimize(1)
self.assertTrue(np.max(np.abs(model_controlled.control) <= maximum_control_strength))
self.assertTrue(
np.max(np.abs(model_controlled.control) <= maximum_control_strength)
)

def test_adjust_init(self):
print("Test adjust_init function of OC class")
Expand Down Expand Up @@ -327,7 +339,10 @@ def test_adjust_init(self):
for init_var0 in model.init_vars:
if "ou" in init_var0:
continue
self.assertTrue(model_controlled.model.params[init_var0].shape == targetinitshape)
self.assertTrue(
model_controlled.model.params[init_var0].shape
== targetinitshape
)

def test_adjust_input(self):
print("Test test_adjust_input function of OC class")
Expand All @@ -337,7 +352,9 @@ def test_adjust_input(self):
model = FHNModel(Cmat=cmat, Dmat=dmat)
model.params.duration = p.TEST_DURATION_6

target = np.zeros((model.params.N, len(model.state_vars), p.TEST_INPUT_2N_6.shape[1]))
target = np.zeros(
(model.params.N, len(model.state_vars), p.TEST_INPUT_2N_6.shape[1])
)
targetinputshape = (target.shape[0], target.shape[2])

for test_input in [
Expand All @@ -358,7 +375,27 @@ def test_adjust_input(self):
)

for input_var0 in model.input_vars:
self.assertTrue(model_controlled.model.params[input_var0].shape == targetinputshape)
self.assertTrue(
model_controlled.model.params[input_var0].shape
== targetinputshape
)

# tests if the control is only active in the control interval
# single-node case
def test_onenode_control_interval(self):
print("Test OC for control_interval = [0,0] in single-node model")
model = FHNModel()

model.params["duration"] = p.TEST_DURATION_8
test_oc_utils.setinitzero_1n(model)

test_oc_utils.set_input(model, p.TEST_INPUT_1N_8)
model.run()
target = test_oc_utils.gettarget_1n(model)

model_controlled = oc_fhn.OcFhn(model, target, control_interval=(0, 1))
model_controlled.optimize(1)
self.assertEqual(np.amax(np.abs(model_controlled.control[:, :, 1:])), 0.0)

# tests if the cost is independent of the integration time step
def test_cost_dt(self):
Expand Down

0 comments on commit 0776d7a

Please sign in to comment.