-
Notifications
You must be signed in to change notification settings - Fork 7
/
hip_net.go
277 lines (232 loc) · 10.1 KB
/
hip_net.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package axon
import (
"cogentcore.org/core/math32/vecint"
"cogentcore.org/core/tensor/stats/norm"
"github.com/emer/emergent/v2/emer"
"github.com/emer/emergent/v2/etime"
"github.com/emer/emergent/v2/looper"
"github.com/emer/emergent/v2/paths"
)
// HipConfig have the hippocampus size and connectivity parameters
type HipConfig struct {
// size of EC2
EC2Size vecint.Vector2i `nest:"+"`
// number of EC3 pools (outer dimension)
EC3NPool vecint.Vector2i `nest:"+"`
// number of neurons in one EC3 pool
EC3NNrn vecint.Vector2i `nest:"+"`
// number of neurons in one CA1 pool
CA1NNrn vecint.Vector2i `nest:"+"`
// size of CA3
CA3Size vecint.Vector2i `nest:"+"`
// size of DG / CA3
DGRatio float32 `default:"2.236"`
// percent connectivity from EC3 to EC2
EC3ToEC2PCon float32 `default:"0.1"`
// percent connectivity from EC2 to DG
EC2ToDGPCon float32 `default:"0.25"`
// percent connectivity from EC2 to CA3
EC2ToCA3PCon float32 `default:"0.25"`
// percent connectivity from CA3 to CA1
CA3ToCA1PCon float32 `default:"0.25"`
// percent connectivity into CA3 from DG
DGToCA3PCon float32 `default:"0.02"`
// lateral radius of connectivity in EC2
EC2LatRadius int
// lateral gaussian sigma in EC2 for how quickly weights fall off with distance
EC2LatSigma float32
// proportion of full mossy fiber strength (PathScale.Rel) for CA3 EDL in training, applied at the start of a trial to reduce DG -> CA3 strength. 1 = fully reduce strength, .5 = 50% reduction, etc
MossyDelta float32 `default:"1"`
// proportion of full mossy fiber strength (PathScale.Rel) for CA3 EDL in testing, applied during 2nd-3rd quarters to reduce DG -> CA3 strength. 1 = fully reduce strength, .5 = 50% reduction, etc
MossyDeltaTest float32 `default:"0.75"`
// low theta modulation value for temporal difference EDL -- sets PathScale.Rel on CA1 <-> EC paths consistent with Theta phase model
ThetaLow float32 `default:"0.9"`
// high theta modulation value for temporal difference EDL -- sets PathScale.Rel on CA1 <-> EC paths consistent with Theta phase model
ThetaHigh float32 `default:"1"`
// flag for clamping the EC5 from EC5ClampSrc
EC5Clamp bool `default:"true"`
// source layer for EC5 clamping activations in the plus phase -- biologically it is EC3 but can use an Input layer if available
EC5ClampSrc string `default:"EC3"`
// clamp the EC5 from EC5ClampSrc during testing as well as training -- this will overwrite any target values that might be used in stats (e.g., in the basic hip example), so it must be turned off there
EC5ClampTest bool `default:"true"`
// threshold for binarizing EC5 clamp values -- any value above this is clamped to 1, else 0 -- helps produce a cleaner learning signal. Set to 0 to not perform any binarization.
EC5ClampThr float32 `default:"0.1"`
}
func (hip *HipConfig) Defaults() {
// size
hip.EC2Size.Set(21, 21) // 21
hip.EC3NPool.Set(2, 3)
hip.EC3NNrn.Set(7, 7)
hip.CA1NNrn.Set(10, 10) // using MedHip now
hip.CA3Size.Set(20, 20) // using MedHip now
hip.DGRatio = 2.236 // c.f. Ketz et al., 2013
// ratio
hip.EC2ToDGPCon = 0.25
hip.EC2ToCA3PCon = 0.25
hip.CA3ToCA1PCon = 0.25
hip.DGToCA3PCon = 0.02
hip.EC3ToEC2PCon = 0.1 // 0.1 for EC3-EC2 in WintererMaierWoznyEtAl17, not sure about Input-EC2
// lateral
hip.EC2LatRadius = 2
hip.EC2LatSigma = 2
hip.MossyDelta = 1
hip.MossyDeltaTest = .75
hip.ThetaLow = 0.9
hip.ThetaHigh = 1
hip.EC5Clamp = true
hip.EC5ClampSrc = "EC3"
hip.EC5ClampTest = true
hip.EC5ClampThr = 0.1
}
// AddHip adds a new Hippocampal network for episodic memory.
// Returns layers most likely to be used for remaining connections and positions.
func (net *Network) AddHip(ctx *Context, hip *HipConfig, space float32) (ec2, ec3, dg, ca3, ca1, ec5 *Layer) {
// Trisynaptic Pathway (TSP)
ec2 = net.AddLayer2D("EC2", hip.EC2Size.Y, hip.EC2Size.X, SuperLayer)
ec2.SetRepIndexesShape(emer.Layer2DRepIndexes(ec2, 10))
dg = net.AddLayer2D("DG", int(float32(hip.CA3Size.Y)*hip.DGRatio), int(float32(hip.CA3Size.X)*hip.DGRatio), SuperLayer)
dg.SetRepIndexesShape(emer.Layer2DRepIndexes(dg, 10))
ca3 = net.AddLayer2D("CA3", hip.CA3Size.Y, hip.CA3Size.X, SuperLayer)
ca3.SetRepIndexesShape(emer.Layer2DRepIndexes(ca3, 10))
// Monosynaptic Pathway (MSP)
ec3 = net.AddLayer4D("EC3", hip.EC3NPool.Y, hip.EC3NPool.X, hip.EC3NNrn.Y, hip.EC3NNrn.X, SuperLayer)
ec3.AddClass("EC")
ec3.SetRepIndexesShape(emer.CenterPoolIndexes(ec3, 2), emer.CenterPoolShape(ec3, 2))
ca1 = net.AddLayer4D("CA1", hip.EC3NPool.Y, hip.EC3NPool.X, hip.CA1NNrn.Y, hip.CA1NNrn.X, SuperLayer)
ca1.SetRepIndexesShape(emer.CenterPoolIndexes(ca1, 2), emer.CenterPoolShape(ca1, 2))
if hip.EC5Clamp {
ec5 = net.AddLayer4D("EC5", hip.EC3NPool.Y, hip.EC3NPool.X, hip.EC3NNrn.Y, hip.EC3NNrn.X, TargetLayer) // clamped in plus phase
} else {
ec5 = net.AddLayer4D("EC5", hip.EC3NPool.Y, hip.EC3NPool.X, hip.EC3NNrn.Y, hip.EC3NNrn.X, SuperLayer)
}
ec5.AddClass("EC")
ec5.SetRepIndexesShape(emer.CenterPoolIndexes(ec5, 2), emer.CenterPoolShape(ec5, 2))
// Input and ECs connections
onetoone := paths.NewOneToOne()
ec3Toec2 := paths.NewUniformRand()
ec3Toec2.PCon = hip.EC3ToEC2PCon
mossy := paths.NewUniformRand()
mossy.PCon = hip.DGToCA3PCon
net.ConnectLayers(ec3, ec2, ec3Toec2, ForwardPath)
net.ConnectLayers(ec5, ec3, onetoone, BackPath)
// recurrent inhbition in EC2
lat := paths.NewCircle()
lat.TopoWts = true
lat.Radius = hip.EC2LatRadius
lat.Sigma = hip.EC2LatSigma
inh := net.ConnectLayers(ec2, ec2, lat, InhibPath)
inh.AddClass("InhibLateral")
// TSP connections
ppathDG := paths.NewUniformRand()
ppathDG.PCon = hip.EC2ToDGPCon
ppathCA3 := paths.NewUniformRand()
ppathCA3.PCon = hip.EC2ToCA3PCon
ca3ToCA1 := paths.NewUniformRand()
ca3ToCA1.PCon = hip.CA3ToCA1PCon
full := paths.NewFull()
net.ConnectLayers(ec2, dg, ppathDG, HipPath).AddClass("HippoCHL")
net.ConnectLayers(ec2, ca3, ppathCA3, HipPath).AddClass("PPath")
net.ConnectLayers(ca3, ca3, full, HipPath).AddClass("PPath")
net.ConnectLayers(dg, ca3, mossy, ForwardPath).AddClass("HippoCHL")
net.ConnectLayers(ca3, ca1, ca3ToCA1, HipPath).AddClass("HippoCHL")
// MSP connections
pool1to1 := paths.NewPoolOneToOne()
net.ConnectLayers(ec3, ca1, pool1to1, HipPath).AddClass("EcCA1Path") // HipPath makes wt linear
net.ConnectLayers(ca1, ec5, pool1to1, ForwardPath).AddClass("EcCA1Path") // doesn't work w/ HipPath
net.ConnectLayers(ec5, ca1, pool1to1, HipPath).AddClass("EcCA1Path") // HipPath makes wt linear
// positioning
ec3.PlaceRightOf(ec2, space)
ec5.PlaceRightOf(ec3, space)
dg.PlaceAbove(ec2)
ca3.PlaceAbove(dg)
ca1.PlaceRightOf(ca3, space)
return
}
// ConfigLoopsHip configures the hippocampal looper and should be included in ConfigLoops
// in model to make sure hip loops is configured correctly.
// see hip.go for an instance of implementation of this function.
// ec5ClampFrom specifies the layer to clamp EC5 plus phase values from:
// EC3 is the biological source, but can use Input layer for simple testing net.
func (net *Network) ConfigLoopsHip(ctx *Context, man *looper.Manager, hip *HipConfig, pretrain *bool) {
var tmpValues []float32
clampSrc := net.AxonLayerByName(hip.EC5ClampSrc)
ec5 := net.AxonLayerByName("EC5")
ca1 := net.AxonLayerByName("CA1")
ca3 := net.AxonLayerByName("CA3")
dg := net.AxonLayerByName("DG")
dgFromEc2 := dg.SendName("EC2")
ca1FromEc3 := ca1.SendName("EC3")
ca1FromCa3 := ca1.SendName("CA3")
ca3FromDg := ca3.SendName("DG")
ca3FromEc2 := ca3.SendName("EC2")
ca3FromCa3 := ca3.SendName("CA3")
dgPjScale := ca3FromDg.Params.PathScale.Rel
ca1FromCa3Abs := ca1FromCa3.Params.PathScale.Abs
// configure events -- note that events are shared between Train, Test
// so only need to do it once on Train
mode := etime.Train
stack := man.Stacks[mode]
cyc, _ := stack.Loops[etime.Cycle]
minusStart, _ := cyc.EventByName("MinusPhase")
minusStart.OnEvent.Add("HipMinusPhase:Start", func() {
if *pretrain {
dgFromEc2.Params.Learn.Learn = 0
ca3FromEc2.Params.Learn.Learn = 0
ca3FromCa3.Params.Learn.Learn = 0
ca1FromCa3.Params.Learn.Learn = 0
ca1FromCa3.Params.PathScale.Abs = 0
} else {
dgFromEc2.Params.Learn.Learn = 1
ca3FromEc2.Params.Learn.Learn = 1
ca3FromCa3.Params.Learn.Learn = 1
ca1FromCa3.Params.Learn.Learn = 1
ca1FromCa3.Params.PathScale.Abs = ca1FromCa3Abs
}
ca1FromEc3.Params.PathScale.Rel = hip.ThetaHigh
ca1FromCa3.Params.PathScale.Rel = hip.ThetaLow
ca3FromDg.Params.PathScale.Rel = dgPjScale * (1 - hip.MossyDelta) // turn off DG input to CA3 in first quarter
net.InitGScale(ctx) // update computed scaling factors
net.GPU.SyncParamsToGPU()
})
beta1, _ := cyc.EventByName("Beta1")
beta1.OnEvent.Add("Hip:Beta1", func() {
ca1FromEc3.Params.PathScale.Rel = hip.ThetaLow
ca1FromCa3.Params.PathScale.Rel = hip.ThetaHigh
if man.Mode == etime.Test {
ca3FromDg.Params.PathScale.Rel = dgPjScale * (1 - hip.MossyDeltaTest)
}
net.InitGScale(ctx) // update computed scaling factors
net.GPU.SyncParamsToGPU()
})
plus, _ := cyc.EventByName("PlusPhase")
// note: critical for this to come before std start
plus.OnEvent.InsertBefore("PlusPhase:Start", "HipPlusPhase:Start", func() {
ca3FromDg.Params.PathScale.Rel = dgPjScale // restore at the beginning of plus phase for CA3 EDL
ca1FromEc3.Params.PathScale.Rel = hip.ThetaHigh
ca1FromCa3.Params.PathScale.Rel = hip.ThetaLow
// clamp EC5 from clamp source (EC3 typically)
if hip.EC5Clamp {
if mode != etime.Test || hip.EC5ClampTest {
for di := uint32(0); di < ctx.NetIndexes.NData; di++ {
clampSrc.UnitValues(&tmpValues, "Act", int(di))
if hip.EC5ClampThr > 0 {
norm.Binarize32(tmpValues, hip.EC5ClampThr, 1, 0)
}
ec5.ApplyExt1D32(ctx, di, tmpValues)
}
}
}
net.InitGScale(ctx) // update computed scaling factors
net.GPU.SyncParamsToGPU()
net.ApplyExts(ctx) // essential for GPU
})
trl := stack.Loops[etime.Trial]
trl.OnEnd.Prepend("HipPlusPhase:End", func() {
ca1FromCa3.Params.PathScale.Rel = hip.ThetaHigh
net.InitGScale(ctx) // update computed scaling factors
net.GPU.SyncParamsToGPU()
})
}