From 80f55c7d3c8e6cae519df4e2c6a29e26599b0292 Mon Sep 17 00:00:00 2001 From: David Symonds Date: Fri, 18 Sep 2020 18:32:03 +1000 Subject: [PATCH] chore(spanner/spansql): use ID type for identifiers throughout (#2889) This was a long-planned cleanup, which adds some rigour to the type definitions, and avoids a bunch of unnecessary casting in various places. --- spanner/spannertest/db.go | 42 ++++++++++++++--------------- spanner/spannertest/db_eval.go | 14 +++++----- spanner/spannertest/db_query.go | 6 ++--- spanner/spannertest/db_test.go | 32 +++++++++++----------- spanner/spannertest/inmem.go | 21 ++++++++++----- spanner/spansql/parser.go | 16 +++++------ spanner/spansql/parser_test.go | 22 +++++++-------- spanner/spansql/sql.go | 36 ++++++++++++------------- spanner/spansql/sql_test.go | 2 +- spanner/spansql/types.go | 48 ++++++++++++++++----------------- 10 files changed, 123 insertions(+), 116 deletions(-) diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go index 4f24b12f4a0..8c8fbe711d7 100644 --- a/spanner/spannertest/db.go +++ b/spanner/spannertest/db.go @@ -43,8 +43,8 @@ import ( type database struct { mu sync.Mutex lastTS time.Time // last commit timestamp - tables map[string]*table - indexes map[string]struct{} // only record their existence + tables map[spansql.ID]*table + indexes map[spansql.ID]struct{} // only record their existence rwMu sync.Mutex // held by read-write transactions } @@ -55,9 +55,9 @@ type table struct { // Information about the table columns. // They are reordered on table creation so the primary key columns come first. cols []colInfo - colIndex map[string]int // col name to index - pkCols int // number of primary key columns (may be 0) - pkDesc []bool // whether each primary key column is in descending order + colIndex map[spansql.ID]int // col name to index + pkCols int // number of primary key columns (may be 0) + pkDesc []bool // whether each primary key column is in descending order // Rows are stored in primary key order. rows []row @@ -65,7 +65,7 @@ type table struct { // colInfo represents information about a column in a table or result set. type colInfo struct { - Name string + Name spansql.ID Type spansql.Type AggIndex int // Index+1 of SELECT list for which this is an aggregate value. } @@ -239,10 +239,10 @@ func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status { // Lazy init. if d.tables == nil { - d.tables = make(map[string]*table) + d.tables = make(map[spansql.ID]*table) } if d.indexes == nil { - d.indexes = make(map[string]struct{}) + d.indexes = make(map[spansql.ID]struct{}) } switch stmt := stmt.(type) { @@ -259,7 +259,7 @@ func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status { // TODO: check stmt.Interleave details. // Move primary keys first, preserving their order. - pk := make(map[string]int) + pk := make(map[spansql.ID]int) var pkDesc []bool for i, kp := range stmt.PrimaryKey { pk[kp.Column] = -1000 + i @@ -271,7 +271,7 @@ func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status { }) t := &table{ - colIndex: make(map[string]int), + colIndex: make(map[spansql.ID]int), pkCols: len(pk), pkDesc: pkDesc, } @@ -334,7 +334,7 @@ func (d *database) ApplyDDL(stmt spansql.DDLStmt) *status.Status { } -func (d *database) table(tbl string) (*table, error) { +func (d *database) table(tbl spansql.ID) (*table, error) { d.mu.Lock() defer d.mu.Unlock() @@ -346,7 +346,7 @@ func (d *database) table(tbl string) (*table, error) { } // writeValues executes a write option (Insert, Update, etc.). -func (d *database) writeValues(tx *transaction, tbl string, cols []string, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error { +func (d *database) writeValues(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue, f func(t *table, colIndexes []int, r row) error) error { if err := tx.checkMutable(); err != nil { return err } @@ -406,7 +406,7 @@ func (d *database) writeValues(tx *transaction, tbl string, cols []string, value return nil } -func (d *database) Insert(tx *transaction, tbl string, cols []string, values []*structpb.ListValue) error { +func (d *database) Insert(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error { return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error { pk := r[:t.pkCols] rowNum, found := t.rowForPK(pk) @@ -418,7 +418,7 @@ func (d *database) Insert(tx *transaction, tbl string, cols []string, values []* }) } -func (d *database) Update(tx *transaction, tbl string, cols []string, values []*structpb.ListValue) error { +func (d *database) Update(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error { return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error { if t.pkCols == 0 { return status.Errorf(codes.InvalidArgument, "cannot update table %s with no columns in primary key", tbl) @@ -437,7 +437,7 @@ func (d *database) Update(tx *transaction, tbl string, cols []string, values []* }) } -func (d *database) InsertOrUpdate(tx *transaction, tbl string, cols []string, values []*structpb.ListValue) error { +func (d *database) InsertOrUpdate(tx *transaction, tbl spansql.ID, cols []spansql.ID, values []*structpb.ListValue) error { return d.writeValues(tx, tbl, cols, values, func(t *table, colIndexes []int, r row) error { pk := r[:t.pkCols] rowNum, found := t.rowForPK(pk) @@ -456,7 +456,7 @@ func (d *database) InsertOrUpdate(tx *transaction, tbl string, cols []string, va // TODO: Replace -func (d *database) Delete(tx *transaction, table string, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error { +func (d *database) Delete(tx *transaction, table spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, all bool) error { if err := tx.checkMutable(); err != nil { return err } @@ -507,7 +507,7 @@ func (d *database) Delete(tx *transaction, table string, keys []*structpb.ListVa } // readTable executes a read option (Read, ReadAll). -func (d *database) readTable(table string, cols []string, f func(*table, *rawIter, []int) error) (*rawIter, error) { +func (d *database) readTable(table spansql.ID, cols []spansql.ID, f func(*table, *rawIter, []int) error) (*rawIter, error) { t, err := d.table(table) if err != nil { return nil, err @@ -528,7 +528,7 @@ func (d *database) readTable(table string, cols []string, f func(*table, *rawIte return ri, f(t, ri, colIndexes) } -func (d *database) Read(tbl string, cols []string, keys []*structpb.ListValue, keyRanges keyRangeList, limit int64) (rowIter, error) { +func (d *database) Read(tbl spansql.ID, cols []spansql.ID, keys []*structpb.ListValue, keyRanges keyRangeList, limit int64) (rowIter, error) { // The real Cloud Spanner returns an error if the key set is empty by definition. // That doesn't seem to be well-defined, but it is a common error to attempt a read with no keys, // so catch that here and return a representative error. @@ -592,7 +592,7 @@ func (d *database) Read(tbl string, cols []string, keys []*structpb.ListValue, k }) } -func (d *database) ReadAll(tbl string, cols []string, limit int64) (*rawIter, error) { +func (d *database) ReadAll(tbl spansql.ID, cols []spansql.ID, limit int64) (*rawIter, error) { return d.readTable(tbl, cols, func(t *table, ri *rawIter, colIndexes []int) error { for _, r := range t.rows { ri.add(r, colIndexes) @@ -631,7 +631,7 @@ func (t *table) addColumn(cd spansql.ColumnDef, newTable bool) *status.Status { return nil } -func (t *table) dropColumn(name string) *status.Status { +func (t *table) dropColumn(name spansql.ID) *status.Status { // Only permit dropping non-key columns that aren't part of a secondary index. // We don't support indexes, so only check that it isn't part of the primary key. @@ -753,7 +753,7 @@ func (t *table) findRange(r *keyRange) (int, int) { } // colIndexes returns the indexes for the named columns. -func (t *table) colIndexes(cols []string) ([]int, error) { +func (t *table) colIndexes(cols []spansql.ID) ([]int, error) { var is []int for _, col := range cols { i, ok := t.colIndex[col] diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go index 25b0a4cd47b..48cf5837ac0 100644 --- a/spanner/spannertest/db_eval.go +++ b/spanner/spannertest/db_eval.go @@ -37,7 +37,7 @@ type evalContext struct { row row // If there are visible aliases, they are populated here. - aliases map[string]spansql.Expr + aliases map[spansql.ID]spansql.Expr params queryParams } @@ -451,23 +451,23 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) { func (ec evalContext) evalID(id spansql.ID) (interface{}, error) { for i, col := range ec.cols { - if col.Name == string(id) { + if col.Name == id { return ec.row.copyDataElem(i), nil } } - if e, ok := ec.aliases[string(id)]; ok { + if e, ok := ec.aliases[id]; ok { // Make a copy of the context without this alias // to prevent an evaluation cycle. innerEC := ec - innerEC.aliases = make(map[string]spansql.Expr) + innerEC.aliases = make(map[spansql.ID]spansql.Expr) for alias, e := range ec.aliases { - if alias != string(id) { + if alias != id { innerEC.aliases[alias] = e } } return innerEC.evalExpr(e) } - return nil, fmt.Errorf("couldn't resolve identifier %s", string(id)) + return nil, fmt.Errorf("couldn't resolve identifier %s", id) } func (ec evalContext) coerceComparisonOpArgs(co spansql.ComparisonOp) (spansql.ComparisonOp, error) { @@ -684,7 +684,7 @@ func (ec evalContext) colInfo(e spansql.Expr) (colInfo, error) { case spansql.ID: // TODO: support more than only naming a table column. for _, col := range ec.cols { - if col.Name == string(e) { + if col.Name == e { return col, nil } } diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go index 5b220510c1a..173eea46e87 100644 --- a/spanner/spannertest/db_query.go +++ b/spanner/spannertest/db_query.go @@ -269,7 +269,7 @@ type queryParam struct { Type spansql.Type } -type queryParams map[string]queryParam +type queryParams map[string]queryParam // TODO: change key to spansql.Param? func (d *database) Query(q spansql.Query, params queryParams) (rowIter, error) { // If there's an ORDER BY clause, extend the query to include the expressions we need @@ -377,7 +377,7 @@ func (d *database) evalSelect(sel spansql.Select, params queryParams) (ri rowIte var rowGroups [][2]int // Sequence of half-open intervals of row numbers. if len(sel.GroupBy) > 0 { // Load aliases visible to this GROUP BY. - ec.aliases = make(map[string]spansql.Expr) + ec.aliases = make(map[spansql.ID]spansql.Expr) for i, alias := range sel.ListAliases { ec.aliases[alias] = sel.List[i] } @@ -525,7 +525,7 @@ func (d *database) evalSelect(sel spansql.Select, params queryParams) (ri rowIte aggType = int64Type } rawOut.cols = append(raw.cols, colInfo{ - Name: fexpr.SQL(), + Name: spansql.ID(fexpr.SQL()), // TODO: this is a bit hokey, but it is output only Type: aggType, AggIndex: aggI + 1, }) diff --git a/spanner/spannertest/db_test.go b/spanner/spannertest/db_test.go index 1c42e6fabf2..81571d0ed35 100644 --- a/spanner/spannertest/db_test.go +++ b/spanner/spannertest/db_test.go @@ -63,7 +63,7 @@ func TestTableCreation(t *testing.T) { {Name: "Cool", Type: spansql.Type{Base: spansql.Bool}}, {Name: "Height", Type: spansql.Type{Base: spansql.Float64}}, }, - colIndex: map[string]int{ + colIndex: map[spansql.ID]int{ "Tenure": 2, "ID": 1, "Cool": 3, "Name": 0, "Height": 4, }, pkCols: 2, @@ -89,7 +89,7 @@ func TestTableData(t *testing.T) { // Insert a subset of columns. tx := db.NewTransaction() tx.Start() - err := db.Insert(tx, "Staff", []string{"ID", "Name", "Tenure", "Height"}, []*structpb.ListValue{ + err := db.Insert(tx, "Staff", []spansql.ID{"ID", "Name", "Tenure", "Height"}, []*structpb.ListValue{ // int64 arrives as a decimal string. listV(stringV("1"), stringV("Jack"), stringV("10"), floatV(1.85)), listV(stringV("2"), stringV("Daniel"), stringV("11"), floatV(1.83)), @@ -98,7 +98,7 @@ func TestTableData(t *testing.T) { t.Fatalf("Inserting data: %v", err) } // Insert a different set of columns. - err = db.Insert(tx, "Staff", []string{"Name", "ID", "Cool", "Tenure", "Height"}, []*structpb.ListValue{ + err = db.Insert(tx, "Staff", []spansql.ID{"Name", "ID", "Cool", "Tenure", "Height"}, []*structpb.ListValue{ listV(stringV("Sam"), stringV("3"), boolV(false), stringV("9"), floatV(1.75)), listV(stringV("Teal'c"), stringV("4"), boolV(true), stringV("8"), floatV(1.91)), listV(stringV("George"), stringV("5"), nullV(), stringV("6"), floatV(1.73)), @@ -113,7 +113,7 @@ func TestTableData(t *testing.T) { t.Fatalf("Deleting a row: %v", err) } // Turns out this guy isn't cool after all. - err = db.Update(tx, "Staff", []string{"Name", "ID", "Cool"}, []*structpb.ListValue{ + err = db.Update(tx, "Staff", []spansql.ID{"Name", "ID", "Cool"}, []*structpb.ListValue{ // Missing columns should be left alone. listV(stringV("Daniel"), stringV("2"), boolV(false)), }) @@ -125,7 +125,7 @@ func TestTableData(t *testing.T) { } // Read some specific keys. - ri, err := db.Read("Staff", []string{"Name", "Tenure"}, []*structpb.ListValue{ + ri, err := db.Read("Staff", []spansql.ID{"Name", "Tenure"}, []*structpb.ListValue{ listV(stringV("George"), stringV("5")), listV(stringV("Harry"), stringV("6")), // Missing key should be silently ignored. listV(stringV("Sam"), stringV("3")), @@ -143,7 +143,7 @@ func TestTableData(t *testing.T) { t.Errorf("Read data by keys wrong.\n got %v\nwant %v", all, wantAll) } // Read the same, but by key range. - ri, err = db.Read("Staff", []string{"Name", "Tenure"}, nil, keyRangeList{ + ri, err = db.Read("Staff", []spansql.ID{"Name", "Tenure"}, nil, keyRangeList{ {start: listV(stringV("Gabriel")), end: listV(stringV("Harpo"))}, // open/open { // closed/open @@ -162,7 +162,7 @@ func TestTableData(t *testing.T) { } // Read a subset of all rows, with a limit. - ri, err = db.ReadAll("Staff", []string{"Tenure", "Name", "Height"}, 4) + ri, err = db.ReadAll("Staff", []spansql.ID{"Tenure", "Name", "Height"}, 4) if err != nil { t.Fatalf("ReadAll: %v", err) } @@ -209,7 +209,7 @@ func TestTableData(t *testing.T) { } tx = db.NewTransaction() tx.Start() - err = db.Update(tx, "Staff", []string{"Name", "ID", "FirstSeen", "To"}, []*structpb.ListValue{ + err = db.Update(tx, "Staff", []spansql.ID{"Name", "ID", "FirstSeen", "To"}, []*structpb.ListValue{ listV(stringV("Jack"), stringV("1"), stringV("1994-10-28"), nullV()), listV(stringV("Daniel"), stringV("2"), stringV("1994-10-28"), nullV()), listV(stringV("George"), stringV("5"), stringV("1997-07-27"), stringV("2008-07-29T11:22:43Z")), @@ -225,7 +225,7 @@ func TestTableData(t *testing.T) { // The queries below ensure that this was all deleted. tx = db.NewTransaction() tx.Start() - err = db.Insert(tx, "Staff", []string{"Name", "ID"}, []*structpb.ListValue{ + err = db.Insert(tx, "Staff", []spansql.ID{"Name", "ID"}, []*structpb.ListValue{ listV(stringV("01"), stringV("1")), listV(stringV("03"), stringV("3")), listV(stringV("06"), stringV("6")), @@ -245,7 +245,7 @@ func TestTableData(t *testing.T) { t.Fatalf("Committing changes: %v", err) } // Re-add the data and delete with DML. - err = db.Insert(tx, "Staff", []string{"Name", "ID"}, []*structpb.ListValue{ + err = db.Insert(tx, "Staff", []spansql.ID{"Name", "ID"}, []*structpb.ListValue{ listV(stringV("01"), stringV("1")), listV(stringV("03"), stringV("3")), listV(stringV("06"), stringV("6")), @@ -292,7 +292,7 @@ func TestTableData(t *testing.T) { } tx = db.NewTransaction() tx.Start() - err = db.Update(tx, "Staff", []string{"Name", "ID", "RawBytes"}, []*structpb.ListValue{ + err = db.Update(tx, "Staff", []spansql.ID{"Name", "ID", "RawBytes"}, []*structpb.ListValue{ // bytes {0x01 0x00 0x01} encode as base-64 AQAB. listV(stringV("Jack"), stringV("1"), stringV("AQAB")), }) @@ -324,7 +324,7 @@ func TestTableData(t *testing.T) { } tx = db.NewTransaction() tx.Start() - err = db.Insert(tx, "PlayerStats", []string{"LastName", "OpponentID", "PointsScored"}, []*structpb.ListValue{ + err = db.Insert(tx, "PlayerStats", []spansql.ID{"LastName", "OpponentID", "PointsScored"}, []*structpb.ListValue{ listV(stringV("Adams"), stringV("51"), stringV("3")), listV(stringV("Buchanan"), stringV("77"), stringV("0")), listV(stringV("Coolidge"), stringV("77"), stringV("1")), @@ -607,7 +607,7 @@ func TestTableDescendingKey(t *testing.T) { tx := db.NewTransaction() tx.Start() - err := db.Insert(tx, "Timeseries", []string{"Name", "Observed", "Value"}, []*structpb.ListValue{ + err := db.Insert(tx, "Timeseries", []spansql.ID{"Name", "Observed", "Value"}, []*structpb.ListValue{ listV(stringV("box"), stringV("1"), floatV(1.1)), listV(stringV("cupcake"), stringV("1"), floatV(6)), listV(stringV("box"), stringV("2"), floatV(1.2)), @@ -665,7 +665,7 @@ func TestTableSchemaConvertNull(t *testing.T) { // Populate with data including a NULL for the STRING field. tx := db.NewTransaction() tx.Start() - err := db.Insert(tx, "Songwriters", []string{"ID", "Nickname"}, []*structpb.ListValue{ + err := db.Insert(tx, "Songwriters", []spansql.ID{"ID", "Nickname"}, []*structpb.ListValue{ listV(stringV("6"), stringV("Tiger")), listV(stringV("7"), nullV()), }) @@ -814,7 +814,7 @@ func TestConcurrentReadInsert(t *testing.T) { // Insert some initial data. tx := db.NewTransaction() tx.Start() - err := db.Insert(tx, "Tablino", []string{"A"}, []*structpb.ListValue{ + err := db.Insert(tx, "Tablino", []spansql.ID{"A"}, []*structpb.ListValue{ listV(stringV("1")), listV(stringV("2")), listV(stringV("4")), @@ -850,7 +850,7 @@ func TestConcurrentReadInsert(t *testing.T) { tx := db.NewTransaction() tx.Start() - err := db.Insert(tx, "Tablino", []string{"A"}, []*structpb.ListValue{ + err := db.Insert(tx, "Tablino", []spansql.ID{"A"}, []*structpb.ListValue{ listV(stringV("3")), }) if err != nil { diff --git a/spanner/spannertest/inmem.go b/spanner/spannertest/inmem.go index ba2b2d874f6..a902ae28a58 100644 --- a/spanner/spannertest/inmem.go +++ b/spanner/spannertest/inmem.go @@ -565,10 +565,10 @@ func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Span var ri rowIter if req.KeySet.All { s.logf("Reading all from %s (cols: %v)", req.Table, req.Columns) - ri, err = s.db.ReadAll(req.Table, req.Columns, req.Limit) + ri, err = s.db.ReadAll(spansql.ID(req.Table), idList(req.Columns), req.Limit) } else { s.logf("Reading rows from %d keys and %d ranges from %s (cols: %v)", len(req.KeySet.Keys), len(req.KeySet.Ranges), req.Table, req.Columns) - ri, err = s.db.Read(req.Table, req.Columns, req.KeySet.Keys, makeKeyRangeList(req.KeySet.Ranges), req.Limit) + ri, err = s.db.Read(spansql.ID(req.Table), idList(req.Columns), req.KeySet.Keys, makeKeyRangeList(req.KeySet.Ranges), req.Limit) } if err != nil { return err @@ -593,7 +593,7 @@ func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spa return err } rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{ - Name: ci.Name, + Name: string(ci.Name), Type: st, }) } @@ -686,19 +686,19 @@ func (s *server) Commit(ctx context.Context, req *spannerpb.CommitRequest) (resp return nil, fmt.Errorf("unsupported mutation operation type %T", op) case *spannerpb.Mutation_Insert: ins := op.Insert - err := s.db.Insert(tx, ins.Table, ins.Columns, ins.Values) + err := s.db.Insert(tx, spansql.ID(ins.Table), idList(ins.Columns), ins.Values) if err != nil { return nil, err } case *spannerpb.Mutation_Update: up := op.Update - err := s.db.Update(tx, up.Table, up.Columns, up.Values) + err := s.db.Update(tx, spansql.ID(up.Table), idList(up.Columns), up.Values) if err != nil { return nil, err } case *spannerpb.Mutation_InsertOrUpdate: iou := op.InsertOrUpdate - err := s.db.InsertOrUpdate(tx, iou.Table, iou.Columns, iou.Values) + err := s.db.InsertOrUpdate(tx, spansql.ID(iou.Table), idList(iou.Columns), iou.Values) if err != nil { return nil, err } @@ -706,7 +706,7 @@ func (s *server) Commit(ctx context.Context, req *spannerpb.CommitRequest) (resp del := op.Delete ks := del.KeySet - err := s.db.Delete(tx, del.Table, ks.Keys, makeKeyRangeList(ks.Ranges), ks.All) + err := s.db.Delete(tx, spansql.ID(del.Table), ks.Keys, makeKeyRangeList(ks.Ranges), ks.All) if err != nil { return nil, err } @@ -917,3 +917,10 @@ func makeKeyRange(r *spannerpb.KeyRange) *keyRange { } return &kr } + +func idList(ss []string) (ids []spansql.ID) { + for _, s := range ss { + ids = append(ids, spansql.ID(s)) + } + return +} diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index beee8069c8b..e50273d76c6 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -1502,8 +1502,8 @@ func (p *parser) parseForeignKey() (ForeignKey, *parseError) { return fk, nil } -func (p *parser) parseColumnNameList() ([]string, *parseError) { - var list []string +func (p *parser) parseColumnNameList() ([]ID, *parseError) { + var list []ID err := p.parseCommaList(func(p *parser) *parseError { n, err := p.parseTableOrIndexOrColumnName() if err != nil { @@ -1740,9 +1740,9 @@ func (p *parser) parseSelect() (Select, *parseError) { return sel, nil } -func (p *parser) parseSelectList() ([]Expr, []string, *parseError) { +func (p *parser) parseSelectList() ([]Expr, []ID, *parseError) { var list []Expr - var aliases []string // Only set if any aliases are seen. + var aliases []ID // Only set if any aliases are seen. padAliases := func() { for len(aliases) < len(list) { aliases = append(aliases, "") @@ -2297,13 +2297,13 @@ func (p *parser) parseBoolExpr() (BoolExpr, *parseError) { return be, nil } -func (p *parser) parseAlias() (string, *parseError) { +func (p *parser) parseAlias() (ID, *parseError) { // The docs don't specify what lexical token is valid for an alias, // but it seems likely that it is an identifier. return p.parseTableOrIndexOrColumnName() } -func (p *parser) parseTableOrIndexOrColumnName() (string, *parseError) { +func (p *parser) parseTableOrIndexOrColumnName() (ID, *parseError) { /* table_name and column_name and index_name: {a—z|A—Z}[{a—z|A—Z|0—9|_}+] @@ -2314,10 +2314,10 @@ func (p *parser) parseTableOrIndexOrColumnName() (string, *parseError) { return "", tok.err } if tok.typ == quotedID { - return tok.string, nil + return ID(tok.string), nil } // TODO: enforce restrictions - return tok.value, nil + return ID(tok.value), nil } func (p *parser) parseOnDelete() (OnDelete, *parseError) { diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index a9447bdf47a..8b6b4bfaf9b 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -49,7 +49,7 @@ func TestParseQuery(t *testing.T) { RHS: Null, }, }, - ListAliases: []string{"aka"}, + ListAliases: []ID{"aka"}, }, Order: []Order{{ Expr: ID("Age"), @@ -90,7 +90,7 @@ func TestParseQuery(t *testing.T) { }, From: []SelectFrom{SelectFromTable{Table: "PlayerStats"}}, GroupBy: []Expr{ID("FirstName"), ID("LastName")}, - ListAliases: []string{"total_points", "", "surname"}, + ListAliases: []ID{"total_points", "", "surname"}, }, }, }, @@ -109,7 +109,7 @@ func TestParseQuery(t *testing.T) { LHS: ID("l_user_id"), RHS: Param("userID"), }, - ListAliases: []string{"count"}, + ListAliases: []ID{"count"}, }, }, }, @@ -318,7 +318,7 @@ func TestParseDDL(t *testing.T) { Table: "FooBar", Columns: []KeyPart{{Column: "Count", Desc: true}}, Unique: true, - Storing: []string{"Count"}, + Storing: []ID{"Count"}, Interleave: "SomeTable", Position: line(8), }, @@ -334,18 +334,18 @@ func TestParseDDL(t *testing.T) { { Name: "Con1", ForeignKey: ForeignKey{ - Columns: []string{"System"}, + Columns: []ID{"System"}, RefTable: "FooBar", - RefColumns: []string{"System"}, + RefColumns: []ID{"System"}, Position: line(13), }, Position: line(13), }, { ForeignKey: ForeignKey{ - Columns: []string{"System", "RepoPath"}, + Columns: []ID{"System", "RepoPath"}, RefTable: "Stranger", - RefColumns: []string{"Sys", "RPath"}, + RefColumns: []ID{"Sys", "RPath"}, Position: line(15), }, Position: line(15), @@ -377,9 +377,9 @@ func TestParseDDL(t *testing.T) { Alteration: AddConstraint{Constraint: TableConstraint{ Name: "Con2", ForeignKey: ForeignKey{ - Columns: []string{"RepoPath"}, + Columns: []ID{"RepoPath"}, RefTable: "Repos", - RefColumns: []string{"RPath"}, + RefColumns: []ID{"RPath"}, Position: line(23), }, Position: line(23), @@ -523,7 +523,7 @@ func TestParseDDL(t *testing.T) { } } -func tableByName(t *testing.T, ddl *DDL, name string) *CreateTable { +func tableByName(t *testing.T, ddl *DDL, name ID) *CreateTable { t.Helper() for _, stmt := range ddl.List { if ct, ok := stmt.(*CreateTable); ok && ct.Name == name { diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index 2d59ac1205b..deae3135cbb 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -25,7 +25,7 @@ import ( ) func (ct CreateTable) SQL() string { - str := "CREATE TABLE " + ID(ct.Name).SQL() + " (\n" + str := "CREATE TABLE " + ct.Name.SQL() + " (\n" for _, c := range ct.Columns { str += " " + c.SQL() + ",\n" } @@ -41,7 +41,7 @@ func (ct CreateTable) SQL() string { } str += ")" if il := ct.Interleave; il != nil { - str += ",\n INTERLEAVE IN PARENT " + ID(il.Parent).SQL() + " ON DELETE " + il.OnDelete.SQL() + str += ",\n INTERLEAVE IN PARENT " + il.Parent.SQL() + " ON DELETE " + il.OnDelete.SQL() } return str } @@ -54,7 +54,7 @@ func (ci CreateIndex) SQL() string { if ci.NullFiltered { str += " NULL_FILTERED" } - str += " INDEX " + ID(ci.Name).SQL() + " ON " + ID(ci.Table).SQL() + "(" + str += " INDEX " + ci.Name.SQL() + " ON " + ci.Table.SQL() + "(" for i, c := range ci.Columns { if i > 0 { str += ", " @@ -66,21 +66,21 @@ func (ci CreateIndex) SQL() string { str += " STORING (" + idList(ci.Storing) + ")" } if ci.Interleave != "" { - str += ", INTERLEAVE IN " + ID(ci.Interleave).SQL() + str += ", INTERLEAVE IN " + ci.Interleave.SQL() } return str } func (dt DropTable) SQL() string { - return "DROP TABLE " + ID(dt.Name).SQL() + return "DROP TABLE " + dt.Name.SQL() } func (di DropIndex) SQL() string { - return "DROP INDEX " + ID(di.Name).SQL() + return "DROP INDEX " + di.Name.SQL() } func (at AlterTable) SQL() string { - return "ALTER TABLE " + ID(at.Name).SQL() + " " + at.Alteration.SQL() + return "ALTER TABLE " + at.Name.SQL() + " " + at.Alteration.SQL() } func (ac AddColumn) SQL() string { @@ -88,7 +88,7 @@ func (ac AddColumn) SQL() string { } func (dc DropColumn) SQL() string { - return "DROP COLUMN " + ID(dc.Name).SQL() + return "DROP COLUMN " + dc.Name.SQL() } func (ac AddConstraint) SQL() string { @@ -96,7 +96,7 @@ func (ac AddConstraint) SQL() string { } func (dc DropConstraint) SQL() string { - return "DROP CONSTRAINT " + ID(dc.Name).SQL() + return "DROP CONSTRAINT " + dc.Name.SQL() } func (sod SetOnDelete) SQL() string { @@ -114,7 +114,7 @@ func (od OnDelete) SQL() string { } func (ac AlterColumn) SQL() string { - return "ALTER COLUMN " + ID(ac.Name).SQL() + " " + ac.Alteration.SQL() + return "ALTER COLUMN " + ac.Name.SQL() + " " + ac.Alteration.SQL() } func (sct SetColumnType) SQL() string { @@ -144,11 +144,11 @@ func (co ColumnOptions) SQL() string { } func (d *Delete) SQL() string { - return "DELETE FROM " + ID(d.Table).SQL() + " WHERE " + d.Where.SQL() + return "DELETE FROM " + d.Table.SQL() + " WHERE " + d.Where.SQL() } func (cd ColumnDef) SQL() string { - str := ID(cd.Name).SQL() + " " + cd.Type.SQL() + str := cd.Name.SQL() + " " + cd.Type.SQL() if cd.NotNull { str += " NOT NULL" } @@ -161,7 +161,7 @@ func (cd ColumnDef) SQL() string { func (tc TableConstraint) SQL() string { var str string if tc.Name != "" { - str += "CONSTRAINT " + ID(tc.Name).SQL() + " " + str += "CONSTRAINT " + tc.Name.SQL() + " " } str += tc.ForeignKey.SQL() return str @@ -169,7 +169,7 @@ func (tc TableConstraint) SQL() string { func (fk ForeignKey) SQL() string { str := "FOREIGN KEY (" + idList(fk.Columns) - str += ") REFERENCES " + ID(fk.RefTable).SQL() + " (" + str += ") REFERENCES " + fk.RefTable.SQL() + " (" str += idList(fk.RefColumns) + ")" return str } @@ -212,7 +212,7 @@ func (tb TypeBase) SQL() string { } func (kp KeyPart) SQL() string { - str := ID(kp.Column).SQL() + str := kp.Column.SQL() if kp.Desc { str += " DESC" } @@ -252,7 +252,7 @@ func (sel Select) SQL() string { if len(sel.ListAliases) > 0 { alias := sel.ListAliases[i] if alias != "" { - str += " AS " + ID(alias).SQL() + str += " AS " + alias.SQL() } } } @@ -405,10 +405,10 @@ func (f Func) SQL() string { return str } -func idList(l []string) string { +func idList(l []ID) string { var ss []string for _, s := range l { - ss = append(ss, ID(s).SQL()) + ss = append(ss, s.SQL()) } return strings.Join(ss, ", ") } diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go index 8d47b1e7b88..c9c152f65e8 100644 --- a/spanner/spansql/sql_test.go +++ b/spanner/spansql/sql_test.go @@ -250,7 +250,7 @@ func TestSQL(t *testing.T) { RHS: Null, }, }, - ListAliases: []string{"", "banana"}, + ListAliases: []ID{"", "banana"}, }, Order: []Order{{Expr: ID("OCol"), Desc: true}}, Limit: IntegerLiteral(1000), diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index ce020e4488d..f1222d505c0 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -25,12 +25,11 @@ import ( ) // TODO: More Position fields throughout; maybe in Query/Select. -// TODO: Perhaps identifiers in the AST should be ID-typed. // CreateTable represents a CREATE TABLE statement. // https://cloud.google.com/spanner/docs/data-definition-language#create_table type CreateTable struct { - Name string + Name ID Columns []ColumnDef Constraints []TableConstraint PrimaryKey []KeyPart @@ -56,7 +55,7 @@ func (ct *CreateTable) clearOffset() { // TableConstraint represents a constraint on a table. type TableConstraint struct { - Name string // may be empty + Name ID // may be empty ForeignKey ForeignKey Position Position // position of the "CONSTRAINT" or "FOREIGN" token @@ -70,22 +69,22 @@ func (tc *TableConstraint) clearOffset() { // Interleave represents an interleave clause of a CREATE TABLE statement. type Interleave struct { - Parent string + Parent ID OnDelete OnDelete } // CreateIndex represents a CREATE INDEX statement. // https://cloud.google.com/spanner/docs/data-definition-language#create-index type CreateIndex struct { - Name string - Table string + Name ID + Table ID Columns []KeyPart Unique bool NullFiltered bool - Storing []string - Interleave string + Storing []ID + Interleave ID Position Position // position of the "CREATE" token } @@ -98,7 +97,7 @@ func (ci *CreateIndex) clearOffset() { ci.Position.Offset = 0 } // DropTable represents a DROP TABLE statement. // https://cloud.google.com/spanner/docs/data-definition-language#drop_table type DropTable struct { - Name string + Name ID Position Position // position of the "DROP" token } @@ -111,7 +110,7 @@ func (dt *DropTable) clearOffset() { dt.Position.Offset = 0 } // DropIndex represents a DROP INDEX statement. // https://cloud.google.com/spanner/docs/data-definition-language#drop-index type DropIndex struct { - Name string + Name ID Position Position // position of the "DROP" token } @@ -124,7 +123,7 @@ func (di *DropIndex) clearOffset() { di.Position.Offset = 0 } // AlterTable represents an ALTER TABLE statement. // https://cloud.google.com/spanner/docs/data-definition-language#alter_table type AlterTable struct { - Name string + Name ID Alteration TableAlteration Position Position // position of the "ALTER" token @@ -160,12 +159,12 @@ func (SetOnDelete) isTableAlteration() {} func (AlterColumn) isTableAlteration() {} type AddColumn struct{ Def ColumnDef } -type DropColumn struct{ Name string } +type DropColumn struct{ Name ID } type AddConstraint struct{ Constraint TableConstraint } -type DropConstraint struct{ Name string } +type DropConstraint struct{ Name ID } type SetOnDelete struct{ Action OnDelete } type AlterColumn struct { - Name string + Name ID Alteration ColumnAlteration } @@ -195,7 +194,7 @@ const ( // Delete represents a DELETE statement. // https://cloud.google.com/spanner/docs/dml-syntax#delete-statement type Delete struct { - Table string + Table ID Where BoolExpr // TODO: Alias @@ -209,7 +208,7 @@ func (*Delete) isDMLStmt() {} // ColumnDef represents a column definition as part of a CREATE TABLE // or ALTER TABLE statement. type ColumnDef struct { - Name string + Name ID Type Type NotNull bool @@ -234,9 +233,9 @@ type ColumnOptions struct { // ForeignKey represents a foreign key definition as part of a CREATE TABLE // or ALTER TABLE statement. type ForeignKey struct { - Columns []string - RefTable string - RefColumns []string + Columns []ID + RefTable ID + RefColumns []ID Position Position // position of the "FOREIGN" token } @@ -268,7 +267,7 @@ const ( // KeyPart represents a column specification as part of a primary key or index definition. type KeyPart struct { - Column string + Column ID Desc bool } @@ -299,7 +298,7 @@ type Select struct { // If the SELECT list has explicit aliases ("AS alias"), // ListAliases will be populated 1:1 with List; // aliases that are present will be non-empty. - ListAliases []string + ListAliases []ID } // SelectFrom represents the FROM clause of a SELECT. @@ -311,8 +310,8 @@ type SelectFrom interface { // SelectFromTable is a SelectFrom that specifies a table to read from. type SelectFromTable struct { - Table string - Alias string // empty if not aliased + Table ID + Alias ID // empty if not aliased } func (SelectFromTable) isSelectFrom() {} @@ -456,7 +455,7 @@ type IsExpr interface { // Func represents a function call. type Func struct { - Name string + Name string // not ID Args []Expr // TODO: various functions permit as-expressions, which might warrant different types in here. @@ -474,6 +473,7 @@ func (Paren) isBoolExpr() {} // possibly bool func (Paren) isExpr() {} // ID represents an identifier. +// https://cloud.google.com/spanner/docs/lexical#identifiers type ID string func (ID) isBoolExpr() {} // possibly bool