-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
defaults, args working; includes not working but open works.
- Loading branch information
Showing
10 changed files
with
364 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
// Copyright (c) 2023, The Emergent Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package econfig | ||
|
||
import ( | ||
"fmt" | ||
"sort" | ||
"strings" | ||
"testing" | ||
|
||
"golang.org/x/exp/maps" | ||
) | ||
|
||
// TestSubConfig is a sub-struct with special params | ||
type TestSubConfig struct { | ||
NPats int `def:"10" desc:"number of patterns to create"` | ||
Sparseness float32 `def:"0.15" desc:"proportion activity of created params"` | ||
} | ||
|
||
// TestConfig is a testing config | ||
type TestConfig struct { | ||
Includes []string `desc:"specify include files here, and after configuration, it contains list of include files added"` | ||
GUI bool `desc:"open the GUI -- does not automatically run -- if false, then runs automatically and quits"` | ||
GPU bool `desc:"use the GPU for computation"` | ||
Debug bool `desc:"log debugging information"` | ||
PatParams TestSubConfig `desc:"important for testing . notation etc"` | ||
Network map[string]any `desc:"network parameters"` | ||
ParamSet string `desc:"ParamSet name to use -- must be valid name as listed in compiled-in params or loaded params"` | ||
ParamFile string `desc:"Name of the JSON file to input saved parameters from."` | ||
ParamDocFile string `desc:"Name of the file to output all parameter data. If not empty string, program should write file(s) and then exit"` | ||
Tag string `desc:"extra tag to add to file names and logs saved from this run"` | ||
Note string `def:"testing is fun" desc:"user note -- describe the run params etc -- like a git commit message for the run"` | ||
Run int `def:"0" desc:"starting run number -- determines the random seed -- runs counts from there -- can do all runs in parallel by launching separate jobs with each run, runs = 1"` | ||
Runs int `def:"10" desc:"total number of runs to do when running Train"` | ||
Epochs int `def:"100" desc:"total number of epochs per run"` | ||
NTrials int `def:"128" desc:"total number of trials per epoch. Should be an even multiple of NData."` | ||
NData int `def:"16" desc:"number of data-parallel items to process in parallel per trial -- works (and is significantly faster) for both CPU and GPU. Results in an effective mini-batch of learning."` | ||
SaveWts bool `short:"s" desc:"if true, save final weights after each run"` | ||
EpochLog bool `def:"true" desc:"if true, save train epoch log to file, as .epc.tsv typically"` | ||
RunLog bool `def:"true" desc:"if true, save run log to file, as .run.tsv typically"` | ||
TrialLog bool `def:"true" desc:"if true, save train trial log to file, as .trl.tsv typically. May be large."` | ||
TestEpochLog bool `def:"false" desc:"if true, save testing epoch log to file, as .tst_epc.tsv typically. In general it is better to copy testing items over to the training epoch log and record there."` | ||
TestTrialLog bool `def:"false" desc:"if true, save testing trial log to file, as .tst_trl.tsv typically. May be large."` | ||
NetData bool `desc:"if true, save network activation etc data from testing trials, for later viewing in netview"` | ||
} | ||
|
||
func TestDefaults(t *testing.T) { | ||
cfg := &TestConfig{} | ||
SetFromDefaults(cfg) | ||
if cfg.Epochs != 100 || cfg.EpochLog != true || cfg.Note != "testing is fun" { | ||
t.Errorf("Main defaults failed to set") | ||
} | ||
if cfg.PatParams.NPats != 10 || cfg.PatParams.Sparseness != 0.15 { | ||
t.Errorf("PatParams defaults failed to set") | ||
} | ||
} | ||
|
||
func TestArgsPrint(t *testing.T) { | ||
// t.Skip("just prints all possible args") | ||
|
||
cfg := &TestConfig{} | ||
longArgs, shortArgs := FieldArgNames(cfg) | ||
|
||
keys := maps.Keys(longArgs) | ||
sort.Slice(keys, func(i, j int) bool { | ||
return strings.ToLower(keys[i]) < strings.ToLower(keys[j]) | ||
}) | ||
fmt.Println("Long Args:") | ||
fmt.Println(strings.Join(keys, "\n")) | ||
|
||
keys = maps.Keys(shortArgs) | ||
sort.Slice(keys, func(i, j int) bool { | ||
return strings.ToLower(keys[i]) < strings.ToLower(keys[j]) | ||
}) | ||
fmt.Println("\n\nShort Args:") | ||
fmt.Println(strings.Join(keys, "\n")) | ||
} | ||
|
||
func TestArgs(t *testing.T) { | ||
cfg := &TestConfig{} | ||
SetFromDefaults(cfg) | ||
// args := []string{"-s", "--runs=5", "--run", "1", "--TAG", "nice", "--Network", "Prjn.Learn.LRate:0.001", "--sparseness=0.1", "leftover1", "leftover2"} | ||
args := []string{"-s", "--runs=5", "--run", "1", "--TAG", "nice", "--PatParams.Sparseness=0.1", "leftover1", "leftover2"} | ||
leftovers, err := parseArgs(cfg, args) | ||
if err != nil { | ||
t.Errorf(err.Error()) | ||
} | ||
fmt.Println(leftovers) | ||
if cfg.Runs != 5 || cfg.Run != 1 || cfg.Tag != "nice" || cfg.PatParams.Sparseness != 0.1 || cfg.SaveWts != true { | ||
t.Errorf("args not set properly") | ||
} | ||
} | ||
|
||
func TestOpen(t *testing.T) { | ||
IncludePaths = []string{".", "testdata"} | ||
cfg := &TestConfig{} | ||
err := Open(cfg, "testcfg.toml") | ||
if err != nil { | ||
t.Errorf(err.Error()) | ||
} | ||
if cfg.GUI != true || cfg.Tag != "testing" { | ||
t.Errorf("testinc.toml not parsed\n") | ||
} | ||
|
||
err = SetFromIncludes(cfg) | ||
if err != nil { | ||
t.Errorf(err.Error()) | ||
} | ||
fmt.Println("includes:", cfg.Includes) | ||
|
||
if cfg.NTrials != 64 { | ||
t.Errorf("testinc.toml not parsed\n") | ||
} | ||
} |
Oops, something went wrong.