Skip to content

Commit

Permalink
defaults, args working; includes not working but open works.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoreilly committed Jul 3, 2023
1 parent c9f026f commit 760ed8a
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 5 deletions.
27 changes: 26 additions & 1 deletion econfig/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ econfig provides methods to set values on a `Config` struct through a (TOML) con
# Special fields, supported types, and field tags

* A limited number of standard field types are supported, consistent with emer neural network usage:
+ `bool` and `[]bool`
+ `float32` and `[]float32`
+ `int` and `[]int`
+ `string` and `[]string`
Expand All @@ -51,7 +52,31 @@ econfig provides methods to set values on a `Config` struct through a (TOML) con
Here's a standard `Config` struct, corresponding to the `AddStd` args from `ecmd`, which can be used as a starting point.

```Go
type Config struct {
// Config is a standard Sim config -- use as a starting point.
// don't forget to update defaults, delete unused fields, etc.
typeConfig struct {
Includes []string `desc:"specify include files here, and after configuration, it contains list of include files added"`
GUI bool `desc:"open the GUI -- does not automatically run -- if false, then runs automatically and quits"`
GPU bool `desc:"use the GPU for computation"`
Debug bool `desc:"log debugging information"`
Network map[string]any `desc:"network parameters"`
ParamSet string `desc:"ParamSet name to use -- must be valid name as listed in compiled-in params or loaded params"`
ParamFile string `desc:"Name of the JSON file to input saved parameters from."`
ParamDocFile string `desc:"Name of the file to output all parameter data. If not empty string, program should write file(s) and then exit"`
Tag string `desc:"extra tag to add to file names and logs saved from this run"`
Note string `desc:"user note -- describe the run params etc -- like a git commit message for the run"`
Run int `def:"0" desc:"starting run number -- determines the random seed -- runs counts from there -- can do all runs in parallel by launching separate jobs with each run, runs = 1"`
Runs int `def:"10" desc:"total number of runs to do when running Train"`
Epochs int `def:"100" desc:"total number of epochs per run"`
NTrials int `def:"128" desc:"total number of trials per epoch. Should be an even multiple of NData."`
NData int `def:"16" desc:"number of data-parallel items to process in parallel per trial -- works (and is significantly faster) for both CPU and GPU. Results in an effective mini-batch of learning."`
SaveWts bool `desc:"if true, save final weights after each run"`
EpochLog bool `def:"true" desc:"if true, save train epoch log to file, as .epc.tsv typically"`
RunLog bool `def:"true" desc:"if true, save run log to file, as .run.tsv typically"`
TrialLog bool `def:"true" desc:"if true, save train trial log to file, as .trl.tsv typically. May be large."`
TestEpochLog bool `def:"false" desc:"if true, save testing epoch log to file, as .tst_epc.tsv typically. In general it is better to copy testing items over to the training epoch log and record there."`
TestTrialLog bool `def:"false" desc:"if true, save testing trial log to file, as .tst_trl.tsv typically. May be large."`
NetData bool `desc:"if true, save network activation etc data from testing trials, for later viewing in netview"`
}
```

Expand Down
167 changes: 163 additions & 4 deletions econfig/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,35 @@

package econfig

import (
"fmt"
"log"
"os"
"reflect"
"strings"

"github.com/goki/ki/kit"
"github.com/iancoleman/strcase"
)

// SetFromArgs sets Config values from command-line args,
// based on the field names in the Config struct.
// Returns any args that did not start with a `-` flag indicator.
// For more robust error processing, it is assumed that all flagged args (-)
// must refer to fields in the config, so any that fail to match trigger
// an error. Errors can also result from parsing.
// Errors are automatically logged because these are user-facing.
func SetFromArgs(cfg any) ([]string, error) {
return nil, nil
func SetFromArgs(cfg any) (leftovers []string, err error) {
leftovers, err = parseArgs(cfg, os.Args[1:])
if err != nil {
fmt.Println(Usage(cfg))
}
return
}

// parseArgs does the actual arg parsing
func parseArgs(cfg any, args []string) ([]string, error) {
longArgs, shortArgs := FieldArgNames(cfg)
var leftovers []string
var err error
for len(args) > 0 {
Expand All @@ -37,13 +53,156 @@ func parseArgs(cfg any, args []string) ([]string, error) {
leftovers = append(leftovers, args...)
break
}
// args, err = f.parseLongArg(s, args, fn)
args, err = parseLongArg(s, args, longArgs)
} else {
// args, err = f.parseShortArg(s, args, fn)
args, err = parseShortArg(s, args, shortArgs)
}
if err != nil {
return leftovers, err
}
}
return leftovers, nil
}

func parseLongArg(s string, args []string, longArgs map[string]reflect.Value) (a []string, err error) {
a = args
name := s[2:]
if len(name) == 0 || name[0] == '-' || name[0] == '=' {
err = fmt.Errorf("SetFromArgs: bad flag syntax: %s", s)
log.Println(err)
return
}

split := strings.SplitN(name, "=", 2)
name = split[0]
fval, exists := longArgs[name]
if !exists {
err = fmt.Errorf("SetFromArgs: flag name not recognized: %s", name)
log.Println(err)
return
}

var value string
if len(split) == 2 {
// '--flag=arg'
value = split[1]
} else if len(a) > 0 {
// '--flag arg'
value = a[0]
a = a[1:]
} else {
// '--flag' (arg was required)
err = fmt.Errorf("SetFromArgs: flag needs an argument: %s", s)
log.Println(err)
return
}

err = setArgValue(name, fval, value)
return
}

func setArgValue(name string, fval reflect.Value, value string) error {
ok := kit.SetRobust(fval.Interface(), value) // overkill but whatever
if !ok {
err := fmt.Errorf("SetFromArgs: not able to set field from arg: %s val: %s", name, value)
log.Println(err)
return err
}
return nil
}

func parseSingleShortArg(shorthands string, args []string, shortArgs map[string]reflect.Value) (outShorts string, outArgs []string, err error) {
outArgs = args
// if strings.HasPrefix(shorthands, "test.") {
// return
// }
outShorts = shorthands[1:]
c := string(shorthands[0])

fval, exists := shortArgs[c]

if !exists {
err = fmt.Errorf("SetFromArgs: unknown shorthand flag: %q in -%s", c, shorthands)
log.Println(err)
return
}

// todo: make sure that next field doesn't start with -

var value string
if len(shorthands) > 2 && shorthands[1] == '=' {
// '-f=arg'
value = shorthands[2:]
outShorts = ""
} else if len(args) > 0 {
if len(args[0]) > 1 && string(args[0][0]) != "-" {
value = args[0]
outArgs = args[1:]
} else {
value = "true"
}
} else {
value = "true"
}
err = setArgValue(c, fval, value)
return
}

func parseShortArg(s string, args []string, shortArgs map[string]reflect.Value) (a []string, err error) {
a = args
shorthands := s[1:]

// "shorthands" can be a series of shorthand letters of flags (e.g. "-vvv").
for len(shorthands) > 0 {
shorthands, a, err = parseSingleShortArg(shorthands, args, shortArgs)
if err != nil {
return
}
}
return
}

// FieldArgNames returns map of all the different ways the field names
// can be specified as arg flags, mapping to the reflect.Value
func FieldArgNames(obj any) (longArgs, shortArgs map[string]reflect.Value) {
longArgs = make(map[string]reflect.Value)
shortArgs = make(map[string]reflect.Value)
FieldArgNamesStruct(obj, "", longArgs, shortArgs)
return
}

// FieldArgNamesStruct returns map of all the different ways the field names
// can be specified as arg flags, mapping to the reflect.Value
func FieldArgNamesStruct(obj any, path string, longArgs, shortArgs map[string]reflect.Value) {
typ := kit.NonPtrType(reflect.TypeOf(obj))
val := kit.NonPtrValue(reflect.ValueOf(obj))
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
fv := val.Field(i)
if kit.NonPtrType(f.Type).Kind() == reflect.Struct {
nwPath := f.Name
if path != "" {
nwPath = path + "." + nwPath
}
FieldArgNamesStruct(kit.PtrValue(fv).Interface(), nwPath, longArgs, shortArgs)
continue
}
pval := kit.PtrValue(fv)
nm := f.Name
if path != "" {
nm = path + "." + nm
}
longArgs[nm] = pval
longArgs[strings.ToLower(nm)] = pval
longArgs[strcase.ToKebab(nm)] = pval
longArgs[strcase.ToSnake(nm)] = pval
longArgs[strcase.ToScreamingSnake(nm)] = pval
sh, ok := f.Tag.Lookup("short")
if ok && sh != "" {
if _, has := shortArgs[sh]; has {
log.Println("Short arg named:", sh, "already defined")
}
shortArgs[sh] = pval
}
}
}
1 change: 1 addition & 0 deletions econfig/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func SetFromDefaultsStruct(obj any) error {
fv := val.Field(i)
if kit.NonPtrType(f.Type).Kind() == reflect.Struct {
SetFromDefaultsStruct(kit.PtrValue(fv).Interface())
continue
}
def, ok := f.Tag.Lookup("def")
if !ok || def == "" {
Expand Down
116 changes: 116 additions & 0 deletions econfig/econfig_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) 2023, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package econfig

import (
"fmt"
"sort"
"strings"
"testing"

"golang.org/x/exp/maps"
)

// TestSubConfig is a sub-struct with special params
type TestSubConfig struct {
NPats int `def:"10" desc:"number of patterns to create"`
Sparseness float32 `def:"0.15" desc:"proportion activity of created params"`
}

// TestConfig is a testing config
type TestConfig struct {
Includes []string `desc:"specify include files here, and after configuration, it contains list of include files added"`
GUI bool `desc:"open the GUI -- does not automatically run -- if false, then runs automatically and quits"`
GPU bool `desc:"use the GPU for computation"`
Debug bool `desc:"log debugging information"`
PatParams TestSubConfig `desc:"important for testing . notation etc"`
Network map[string]any `desc:"network parameters"`
ParamSet string `desc:"ParamSet name to use -- must be valid name as listed in compiled-in params or loaded params"`
ParamFile string `desc:"Name of the JSON file to input saved parameters from."`
ParamDocFile string `desc:"Name of the file to output all parameter data. If not empty string, program should write file(s) and then exit"`
Tag string `desc:"extra tag to add to file names and logs saved from this run"`
Note string `def:"testing is fun" desc:"user note -- describe the run params etc -- like a git commit message for the run"`
Run int `def:"0" desc:"starting run number -- determines the random seed -- runs counts from there -- can do all runs in parallel by launching separate jobs with each run, runs = 1"`
Runs int `def:"10" desc:"total number of runs to do when running Train"`
Epochs int `def:"100" desc:"total number of epochs per run"`
NTrials int `def:"128" desc:"total number of trials per epoch. Should be an even multiple of NData."`
NData int `def:"16" desc:"number of data-parallel items to process in parallel per trial -- works (and is significantly faster) for both CPU and GPU. Results in an effective mini-batch of learning."`
SaveWts bool `short:"s" desc:"if true, save final weights after each run"`
EpochLog bool `def:"true" desc:"if true, save train epoch log to file, as .epc.tsv typically"`
RunLog bool `def:"true" desc:"if true, save run log to file, as .run.tsv typically"`
TrialLog bool `def:"true" desc:"if true, save train trial log to file, as .trl.tsv typically. May be large."`
TestEpochLog bool `def:"false" desc:"if true, save testing epoch log to file, as .tst_epc.tsv typically. In general it is better to copy testing items over to the training epoch log and record there."`
TestTrialLog bool `def:"false" desc:"if true, save testing trial log to file, as .tst_trl.tsv typically. May be large."`
NetData bool `desc:"if true, save network activation etc data from testing trials, for later viewing in netview"`
}

func TestDefaults(t *testing.T) {
cfg := &TestConfig{}
SetFromDefaults(cfg)
if cfg.Epochs != 100 || cfg.EpochLog != true || cfg.Note != "testing is fun" {
t.Errorf("Main defaults failed to set")
}
if cfg.PatParams.NPats != 10 || cfg.PatParams.Sparseness != 0.15 {
t.Errorf("PatParams defaults failed to set")
}
}

func TestArgsPrint(t *testing.T) {
// t.Skip("just prints all possible args")

cfg := &TestConfig{}
longArgs, shortArgs := FieldArgNames(cfg)

keys := maps.Keys(longArgs)
sort.Slice(keys, func(i, j int) bool {
return strings.ToLower(keys[i]) < strings.ToLower(keys[j])
})
fmt.Println("Long Args:")
fmt.Println(strings.Join(keys, "\n"))

keys = maps.Keys(shortArgs)
sort.Slice(keys, func(i, j int) bool {
return strings.ToLower(keys[i]) < strings.ToLower(keys[j])
})
fmt.Println("\n\nShort Args:")
fmt.Println(strings.Join(keys, "\n"))
}

func TestArgs(t *testing.T) {
cfg := &TestConfig{}
SetFromDefaults(cfg)
// args := []string{"-s", "--runs=5", "--run", "1", "--TAG", "nice", "--Network", "Prjn.Learn.LRate:0.001", "--sparseness=0.1", "leftover1", "leftover2"}
args := []string{"-s", "--runs=5", "--run", "1", "--TAG", "nice", "--PatParams.Sparseness=0.1", "leftover1", "leftover2"}
leftovers, err := parseArgs(cfg, args)
if err != nil {
t.Errorf(err.Error())
}
fmt.Println(leftovers)
if cfg.Runs != 5 || cfg.Run != 1 || cfg.Tag != "nice" || cfg.PatParams.Sparseness != 0.1 || cfg.SaveWts != true {
t.Errorf("args not set properly")
}
}

func TestOpen(t *testing.T) {
IncludePaths = []string{".", "testdata"}
cfg := &TestConfig{}
err := Open(cfg, "testcfg.toml")
if err != nil {
t.Errorf(err.Error())
}
if cfg.GUI != true || cfg.Tag != "testing" {
t.Errorf("testinc.toml not parsed\n")
}

err = SetFromIncludes(cfg)
if err != nil {
t.Errorf(err.Error())
}
fmt.Println("includes:", cfg.Includes)

if cfg.NTrials != 64 {
t.Errorf("testinc.toml not parsed\n")
}
}

0 comments on commit 760ed8a

Please sign in to comment.