diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go index 238d46a80ff..b0ca0a40c9b 100644 --- a/spanner/spannertest/db_eval.go +++ b/spanner/spannertest/db_eval.go @@ -428,6 +428,17 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) { return bool(e), nil case spansql.Paren: return ec.evalExpr(e.Expr) + case spansql.Array: + var arr []interface{} + for _, elt := range e { + v, err := ec.evalExpr(elt) + if err != nil { + return nil, err + } + arr = append(arr, v) + } + // TODO: enforce or coerce to consistent types. + return arr, nil case spansql.ArithOp: return ec.evalArithOp(e) case spansql.LogicalOp: @@ -774,6 +785,21 @@ func (ec evalContext) colInfo(e spansql.Expr) (colInfo, error) { return colInfo{Type: qp.Type}, nil case spansql.Paren: return ec.colInfo(e.Expr) + case spansql.Array: + // Assume all element of an array literal have the same type. + if len(e) == 0 { + // TODO: What does the real Spanner do here? + return colInfo{Type: spansql.Type{Base: spansql.Int64, Array: true}}, nil + } + ci, err := ec.colInfo(e[0]) + if err != nil { + return colInfo{}, err + } + if ci.Type.Array { + return colInfo{}, fmt.Errorf("can't nest array literals") + } + ci.Type.Array = true + return ci, nil case spansql.NullLiteral: // There isn't necessarily something sensible here. // Empirically, though, the real Spanner returns Int64. @@ -781,7 +807,7 @@ func (ec evalContext) colInfo(e spansql.Expr) (colInfo, error) { case aggSentinel: return colInfo{Type: e.Type, AggIndex: e.AggIndex}, nil } - return colInfo{}, fmt.Errorf("can't deduce column type from expression [%s]", e.SQL()) + return colInfo{}, fmt.Errorf("can't deduce column type from expression [%s] (type %T)", e.SQL(), e) } func (ec evalContext) arithColType(ao spansql.ArithOp) (spansql.Type, error) { diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index 5af70b6c5ab..3ab6a5821fc 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -890,10 +890,9 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { }, }, { - // From https://cloud.google.com/spanner/docs/aggregate_functions#avg - // but using a param for the array since array literals aren't supported yet. - `SELECT AVG(x) AS avg FROM UNNEST(@p) AS x`, - map[string]interface{}{"p": []int64{0, 2, 4, 4, 5}}, + // From https://cloud.google.com/spanner/docs/aggregate_functions#avg. + `SELECT AVG(x) AS avg FROM UNNEST([0, 2, 4, 4, 5]) AS x`, + nil, [][]interface{}{ {float64(3)}, }, diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index 3e59171ab99..ed93afb4d4e 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -780,9 +780,9 @@ func (p *parser) advance() { p.cur.err = nil p.cur.line, p.cur.offset = p.line, p.offset p.cur.typ = unknownToken - // TODO: array, struct, date, timestamp literals + // TODO: struct, date, timestamp literals switch p.s[0] { - case ',', ';', '(', ')', '{', '}', '*', '+', '-': + case ',', ';', '(', ')', '{', '}', '[', ']', '*', '+', '-': // Single character symbol. p.cur.value, p.s = p.s[:1], p.s[1:] p.offset++ @@ -1000,7 +1000,7 @@ func (p *parser) parseCreateTable() (*CreateTable, *parseError) { } ct := &CreateTable{Name: tname, Position: pos} - err = p.parseCommaList(func(p *parser) *parseError { + err = p.parseCommaList("(", ")", func(p *parser) *parseError { if p.sniffTableConstraint() { tc, err := p.parseTableConstraint() if err != nil { @@ -1483,7 +1483,7 @@ func (p *parser) parseColumnOptions() (ColumnOptions, *parseError) { func (p *parser) parseKeyPartList() ([]KeyPart, *parseError) { var list []KeyPart - err := p.parseCommaList(func(p *parser) *parseError { + err := p.parseCommaList("(", ")", func(p *parser) *parseError { kp, err := p.parseKeyPart() if err != nil { return err @@ -1635,7 +1635,7 @@ func (p *parser) parseCheck() (Check, *parseError) { func (p *parser) parseColumnNameList() ([]ID, *parseError) { var list []ID - err := p.parseCommaList(func(p *parser) *parseError { + err := p.parseCommaList("(", ")", func(p *parser) *parseError { n, err := p.parseTableOrIndexOrColumnName() if err != nil { return err @@ -2200,7 +2200,7 @@ func (p *parser) parseExprList() ([]Expr, *parseError) { func (p *parser) parseParenExprList() ([]Expr, *parseError) { var list []Expr - err := p.parseCommaList(func(p *parser) *parseError { + err := p.parseCommaList("(", ")", func(p *parser) *parseError { e, err := p.parseExpr() if err != nil { return err @@ -2601,7 +2601,13 @@ func (p *parser) parseLit() (Expr, *parseError) { // case insensitivity for keywords. } - // TODO: more types of literals (array, struct, date, timestamp). + // Handle array literals. + if tok.value == "ARRAY" || tok.value == "[" { + p.back() + return p.parseArrayLit() + } + + // TODO: more types of literals (struct, date, timestamp). // Try a parameter. // TODO: check character sets. @@ -2621,6 +2627,24 @@ func (p *parser) parseLit() (Expr, *parseError) { return pe, nil } +func (p *parser) parseArrayLit() (Array, *parseError) { + // ARRAY keyword is optional. + // TODO: If it is present, consume any after it. + p.eat("ARRAY") + + var arr Array + err := p.parseCommaList("[", "]", func(p *parser) *parseError { + e, err := p.parseLit() + if err != nil { + return err + } + // TODO: Do type consistency checking here? + arr = append(arr, e) + return nil + }) + return arr, err +} + func (p *parser) parsePathExp() (PathExp, *parseError) { var pe PathExp for { @@ -2705,14 +2729,14 @@ func (p *parser) parseOnDelete() (OnDelete, *parseError) { return NoActionOnDelete, nil } -// parseCommaList parses a parenthesized comma-separated list, +// parseCommaList parses a comma-separated list enclosed by bra and ket, // delegating to f for the individual element parsing. -func (p *parser) parseCommaList(f func(*parser) *parseError) *parseError { - if err := p.expect("("); err != nil { +func (p *parser) parseCommaList(bra, ket string, f func(*parser) *parseError) *parseError { + if err := p.expect(bra); err != nil { return err } for { - if p.eat(")") { + if p.eat(ket) { return nil } @@ -2721,17 +2745,17 @@ func (p *parser) parseCommaList(f func(*parser) *parseError) *parseError { return err } - // ")" or "," should be next. + // ket or "," should be next. tok := p.next() if tok.err != nil { return err } - if tok.value == ")" { + if tok.value == ket { return nil } else if tok.value == "," { continue } else { - return p.errorf(`got %q, want ")" or ","`, tok.value) + return p.errorf(`got %q, want %q or ","`, tok.value, ket) } } } diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index f0be4bcadc7..ae5577d5844 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -190,12 +190,16 @@ func TestParseQuery(t *testing.T) { }, }, }, - {`SELECT * FROM UNNEST (@p) AS data`, // array literals aren't yet supported + {`SELECT * FROM UNNEST ([1, 2, 3]) AS data`, Query{ Select: Select{ List: []Expr{Star}, From: []SelectFrom{SelectFromUnnest{ - Expr: Param("p"), + Expr: Array{ + IntegerLiteral(1), + IntegerLiteral(2), + IntegerLiteral(3), + }, Alias: ID("data"), }}, }, @@ -298,6 +302,12 @@ func TestParseExpr(t *testing.T) { {`RB"""\\//\\//"""`, BytesLiteral("\\\\//\\\\//")}, {"RB'''\\\\//\n\\\\//'''", BytesLiteral("\\\\//\n\\\\//")}, + // Array literals: + // https://cloud.google.com/spanner/docs/lexical#array_literals + {`[1, 2, 3]`, Array{IntegerLiteral(1), IntegerLiteral(2), IntegerLiteral(3)}}, + {`['x', 'y', 'xy']`, Array{StringLiteral("x"), StringLiteral("y"), StringLiteral("xy")}}, + {`ARRAY[1, 2, 3]`, Array{IntegerLiteral(1), IntegerLiteral(2), IntegerLiteral(3)}}, + // OR is lower precedence than AND. {`A AND B OR C`, LogicalOp{LHS: LogicalOp{LHS: ID("A"), Op: And, RHS: ID("B")}, Op: Or, RHS: ID("C")}}, {`A OR B AND C`, LogicalOp{LHS: ID("A"), Op: Or, RHS: LogicalOp{LHS: ID("B"), Op: And, RHS: ID("C")}}}, diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index b41cbed3bc2..99160c83d91 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -527,6 +527,13 @@ func (p Paren) addSQL(sb *strings.Builder) { sb.WriteString(")") } +func (a Array) SQL() string { return buildSQL(a) } +func (a Array) addSQL(sb *strings.Builder) { + sb.WriteString("[") + addExprList(sb, []Expr(a), ", ") + sb.WriteString("]") +} + func (id ID) SQL() string { return buildSQL(id) } func (id ID) addSQL(sb *strings.Builder) { // https://cloud.google.com/spanner/docs/lexical#identifiers diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index f169151419f..940df82e4b6 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -568,6 +568,11 @@ type Paren struct { func (Paren) isBoolExpr() {} // possibly bool func (Paren) isExpr() {} +// Array represents an array literal. +type Array []Expr + +func (Array) isExpr() {} + // ID represents an identifier. // https://cloud.google.com/spanner/docs/lexical#identifiers type ID string