Skip to content

Commit

Permalink
untested jax aln
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed Mar 18, 2024
1 parent 7774761 commit 0a682c0
Show file tree
Hide file tree
Showing 5 changed files with 1,106 additions and 0 deletions.
308 changes: 308 additions & 0 deletions neurolib/models/aln/timeIntegration.py
Expand Up @@ -311,6 +311,314 @@ def timeIntegration(params):
noise_inh,
)

def timeIntegration_args(params):
"""Sets up the parameters for time integration
Return:
rates_exc: N*L array : containing the exc. neuron rates in kHz time series of the N nodes
rates_inh: N*L array : containing the inh. neuron rates in kHz time series of the N nodes
t: L array : time in ms
mufe: N vector : final value of mufe for each node
mufi: N vector : final value of mufi for each node
IA: N vector : final value of IA for each node
seem : N vector : final value of seem for each node
seim : N vector : final value of seim for each node
siem : N vector : final value of siem for each node
siim : N vector : final value of siim for each node
seev : N vector : final value of seev for each node
seiv : N vector : final value of seiv for each node
siev : N vector : final value of siev for each node
siiv : N vector : final value of siiv for each node
:param params: Parameter dictionary of the model
:type params: dict
:return: Integrated activity variables of the model
:rtype: (numpy.ndarray,)
"""

dt = params["dt"] # Time step for the Euler intergration (ms)
duration = params["duration"] # imulation duration (ms)
RNGseed = params["seed"] # seed for RNG
# set to 0 for faster computation

# ------------------------------------------------------------------------
# global coupling parameters

# Connectivity matric
# Interareal relative coupling strengths (values between 0 and 1), Cmat(i,j) connnection from jth to ith
Cmat = params["Cmat"]
c_gl = params["c_gl"] # EPSP amplitude between areas
Ke_gl = params["Ke_gl"] # number of incoming E connections (to E population) from each area

N = len(Cmat) # Number of areas

# Interareal connection delay
lengthMat = params["lengthMat"]
signalV = params["signalV"]

if N == 1:
Dmat = np.ones((N, N)) * params["de"]
else:
Dmat = mu.computeDelayMatrix(
lengthMat, signalV
) # Interareal connection delays, Dmat(i,j) Connnection from jth node to ith (ms)
Dmat[np.eye(len(Dmat)) == 1] = np.ones(len(Dmat)) * params["de"]

Dmat_ndt = np.around(Dmat / dt).astype(int) # delay matrix in multiples of dt

# ------------------------------------------------------------------------

# local network (area) parameters [identical for all areas for now]

### model parameters
filter_sigma = params["filter_sigma"]

# distributed delay between areas, not tested, but should work
# distributed delay is implemented by a convolution with the delay kernel
# the convolution is represented as a linear ODE with the timescale that
# corresponds to the width of the delay distribution
distr_delay = params["distr_delay"]

# external input parameters:
tau_ou = params["tau_ou"] # Parameter of the Ornstein-Uhlenbeck process for the external input(ms)
# Parameter of the Ornstein-Uhlenbeck (OU) process for the external input ( mV/ms/sqrt(ms) )
sigma_ou = params["sigma_ou"]
mue_ext_mean = params["mue_ext_mean"] # Mean external excitatory input (OU process) (mV/ms)
mui_ext_mean = params["mui_ext_mean"] # Mean external inhibitory input (OU process) (mV/ms)
sigmae_ext = params["sigmae_ext"] # External exc input standard deviation ( mV/sqrt(ms) )
sigmai_ext = params["sigmai_ext"] # External inh input standard deviation ( mV/sqrt(ms) )

# recurrent coupling parameters
Ke = params["Ke"] # Recurrent Exc coupling. "EE = IE" assumed for act_dep_coupling in current implementation
Ki = params["Ki"] # Recurrent Exc coupling. "EI = II" assumed for act_dep_coupling in current implementation

# Recurrent connection delays
de = params["de"] # Local constant delay "EE = IE" (ms)
di = params["di"] # Local constant delay "EI = II" (ms)

tau_se = params["tau_se"] # Synaptic decay time constant for exc. connections "EE = IE" (ms)
tau_si = params["tau_si"] # Synaptic decay time constant for inh. connections "EI = II" (ms)
tau_de = params["tau_de"]
tau_di = params["tau_di"]

cee = params["cee"] # strength of exc. connection
# -> determines ePSP magnitude in state-dependent way (in the original model)
cie = params["cie"] # strength of inh. connection
# -> determines iPSP magnitude in state-dependent way (in the original model)
cei = params["cei"]
cii = params["cii"]

# Recurrent connections coupling strength
Jee_max = params["Jee_max"] # ( mV/ms )
Jei_max = params["Jei_max"] # ( mV/ms )
Jie_max = params["Jie_max"] # ( mV/ms )
Jii_max = params["Jii_max"] # ( mV/ms )

# rescales c's here: multiplication with tau_se makes
# the increase of s subject to a single input spike invariant to tau_se
# division by J ensures that mu = J*s will result in a PSP of exactly c
# for a single spike!

cee = cee * tau_se / Jee_max # ms
cie = cie * tau_se / Jie_max # ms
cei = cei * tau_si / abs(Jei_max) # ms
cii = cii * tau_si / abs(Jii_max) # ms
c_gl = c_gl * tau_se / Jee_max # ms

# neuron model parameters
a = params["a"] # Adaptation coupling term ( nS )
b = params["b"] # Spike triggered adaptation ( pA )
EA = params["EA"] # Adaptation reversal potential ( mV )
tauA = params["tauA"] # Adaptation time constant ( ms )
# if params below are changed, preprocessing required
C = params["C"] # membrane capacitance ( pF )
gL = params["gL"] # Membrane conductance ( nS )
EL = params["EL"] # Leak reversal potential ( mV )
DeltaT = params["DeltaT"] # Slope factor ( EIF neuron ) ( mV )
VT = params["VT"] # Effective threshold (in exp term of the aEIF model)(mV)
Vr = params["Vr"] # Membrane potential reset value (mV)
Vs = params["Vs"] # Cutoff or spike voltage value, determines the time of spike (mV)
Tref = params["Tref"] # Refractory time (ms)
taum = C / gL # membrane time constant

# ------------------------------------------------------------------------

# Lookup tables for the transfer functions
precalc_r, precalc_V, precalc_tau_mu, precalc_tau_sigma = (
params["precalc_r"],
params["precalc_V"],
params["precalc_tau_mu"],
params["precalc_tau_sigma"],
)

# parameter for the lookup tables
dI = params["dI"]
ds = params["ds"]
sigmarange = params["sigmarange"]
Irange = params["Irange"]

# Initialization
# Floating point issue in np.arange() workaraound: use integers in np.arange()
t = np.arange(1, round(duration, 6) / dt + 1) * dt # Time variable (ms)
sqrt_dt = np.sqrt(dt)

ndt_de = np.around(de / dt).astype(int)
ndt_di = np.around(di / dt).astype(int)

rd_exc = np.zeros((N, N)) # kHz rd_exc(i,j): Connection from jth node to ith
rd_inh = np.zeros(N)

# Already done above when Dmat_ndt is built
# for l in range(N):
# Dmat_ndt[l, l] = ndt_de # if no distributed, this is a fixed value (E-E coupling)

max_global_delay = max(np.max(Dmat_ndt), ndt_de, ndt_di)
startind = int(max_global_delay + 1)

# state variable arrays, have length of t + startind
# they store initial conditions AND simulated data
rates_exc = np.zeros((N, startind + len(t)))
rates_inh = np.zeros((N, startind + len(t)))
IA = np.zeros((N, startind + len(t)))

# ------------------------------------------------------------------------
# Set initial values
mufe = params["mufe_init"].copy() # Filtered mean input (mu) for exc. population
mufi = params["mufi_init"].copy() # Filtered mean input (mu) for inh. population
IA_init = params["IA_init"].copy() # Adaptation current (pA)
seem = params["seem_init"].copy() # Mean exc synaptic input
seim = params["seim_init"].copy()
seev = params["seev_init"].copy() # Exc synaptic input variance
seiv = params["seiv_init"].copy()
siim = params["siim_init"].copy() # Mean inh synaptic input
siem = params["siem_init"].copy()
siiv = params["siiv_init"].copy() # Inh synaptic input variance
siev = params["siev_init"].copy()

mue_ou = params["mue_ou"].copy() # Mean of external exc OU input (mV/ms)
mui_ou = params["mui_ou"].copy() # Mean of external inh ON inout (mV/ms)

# Set the initial firing rates.
# if initial values are just a Nx1 array
if np.shape(params["rates_exc_init"])[1] == 1:
# repeat the 1-dim value stardind times
rates_exc_init = np.dot(params["rates_exc_init"], np.ones((1, startind))) # kHz
rates_inh_init = np.dot(params["rates_inh_init"], np.ones((1, startind))) # kHz
# set initial adaptation current
IA_init = np.dot(params["IA_init"], np.ones((1, startind)))
# if initial values are a Nxt array
else:
rates_exc_init = params["rates_exc_init"][:, -startind:]
rates_inh_init = params["rates_inh_init"][:, -startind:]
IA_init = params["IA_init"][:, -startind:]

np.random.seed(RNGseed)

# Save the noise in the rates array to save memory
rates_exc[:, startind:] = np.random.standard_normal((N, len(t)))
rates_inh[:, startind:] = np.random.standard_normal((N, len(t)))

# Set the initial conditions
rates_exc[:, :startind] = rates_exc_init
rates_inh[:, :startind] = rates_inh_init
IA[:, :startind] = IA_init

noise_exc = np.zeros((N,))
noise_inh = np.zeros((N,))

# tile external inputs to appropriate shape
ext_exc_current = mu.adjustArrayShape(params["ext_exc_current"], rates_exc)
ext_inh_current = mu.adjustArrayShape(params["ext_inh_current"], rates_exc)
ext_exc_rate = mu.adjustArrayShape(params["ext_exc_rate"], rates_exc)
ext_inh_rate = mu.adjustArrayShape(params["ext_inh_rate"], rates_exc)

# ------------------------------------------------------------------------

return (
dt,
duration,
distr_delay,
filter_sigma,
Cmat,
Dmat,
c_gl,
Ke_gl,
tau_ou,
sigma_ou,
mue_ext_mean,
mui_ext_mean,
sigmae_ext,
sigmai_ext,
Ke,
Ki,
de,
di,
tau_se,
tau_si,
tau_de,
tau_di,
cee,
cie,
cii,
cei,
Jee_max,
Jei_max,
Jie_max,
Jii_max,
a,
b,
EA,
tauA,
C,
gL,
EL,
DeltaT,
VT,
Vr,
Vs,
Tref,
taum,
mufe,
mufi,
IA,
seem,
seim,
seev,
seiv,
siim,
siem,
siiv,
siev,
precalc_r,
precalc_V,
precalc_tau_mu,
precalc_tau_sigma,
dI,
ds,
sigmarange,
Irange,
N,
Dmat_ndt,
t,
rates_exc,
rates_inh,
rd_exc,
rd_inh,
sqrt_dt,
startind,
ndt_de,
ndt_di,
mue_ou,
mui_ou,
ext_exc_rate,
ext_inh_rate,
ext_exc_current,
ext_inh_current,
noise_exc,
noise_inh,
)



@numba.njit(locals={"idxX": numba.int64, "idxY": numba.int64, "idx1": numba.int64, "idy1": numba.int64})
def timeIntegration_njit_elementwise(
Expand Down
1 change: 1 addition & 0 deletions neurolib/models/jax/aln/__init__.py
@@ -0,0 +1 @@
from .model import ALNModel
55 changes: 55 additions & 0 deletions neurolib/models/jax/aln/loadDefaultParams.py
@@ -0,0 +1,55 @@
import jax.numpy as jnp
from jax import random

from ...aln.loadDefaultParams import loadDefaultParams as loadDefaultParams_numpy


def loadDefaultParams(Cmat=None, Dmat=None, lookupTableFileName=None, seed=None):
"""Load default parameters for a network of aLN nodes.
:param Cmat: Structural connectivity matrix (adjacency matrix) of coupling strengths, will be normalized to 1. If not given, then a single node simulation will be assumed, defaults to None
:type Cmat: numpy.ndarray, optional
:param Dmat: Fiber length matrix, will be used for computing the delay matrix together with the signal transmission speed parameter `signalV`, defaults to None
:type Dmat: numpy.ndarray, optional
:param lookUpTableFileName: Filename of lookup table with aln non-linear transfer functions and other precomputed quantities., defaults to aln-precalc/quantities_cascade.h
:type lookUpTableFileName: str, optional
:param seed: Seed for the random number generator, defaults to None
:type seed: int, optional
:return: A dictionary with the default parameters of the model
:rtype: dict
"""

params = loadDefaultParams_numpy(Cmat, Dmat, lookupTableFileName, seed)

# Use JAX's PRNGKey for RNG
key = random.PRNGKey(seed) if seed is not None else random.PRNGKey(0)
params.key = key

params.Cmat = jnp.array(params.Cmat)
params.lengthMat = jnp.array(params.lengthMat)

params.mue_ou = jnp.array(params.mue_ou)
params.mui_ou = jnp.array(params.mui_ou)

params.mufe_init = jnp.array(params.mufe_init)
params.mufi_init = jnp.array(params.mufi_init)
params.IA_init = jnp.array(params.IA_init)
params.seem_init = jnp.array(params.seem_init)
params.seim_init = jnp.array(params.seim_init)
params.seev_init = jnp.array(params.seev_init)
params.seiv_init = jnp.array(params.seiv_init)
params.siim_init = jnp.array(params.siim_init)
params.siem_init = jnp.array(params.siem_init)
params.siiv_init = jnp.array(params.siiv_init)
params.siev_init = jnp.array(params.siev_init)
params.rates_exc_init = jnp.array(params.rates_exc_init)
params.rates_inh_init = jnp.array(params.rates_inh_init)

params.Irange = jnp.array(params.Irange)
params.sigmarange = jnp.array(params.sigmarange)
params.precalc_r = jnp.array(params.precalc_r)
params.precalc_V = jnp.array(params.precalc_V)
params.precalc_tau_mu = jnp.array(params.precalc_tau_mu)
params.precalc_tau_sigma = jnp.array(params.precalc_tau_sigma)

return params

0 comments on commit 0a682c0

Please sign in to comment.