Skip to content

Commit

Permalink
fixes to multimodel, testexploration, and continue_run
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed Mar 5, 2024
1 parent 73f7b9f commit c6872e0
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 79 deletions.
16 changes: 10 additions & 6 deletions neurolib/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,19 @@ def initializeBold(self):
self.boldInitialized = True
# logging.info(f"{self.name}: BOLD model initialized.")

def simulateBold(self, variables, append=False):
def get_bold_variable(self, variables):
default_index = self.state_vars.index(self.default_output)
return variables[default_index]

def simulateBold(self, bold_variable, append=False):
"""Gets the default output of the model and simulates the BOLD model.
Adds the simulated BOLD signal to outputs.
"""
if not self.boldInitialized:
logging.warn("BOLD model not initialized, not simulating BOLD. Use `run(bold=True)`")
return

default_index = self.state_vars.index(self.default_output)
sv = variables[default_index]

bold_input = sv[:, self.startindt :]
bold_input = bold_variable[:, self.startindt :]
# logging.debug(f"BOLD input `{svn}` of shape {bold_input.shape}")
if not bold_input.shape[1] >= self.boldModel.samplingRate_NDt:
logging.warn(
Expand Down Expand Up @@ -265,7 +266,8 @@ def integrate(self, append_outputs=False, simulate_bold=False):

# bold simulation after integration
if simulate_bold and self.boldInitialized:
self.simulateBold(variables, append=append_outputs)
bold_variable = self.get_bold_variable(variables)
self.simulateBold(bold_variable, append=append_outputs)

def integrateChunkwise(self, chunksize, bold=False, append_outputs=False):
"""Repeatedly calls the chunkwise integration for the whole duration of the simulation.
Expand Down Expand Up @@ -327,6 +329,8 @@ def storeOutputsAndStates(self, t, variables, append=False):

def setInitialValuesToLastState(self):
"""Reads the last state of the model and sets the initial conditions to that state for continuing a simulation."""
if not hasattr(self, "t"):
raise ValueError("You tried using continue_run=True on the first run.")
for iv, sv in zip(self.init_vars, self.state_vars):
# if state variables are one-dimensional (in space only)
if (self.state[sv].ndim == 0) or (self.state[sv].ndim == 1):
Expand Down
48 changes: 8 additions & 40 deletions neurolib/models/multimodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ def run(
append = append_outputs

# if a previous run is not to be continued clear the model's state
if continue_run is False:
if continue_run:
self.setInitialValuesToLastState()
else:
self.clearModelState()

self.initializeRun(initializeBold=bold)

if chunkwise is False:
self.integrate(append_outputs=append, simulate_bold=bold, noise_input=noise_input)
if continue_run:
self.setInitialValuesToLastState()

else:
if chunksize is None:
Expand Down Expand Up @@ -206,9 +206,11 @@ def integrate(self, append_outputs=False, simulate_bold=False, noise_input=None)

# bold simulation after integration
if simulate_bold and self.boldInitialized:
self.simulateBold(result[self.default_output].values.T, append=True)
self.simulateBold(result[self.default_output].values.T, append=append_outputs)

def setInitialValuesToLastState(self):
if not hasattr(self, "t"):
raise ValueError("You tried using continue_run=True on the first run.")
# set start t for next run for the last value now
self.start_t = self.t[-1]
new_initial_state = np.zeros((self.model_instance.initial_state.shape[0], self.maxDelay + 1))
Expand All @@ -231,44 +233,10 @@ def integrateChunkwise(self, chunksize, bold, append_outputs):

def storeOutputsAndStates(self, results, append):
# save time array
self.setOutput("t", results.time.values + self.start_t, append=append, removeICs=False)
self.setStateVariables("t", results.time.values + self.start_t)
self.setOutput("t", results.time.values, append=append, removeICs=False)
self.setStateVariables("t", results.time.values)
# save outputs
for variable in results:
if variable in self.output_vars:
self.setOutput(variable, results[variable].values.T, append=append, removeICs=False)
self.setStateVariables(variable, results[variable].values.T)

def simulateBold(self, bold_variable, append):
if self.boldInitialized:
bold_input = bold_variable[:, self.startindt :]
if bold_input.shape[1] >= self.boldModel.samplingRate_NDt:
# only if the length of the output has a zero mod to the sampling rate,
# the downsampled output from the boldModel can correctly appended to previous data
# so: we are lazy here and simply disable appending in that case ...
if not bold_input.shape[1] % self.boldModel.samplingRate_NDt == 0:
append = False
logging.warn(
f"Output size {bold_input.shape[1]} is not a multiple of BOLD sample length "
f"{ self.boldModel.samplingRate_NDt}, will not append data."
)
logging.debug(f"Simulating BOLD: boldModel.run(append={append})")

# transform bold input according to self.boldInputTransform
if self.boldInputTransform:
bold_input = self.boldInputTransform(bold_input)

# simulate bold model
self.boldModel.run(bold_input, append=append)

t_BOLD = self.boldModel.t_BOLD
BOLD = self.boldModel.BOLD
self.setOutput("BOLD.t_BOLD", t_BOLD)
self.setOutput("BOLD.BOLD", BOLD)
else:
logging.warn(
f"Will not simulate BOLD if output {bold_input.shape[1]*self.params['dt']} not at least of duration"
f" {self.boldModel.samplingRate_NDt*self.params['dt']}"
)
else:
logging.warn("BOLD model not initialized, not simulating BOLD. Use `run(bold=True)`")
71 changes: 56 additions & 15 deletions neurolib/optimize/exploration/explorationUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def plotExplorationResults(
savename=None,
**kwargs,
):
"""
"""
""" """
# PREPARE DATA
# ------------------
# copy here, because we add another column that we do not want to keep later
Expand Down Expand Up @@ -127,9 +126,22 @@ def plotExplorationResults(
else:
mask = None

image = alphaMask(image, mask_threshold, mask_alpha, mask=mask, invert=mask_invert, style=mask_style,)

im = ax.imshow(image, extent=image_extent, origin="lower", aspect="auto", clim=plot_clim,)
image = alphaMask(
image,
mask_threshold,
mask_alpha,
mask=mask,
invert=mask_invert,
style=mask_style,
)

im = ax.imshow(
image,
extent=image_extent,
origin="lower",
aspect="auto",
clim=plot_clim,
)

# ANNOTATIONs
# ------------------
Expand Down Expand Up @@ -190,7 +202,12 @@ def plot_contour(contour, contour_color, contour_levels, contour_alpha, contour_

# tick marks
ax.tick_params(
axis="both", direction="out", length=3, width=1, bottom=True, left=True,
axis="both",
direction="out",
length=3,
width=1,
bottom=True,
left=True,
)

# multiply / rescale axis
Expand Down Expand Up @@ -243,7 +260,16 @@ def contourPlotDf(
# unpack, why necessary??
contour_kwargs = contour_kwargs["contour_kwargs"]

contours = ax.contour(Xi, Yi, dataframe, colors=color, levels=levels, zorder=1, alpha=alpha, **contour_kwargs,)
contours = ax.contour(
Xi,
Yi,
dataframe,
colors=color,
levels=levels,
zorder=1,
alpha=alpha,
**contour_kwargs,
)

clabel = contour_kwargs["clabel"] if "clabel" in contour_kwargs else False
if clabel:
Expand Down Expand Up @@ -298,11 +324,19 @@ def plotResult(search, runId, z_bold=False, **kwargs):

bold = result.BOLD[:, bold_transient:]
bold_z = stats.zscore(bold, axis=1)
t_bold = np.linspace(2, len(bold.T) * 2, len(bold.T),)
t_bold = np.linspace(
2,
len(bold.T) * 2,
len(bold.T),
)

output = result[search.model.default_output]
output_dt = search.model.params.dt
t_output = np.linspace(output_dt, len(output.T) * output_dt, len(output.T),)
t_output = np.linspace(
output_dt,
len(output.T) * output_dt,
len(output.T),
)

axs[0].set_title(f"FC (run {runId})")
axs[0].imshow(func.fc(bold))
Expand All @@ -329,8 +363,7 @@ def plotResult(search, runId, z_bold=False, **kwargs):


def processExplorationResults(search, **kwargs):
"""Process results from the exploration.
"""
"""Process results from the exploration."""

dfResults = search.dfResults

Expand Down Expand Up @@ -405,7 +438,15 @@ def processExplorationResults(search, **kwargs):

# calculate mean correlation of functional connectivity
# of the simulation and the empirical data
dfResults.loc[i, "fc"] = np.mean([func.matrix_correlation(func.fc(bold), fc,) for fc in ds.FCs])
dfResults.loc[i, "fc"] = np.mean(
[
func.matrix_correlation(
func.fc(bold),
fc,
)
for fc in ds.FCs
]
)
# if BOLD simulation is longer than 5 minutes, calculate kolmogorov of FCD
skip_fcd = kwargs["skip_fcd"] if "skip_fcd" in kwargs else False
if t_bold[-1] > 5 * 1000 * 60 and not skip_fcd:
Expand Down Expand Up @@ -443,10 +484,10 @@ def findCloseResults(dfResults, dist=None, relative=False, **kwargs):
Use the parameters to filter for as kwargs:
Usage: findCloseResults(search.dfResults, mue_ext_mean=2.0, mui_ext_mean=2.5)
Alternatively, use ranges a la [min, max] for each parameter.
Alternatively, use ranges a la [min, max] for each parameter.
Usage: findCloseResults(search.dfResults, mue_ext_mean=[2.0, 3.0], mui_ext_mean=2.5)
:param dfResults: Pandas dataframe to filter
:type dfResults: pandas.DataFrame
:param dist: Distance to specified points in kwargs, defaults to None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_explorationUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def randomString(stringLength=10):
"""Generate a random string of fixed length """
"""Generate a random string of fixed length"""
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(stringLength))

Expand Down Expand Up @@ -41,7 +41,7 @@ def setUpClass(cls):
model=model, parameterSpace=parameters, filename=f"test_exploration_utils_{randomString(20)}.hdf"
)

search.run(chunkwise=True, bold=True)
search.run(chunkwise=True, bold=True, append_outputs=True)

search.loadResults()

Expand Down
21 changes: 5 additions & 16 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from neurolib.utils.stimulus import ZeroInput
from neurolib.models.kuramoto import KuramotoModel


class TestAln(unittest.TestCase):
"""
Basic test for ALN model.
Expand Down Expand Up @@ -222,8 +223,7 @@ def test_network(self):
model.params["sigma_ou"] = 0.1
model.params["k"] = 0.6


# local node input parameter
# local node input parameter
model.params["theta_ext"] = 0.72

model.run(chunkwise=True, append_outputs=True)
Expand Down Expand Up @@ -343,14 +343,8 @@ def test_continue_run_node(self):
model.params["sampling_dt"] = 10.0
model.params["backend"] = "numba"
# run MultiModel with continuation
model.run(continue_run=True)
model.run()
last_t = model.t[-1]
last_x = model.state["x"][:, -model.maxDelay - 1 :]
last_y = model.state["y"][:, -model.maxDelay - 1 :]
# assert last state is initial state now
self.assertEqual(model.start_t, last_t)
np.testing.assert_equal(last_x.squeeze(), model.model_instance.initial_state[0, :])
np.testing.assert_equal(last_y.squeeze(), model.model_instance.initial_state[1, :])
# change noise - just to make things more interesting
model.noise_input = [ZeroInput()] * model.model_instance.num_noise_variables
model.run(continue_run=True)
Expand All @@ -367,16 +361,11 @@ def test_continue_run_network(self):
model.params["sampling_dt"] = 10.0
model.params["backend"] = "numba"
# run MultiModel with continuation
model.run(continue_run=True)
model.run()
last_t = model.t[-1]
last_x = model.state["x"][:, -model.maxDelay - 1 :]
last_y = model.state["y"][:, -model.maxDelay - 1 :]
# assert last state is initial state now
self.assertEqual(model.start_t, last_t)
np.testing.assert_equal(last_x, model.model_instance.initial_state[[0, 2], :])
np.testing.assert_equal(last_y, model.model_instance.initial_state[[1, 3], :])
# change noise - just to make things more interesting
model.noise_input = [ZeroInput()] * model.model_instance.num_noise_variables

model.run(continue_run=True)
# assert continuous time
self.assertAlmostEqual(model.t[0] - last_t, model.params["dt"] / 1000.0)
Expand Down

0 comments on commit c6872e0

Please sign in to comment.