Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plastic synapse weight sign test #973

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 20 additions & 14 deletions models/synapses/stdp_synapse.nestml
Expand Up @@ -33,22 +33,26 @@ References
Stable Hebbian learning from spike timing-dependent
plasticity, Journal of Neuroscience, 20:23,8812--8821
"""
synapse stdp:
synapse stdp_synapse:
state:
w real = 1. @nest::weight # Synaptic weight
w real = 1. @nest::weight # Synaptic weight (> 0 for excitatory and < 0 for inhibitory synapses)
pre_trace real = 0.
post_trace real = 0.

parameters:
d ms = 1 ms @nest::delay # Synaptic transmission delay
lambda real = .01
tau_tr_pre ms = 20 ms
tau_tr_post ms = 20 ms
alpha real = 1
mu_plus real = 1
mu_minus real = 1
Wmax real = 100.
Wmin real = 0.
lambda real = 0.01 # (dimensionless) learning rate for causal updates
alpha real = 1 # relative learning rate for acausal firing
tau_tr_pre ms = 20 ms # time constant of presynaptic trace
tau_tr_post ms = 20 ms # time constant of postsynaptic trace
mu_plus real = 1 # weight dependence exponent for causal updates
mu_minus real = 1 # weight dependence exponent for acausal updates

Wmax real = 100. # maximum absolute value of synaptic weight
Wmin real = 0. # minimum absolute value of synaptic weight

internals:
w_sign real = w / abs(w) # sign of synaptic weight

equations:
pre_trace' = -pre_trace / tau_tr_pre
Expand All @@ -64,16 +68,18 @@ synapse stdp:
onReceive(post_spikes):
post_trace += 1

println("post spike, w_sign = {w_sign}")
# potentiate synapse
w_ real = Wmax * ( w / Wmax + (lambda * ( 1. - ( w / Wmax ) )**mu_plus * pre_trace ))
w = min(Wmax, w_)
w_ real = Wmax * ( abs(w) / Wmax + (lambda * ( 1. - ( abs(w) / Wmax ) )**mu_plus * pre_trace ))
w = w_sign * min(Wmax, w_)

onReceive(pre_spikes):
pre_trace += 1
println("pre spike, w_sign = {w_sign}")

# depress synapse
w_ real = Wmax * ( w / Wmax - ( alpha * lambda * ( w / Wmax )**mu_minus * post_trace ))
w = max(Wmin, w_)
w_ real = Wmax * ( abs(w) / Wmax - ( alpha * lambda * ( abs(w) / Wmax )**mu_minus * post_trace ))
w = w_sign * max(Wmin, w_)

# deliver spike to postsynaptic partner
deliver_spike(w, d)
106 changes: 106 additions & 0 deletions tests/nest_tests/test_plastic_synapse_weight_sign.py
@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
#
# test_plastic_synapse_weight_sign.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import Sequence

import numpy as np
import os
import pytest

import nest

from pynestml.codegeneration.nest_tools import NESTTools
from pynestml.frontend.pynestml_frontend import generate_nest_target

try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.ticker
import matplotlib.pyplot as plt
TEST_PLOTS = True
except Exception:
TEST_PLOTS = False


synapse_model_names = ["stdp_synapse"]#, "triplet_stdp_synapse", "stdp_nn_symm", "stdp_nn_restr_symm", "stdp_nn_pre_centered"]

class TestPlasticSynapseWeightSign:
r"""Test that the sign of the weight of plastic synapses never changes (negative stays negative, positive stays positive)"""

neuron_model_name = "iaf_psc_exp"

@pytest.fixture(autouse=True,
scope="module")
def generate_model_code(self):
"""Generate the model code"""

codegen_opts = {"neuron_synapse_pairs": []}

files = [os.path.join("models", "neurons", self.neuron_model_name + ".nestml")]
for synapse_model_name in synapse_model_names:
files.append(os.path.join("models", "synapses", synapse_model_name + ".nestml"))
codegen_opts["neuron_synapse_pairs"].append({"neuron": self.neuron_model_name,
"synapse": synapse_model_name,
"post_ports": ["post_spikes"]})

input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(os.pardir, os.pardir, s))) for s in files]
generate_nest_target(input_path=input_path,
logging_level="DEBUG",
suffix="_nestml",
codegen_opts=codegen_opts)

nest.Install("nestmlmodule")

@pytest.mark.parametrize("synapse_model_name", synapse_model_names)
@pytest.mark.parametrize("test", ["potentiation", "depression"])
def test_nest_stdp_synapse(self, synapse_model_name: str, test: str):
pre_spike_times = np.linspace(100., 1000., 10)

if test == "potentiation":
init_weight = -1.
post_spike_times = pre_spike_times + 10.
else:
init_weight = 1.
post_spike_times = pre_spike_times - 10.

nest.ResetKernel()

# create spike_generators with these times
pre_sg = nest.Create("spike_generator",
params={"spike_times": pre_spike_times,
"allow_offgrid_times": True})
post_sg = nest.Create("spike_generator",
params={"spike_times": post_spike_times,
"allow_offgrid_times": True})

pre_neuron = nest.Create("parrot_neuron")
post_neuron = nest.Create(self.neuron_model_name)

nest.Connect(pre_sg, pre_neuron, syn_spec={"weight": 9999.})
nest.Connect(post_sg, post_neuron, syn_spec={"weight": 9999.})
nest.Connect(pre_neuron, post_neuron, syn_spec={"synapse_model": synapse_model_name,
"weight": init_weight})

syn = nest.GetConnections(source=pre_neuron, synapse_model=synapse_model_name)

nest.Simulate(100. + max(np.amax(pre_spike_times), np.amax(post_spike_times)))

assert np.sign(syn.weight) == 0. # should not pass through zero