-
Notifications
You must be signed in to change notification settings - Fork 24
/
prjn.go
299 lines (245 loc) · 10.5 KB
/
prjn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package emer
import (
"fmt"
"io"
"cogentcore.org/core/laser"
"github.com/emer/emergent/v2/params"
"github.com/emer/emergent/v2/prjn"
"github.com/emer/emergent/v2/weights"
)
// Prjn defines the basic interface for a projection which connects two layers.
// Name is set automatically to: SendLay().Name() + "To" + RecvLay().Name()
type Prjn interface {
params.Styler // TypeName, Name, and Class methods for parameter styling
// Init MUST be called to initialize the prjn's pointer to itself as an emer.Prjn
// which enables the proper interface methods to be called.
Init(prjn Prjn)
// SendLay returns the sending layer for this projection
SendLay() Layer
// RecvLay returns the receiving layer for this projection
RecvLay() Layer
// Pattern returns the pattern of connectivity for interconnecting the layers
Pattern() prjn.Pattern
// SetPattern sets the pattern of connectivity for interconnecting the layers.
// Returns Prjn so it can be chained to set other properties too
SetPattern(pat prjn.Pattern) Prjn
// Type returns the functional type of projection according to PrjnType (extensible in
// more specialized algorithms)
Type() PrjnType
// SetType sets the functional type of projection according to PrjnType
// Returns Prjn so it can be chained to set other properties too
SetType(typ PrjnType) Prjn
// PrjnTypeName returns the string rep of functional type of projection
// according to PrjnType (extensible in more specialized algorithms, by
// redefining this method as needed).
PrjnTypeName() string
// Connect sets the basic connection parameters for this projection (send, recv, pattern, and type)
// Connect(send, recv Layer, pat prjn.Pattern, typ PrjnType)
// AddClass adds a CSS-style class name(s) for this prjn,
// ensuring that it is not a duplicate, and properly space separated.
// Returns Prjn so it can be chained to set other properties too
AddClass(cls ...string) Prjn
// Label satisfies the gi.Labeler interface for getting the name of objects generically
Label() string
// IsOff returns true if projection or either send or recv layer has been turned Off.
// Useful for experimentation
IsOff() bool
// SetOff sets the projection Off status (i.e., lesioned). Careful: Layer.SetOff(true) will
// reactivate that layer's projections, so projection-level lesioning should always be called
// after layer-level lesioning.
SetOff(off bool)
// SynVarNames returns the names of all the variables on the synapse
// This is typically a global list so do not modify!
SynVarNames() []string
// SynVarProps returns a map of synapse variable properties, with the key being the
// name of the variable, and the value gives a space-separated list of
// go-tag-style properties for that variable.
// The NetView recognizes the following properties:
// range:"##" = +- range around 0 for default display scaling
// min:"##" max:"##" = min, max display range
// auto-scale:"+" or "-" = use automatic scaling instead of fixed range or not.
// zeroctr:"+" or "-" = control whether zero-centering is used
// Note: this is a global list so do not modify!
SynVarProps() map[string]string
// SynIndex returns the index of the synapse between given send, recv unit indexes
// (1D, flat indexes). Returns -1 if synapse not found between these two neurons.
// This requires searching within connections for receiving unit (a bit slow).
SynIndex(sidx, ridx int) int
// SynVarIndex returns the index of given variable within the synapse,
// according to *this prjn's* SynVarNames() list (using a map to lookup index),
// or -1 and error message if not found.
SynVarIndex(varNm string) (int, error)
// SynVarNum returns the number of synapse-level variables
// for this prjn. This is needed for extending indexes in derived types.
SynVarNum() int
// Syn1DNum returns the number of synapses for this prjn as a 1D array.
// This is the max idx for SynVal1D and the number of vals set by SynValues.
Syn1DNum() int
// SynVal1D returns value of given variable index (from SynVarIndex) on given SynIndex.
// Returns NaN on invalid index.
// This is the core synapse var access method used by other methods,
// so it is the only one that needs to be updated for derived layer types.
SynVal1D(varIndex int, synIndex int) float32
// SynValues sets values of given variable name for each synapse, using the natural ordering
// of the synapses (sender based for Leabra),
// into given float32 slice (only resized if not big enough).
// Returns error on invalid var name.
SynValues(vals *[]float32, varNm string) error
// SynVal returns value of given variable name on the synapse
// between given send, recv unit indexes (1D, flat indexes).
// Returns math32.NaN() for access errors.
SynValue(varNm string, sidx, ridx int) float32
// SetSynVal sets value of given variable name on the synapse
// between given send, recv unit indexes (1D, flat indexes).
// Typically only supports base synapse variables and is not extended
// for derived types.
// Returns error for access errors.
SetSynValue(varNm string, sidx, ridx int, val float32) error
// Defaults sets default parameter values for all Prjn parameters
Defaults()
// UpdateParams() updates parameter values for all Prjn parameters,
// based on any other params that might have changed.
UpdateParams()
// ApplyParams applies given parameter style Sheet to this projection.
// Calls UpdateParams if anything set to ensure derived parameters are all updated.
// If setMsg is true, then a message is printed to confirm each parameter that is set.
// it always prints a message if a parameter fails to be set.
// returns true if any params were set, and error if there were any errors.
ApplyParams(pars *params.Sheet, setMsg bool) (bool, error)
// SetParam sets parameter at given path to given value.
// returns error if path not found or value cannot be set.
SetParam(path, val string) error
// NonDefaultParams returns a listing of all parameters in the Projection that
// are not at their default values -- useful for setting param styles etc.
NonDefaultParams() string
// AllParams returns a listing of all parameters in the Projection
AllParams() string
// WriteWtsJSON writes the weights from this projection from the receiver-side perspective
// in a JSON text format. We build in the indentation logic to make it much faster and
// more efficient.
WriteWtsJSON(w io.Writer, depth int)
// ReadWtsJSON reads the weights from this projection from the receiver-side perspective
// in a JSON text format. This is for a set of weights that were saved *for one prjn only*
// and is not used for the network-level ReadWtsJSON, which reads into a separate
// structure -- see SetWts method.
ReadWtsJSON(r io.Reader) error
// SetWts sets the weights for this projection from weights.Prjn decoded values
SetWts(pw *weights.Prjn) error
// Build constructs the full connectivity among the layers as specified in this projection.
Build() error
}
// Prjns is a slice of projections
type Prjns []Prjn
// ElemLabel satisfies the gi.SliceLabeler interface to provide labels for slice elements
func (pl *Prjns) ElemLabel(idx int) string {
if len(*pl) == 0 {
return "(empty)"
}
if idx < 0 || idx >= len(*pl) {
return ""
}
pj := (*pl)[idx]
if laser.AnyIsNil(pj) {
return "nil"
}
return pj.Name()
}
// Add adds a projection to the list
func (pl *Prjns) Add(p Prjn) {
(*pl) = append(*pl, p)
}
// Send finds the projection with given send layer
func (pl *Prjns) Send(send Layer) (Prjn, bool) {
for _, pj := range *pl {
if pj.SendLay() == send {
return pj, true
}
}
return nil, false
}
// Recv finds the projection with given recv layer
func (pl *Prjns) Recv(recv Layer) (Prjn, bool) {
for _, pj := range *pl {
if pj.RecvLay() == recv {
return pj, true
}
}
return nil, false
}
// SendName finds the projection with given send layer name, nil if not found
// see Try version for error checking.
func (pl *Prjns) SendName(sender string) Prjn {
pj, _ := pl.SendNameTry(sender)
return pj
}
// RecvName finds the projection with given recv layer name, nil if not found
// see Try version for error checking.
func (pl *Prjns) RecvName(recv string) Prjn {
pj, _ := pl.RecvNameTry(recv)
return pj
}
// SendNameTry finds the projection with given send layer name.
// returns error message if not found
func (pl *Prjns) SendNameTry(sender string) (Prjn, error) {
for _, pj := range *pl {
if pj.SendLay().Name() == sender {
return pj, nil
}
}
return nil, fmt.Errorf("sending layer: %v not found in list of projections", sender)
}
// SendNameTypeTry finds the projection with given send layer name and Type string.
// returns error message if not found.
func (pl *Prjns) SendNameTypeTry(sender, typ string) (Prjn, error) {
for _, pj := range *pl {
if pj.SendLay().Name() == sender {
tstr := pj.PrjnTypeName()
if tstr == typ {
return pj, nil
}
}
}
return nil, fmt.Errorf("sending layer: %v, type: %v not found in list of projections", sender, typ)
}
// RecvNameTry finds the projection with given recv layer name.
// returns error message if not found
func (pl *Prjns) RecvNameTry(recv string) (Prjn, error) {
for _, pj := range *pl {
if pj.RecvLay().Name() == recv {
return pj, nil
}
}
return nil, fmt.Errorf("receiving layer: %v not found in list of projections", recv)
}
// RecvNameTypeTry finds the projection with given recv layer name and Type string.
// returns error message if not found.
func (pl *Prjns) RecvNameTypeTry(recv, typ string) (Prjn, error) {
for _, pj := range *pl {
if pj.RecvLay().Name() == recv {
tstr := pj.PrjnTypeName()
if tstr == typ {
return pj, nil
}
}
}
return nil, fmt.Errorf("receiving layer: %v, type: %v not found in list of projections", recv, typ)
}
//////////////////////////////////////////////////////////////////////////////////////
// PrjnType
// PrjnType is the type of the projection (extensible for more specialized algorithms).
// Class parameter styles automatically key off of these types.
type PrjnType int32 //enums:enum
// The projection types
const (
// Forward is a feedforward, bottom-up projection from sensory inputs to higher layers
Forward PrjnType = iota
// Back is a feedback, top-down projection from higher layers back to lower layers
Back
// Lateral is a lateral projection within the same layer / area
Lateral
// Inhib is an inhibitory projection that drives inhibitory synaptic inputs instead of excitatory
Inhib
)