Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
feat(spanner/spansql): case insensitive parsing of keywords and funct…
…ions (#4034)

This handles the most important parts of the parser as it relates to
case insensitivity, namely permitting the lexical tokens for the grammar
(as opposed to elements like table/column names) to have arbitrary case.

This includes a complete audit of the parser to check all the lexical
token handling, and cleaned up a few places that were particularly
ancient.

The rendered output remains canonically uppercase.

Fixes #4032.
  • Loading branch information
dsymonds committed May 4, 2021
1 parent d4af6f7 commit ddb09d2
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 74 deletions.
2 changes: 1 addition & 1 deletion spanner/spannertest/README.md
Expand Up @@ -21,12 +21,12 @@ by ascending esotericism:
- NUMERIC
- more aggregation functions
- SELECT HAVING
- case insensitivity
- more literal types
- generated columns
- expression type casting, coercion
- multiple joins
- subselects
- case insensitivity of table and column names and query aliases
- transaction simulation
- FOREIGN KEY and CHECK constraints
- INSERT DML statements
Expand Down
138 changes: 67 additions & 71 deletions spanner/spansql/parser.go
Expand Up @@ -885,15 +885,21 @@ func (p *parser) next() *token {
return &p.cur
}

// caseEqual reports whether the token is valid, not a quoted identifier, and
// equal to the provided string under a case insensitive comparison.
// Use this (or sniff/eat/expect) instead of comparing a string directly for keywords, etc.
func (t *token) caseEqual(x string) bool {
return t.err == nil && t.typ != quotedID && strings.EqualFold(t.value, x)
}

// sniff reports whether the next N tokens are as specified.
func (p *parser) sniff(want ...string) bool {
// Store current parser state and restore on the way out.
orig := *p
defer func() { *p = orig }()

for _, w := range want {
tok := p.next()
if tok.err != nil || tok.value != w {
if !p.next().caseEqual(w) {
return false
}
}
Expand All @@ -907,8 +913,7 @@ func (p *parser) eat(want ...string) bool {
orig := *p

for _, w := range want {
tok := p.next()
if tok.err != nil || tok.value != w {
if !p.next().caseEqual(w) {
// Mismatch.
*p = orig
return false
Expand All @@ -922,7 +927,7 @@ func (p *parser) expect(want string) *parseError {
if tok.err != nil {
return tok.err
}
if tok.value != want {
if !tok.caseEqual(want) {
return p.errorf("got %q while expecting %q", tok.value, want)
}
return nil
Expand Down Expand Up @@ -958,18 +963,22 @@ func (p *parser) parseDDLStmt() (DDLStmt, *parseError) {
if tok.err != nil {
return nil, tok.err
}
kind := tok.value
if kind != "TABLE" && kind != "INDEX" {
return nil, p.errorf("got %q, want TABLE or INDEX", kind)
}
name, err := p.parseTableOrIndexOrColumnName()
if err != nil {
return nil, err
}
if kind == "TABLE" {
switch {
default:
return nil, p.errorf("got %q, want TABLE or INDEX", tok.value)
case tok.caseEqual("TABLE"):
name, err := p.parseTableOrIndexOrColumnName()
if err != nil {
return nil, err
}
return &DropTable{Name: name, Position: pos}, nil
case tok.caseEqual("INDEX"):
name, err := p.parseTableOrIndexOrColumnName()
if err != nil {
return nil, err
}
return &DropIndex{Name: name, Position: pos}, nil
}
return &DropIndex{Name: name, Position: pos}, nil
}

return nil, p.errorf("unknown DDL statement")
Expand Down Expand Up @@ -1200,10 +1209,10 @@ func (p *parser) parseAlterTable() (*AlterTable, *parseError) {
if tok.err != nil {
return nil, tok.err
}
switch tok.value {
switch {
default:
return nil, p.errorf("got %q, expected ADD or DROP or SET or ALTER", tok.value)
case "ADD":
case tok.caseEqual("ADD"):
if p.sniff("CONSTRAINT") || p.sniff("FOREIGN") || p.sniff("CHECK") {
tc, err := p.parseTableConstraint()
if err != nil {
Expand All @@ -1223,7 +1232,7 @@ func (p *parser) parseAlterTable() (*AlterTable, *parseError) {
}
a.Alteration = AddColumn{Def: cd}
return a, nil
case "DROP":
case tok.caseEqual("DROP"):
if p.eat("CONSTRAINT") {
name, err := p.parseTableOrIndexOrColumnName()
if err != nil {
Expand All @@ -1243,7 +1252,7 @@ func (p *parser) parseAlterTable() (*AlterTable, *parseError) {
}
a.Alteration = DropColumn{Name: name}
return a, nil
case "SET":
case tok.caseEqual("SET"):
if err := p.expect("ON"); err != nil {
return nil, err
}
Expand All @@ -1256,7 +1265,7 @@ func (p *parser) parseAlterTable() (*AlterTable, *parseError) {
}
a.Alteration = SetOnDelete{Action: od}
return a, nil
case "ALTER":
case tok.caseEqual("ALTER"):
// TODO: "COLUMN" is optional.
if err := p.expect("COLUMN"); err != nil {
return nil, err
Expand Down Expand Up @@ -1459,6 +1468,8 @@ func (p *parser) parseColumnOptions() (ColumnOptions, *parseError) {
return ColumnOptions{}, err
}

// TODO: Figure out if column options are case insensitive.
// We ignore case for the key (because it is easier) but not the value.
var co ColumnOptions
if p.eat("allow_commit_timestamp", "=") {
tok := p.next()
Expand Down Expand Up @@ -1512,18 +1523,10 @@ func (p *parser) parseKeyPart() (KeyPart, *parseError) {

kp := KeyPart{Column: name}

tok := p.next()
if tok.err != nil {
// End of the key_part.
p.back()
return kp, nil
}
switch tok.value {
case "ASC":
case "DESC":
if p.eat("ASC") {
// OK.
} else if p.eat("DESC") {
kp.Desc = true
default:
p.back()
}

return kp, nil
Expand Down Expand Up @@ -1679,7 +1682,7 @@ func (p *parser) parseType() (Type, *parseError) {
if tok.err != nil {
return Type{}, tok.err
}
if tok.value == "ARRAY" {
if tok.caseEqual("ARRAY") {
t.Array = true
if err := p.expect("<"); err != nil {
return Type{}, err
Expand All @@ -1689,7 +1692,7 @@ func (p *parser) parseType() (Type, *parseError) {
return Type{}, tok.err
}
}
base, ok := baseTypes[tok.value]
base, ok := baseTypes[strings.ToUpper(tok.value)] // baseTypes is keyed by upper case strings.
if !ok {
return Type{}, p.errorf("got %q, want scalar type", tok.value)
}
Expand All @@ -1704,7 +1707,7 @@ func (p *parser) parseType() (Type, *parseError) {
if tok.err != nil {
return Type{}, tok.err
}
if tok.value == "MAX" {
if tok.caseEqual("MAX") {
t.Len = MaxLen
} else if tok.typ == int64Token {
n, err := strconv.ParseInt(tok.value, tok.int64Base, 64)
Expand Down Expand Up @@ -1746,7 +1749,6 @@ func (p *parser) parseQuery() (Query, *parseError) {

// TODO: hints, sub-selects, etc.

// TODO: use a case-insensitive select.
if err := p.expect("SELECT"); err != nil {
return Query{}, err
}
Expand Down Expand Up @@ -1988,15 +1990,15 @@ func (p *parser) parseSelectFrom() (SelectFrom, *parseError) {
return sf, nil
}
var hashJoin bool // Special case for "HASH JOIN" syntax.
if tok.value == "HASH" {
if tok.caseEqual("HASH") {
hashJoin = true
tok = p.next()
if tok.err != nil {
return nil, err
}
}
var jt JoinType
if tok.value == "JOIN" {
if tok.caseEqual("JOIN") {
// This is implicitly an inner join.
jt = InnerJoin
} else if j, ok := joinKeywords[tok.value]; ok {
Expand Down Expand Up @@ -2103,9 +2105,9 @@ func (p *parser) parseTableSample() (TableSample, *parseError) {
switch {
case tok.err != nil:
return ts, tok.err
case tok.value == "BERNOULLI":
case tok.caseEqual("BERNOULLI"):
ts.Method = Bernoulli
case tok.value == "RESERVOIR":
case tok.caseEqual("RESERVOIR"):
ts.Method = Reservoir
default:
return ts, p.errorf("got %q, want BERNOULLI or RESERVOIR", tok.value)
Expand All @@ -2127,9 +2129,9 @@ func (p *parser) parseTableSample() (TableSample, *parseError) {
switch {
case tok.err != nil:
return ts, tok.err
case tok.value == "PERCENT":
case tok.caseEqual("PERCENT"):
ts.SizeType = PercentTableSample
case tok.value == "ROWS":
case tok.caseEqual("ROWS"):
ts.SizeType = RowsTableSample
default:
return ts, p.errorf("got %q, want PERCENT or ROWS", tok.value)
Expand All @@ -2153,13 +2155,10 @@ func (p *parser) parseOrder() (Order, *parseError) {
}
o := Order{Expr: expr}

tok := p.next()
switch {
case tok.err == nil && tok.value == "ASC":
case tok.err == nil && tok.value == "DESC":
if p.eat("ASC") {
// OK.
} else if p.eat("DESC") {
o.Desc = true
default:
p.back()
}

return o, nil
Expand Down Expand Up @@ -2352,9 +2351,7 @@ func (p *parser) parseIsOp() (Expr, *parseError) {
return nil, err
}

tok := p.next()
if tok.err != nil || tok.value != "IS" {
p.back()
if !p.eat("IS") {
return expr, nil
}

Expand All @@ -2363,16 +2360,16 @@ func (p *parser) parseIsOp() (Expr, *parseError) {
isOp.Neg = true
}

tok = p.next()
tok := p.next()
if tok.err != nil {
return nil, tok.err
}
switch tok.value {
case "NULL":
switch {
case tok.caseEqual("NULL"):
isOp.RHS = Null
case "TRUE":
case tok.caseEqual("TRUE"):
isOp.RHS = True
case "FALSE":
case tok.caseEqual("FALSE"):
isOp.RHS = False
default:
return nil, p.errorf("got %q, want NULL or TRUE or FALSE", tok.value)
Expand Down Expand Up @@ -2566,8 +2563,8 @@ func (p *parser) parseLit() (Expr, *parseError) {

// If the literal was an identifier, and there's an open paren next,
// this is a function invocation.
// TODO: Case-insensitivity.
if name := tok.value; funcs[name] && p.sniff("(") {
// The `funcs` map is keyed by upper case strings.
if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") {
list, err := p.parseParenExprList()
if err != nil {
return nil, err
Expand All @@ -2579,30 +2576,28 @@ func (p *parser) parseLit() (Expr, *parseError) {
}

// Handle some reserved keywords and special tokens that become specific values.
switch tok.value {
case "TRUE":
switch {
case tok.caseEqual("TRUE"):
return True, nil
case "FALSE":
case tok.caseEqual("FALSE"):
return False, nil
case "NULL":
case tok.caseEqual("NULL"):
return Null, nil
case "*":
case tok.value == "*":
return Star, nil
default:
// TODO: Check IsKeyWord(tok.value), and return a distinguished type,
// then only accept that when parsing. That will also permit
// case insensitivity for keywords.
// TODO: Check IsKeyWord(tok.value), and return a good error?
}

// Handle typed literals.
switch tok.value {
case "ARRAY", "[":
switch {
case tok.caseEqual("ARRAY") || tok.value == "[":
p.back()
return p.parseArrayLit()
case "DATE":
case tok.caseEqual("DATE"):
p.back()
return p.parseDateLit()
case "TIMESTAMP":
case tok.caseEqual("TIMESTAMP"):
p.back()
return p.parseTimestampLit()
}
Expand Down Expand Up @@ -2787,10 +2782,10 @@ func (p *parser) parseOnDelete() (OnDelete, *parseError) {
if tok.err != nil {
return 0, tok.err
}
if tok.value == "CASCADE" {
if tok.caseEqual("CASCADE") {
return CascadeOnDelete, nil
}
if tok.value != "NO" {
if !tok.caseEqual("NO") {
return 0, p.errorf("got %q, want NO or CASCADE", tok.value)
}
if err := p.expect("ACTION"); err != nil {
Expand All @@ -2801,6 +2796,7 @@ func (p *parser) parseOnDelete() (OnDelete, *parseError) {

// parseCommaList parses a comma-separated list enclosed by bra and ket,
// delegating to f for the individual element parsing.
// Only invoke this with symbols as bra/ket; they are matched literally, not case insensitively.
func (p *parser) parseCommaList(bra, ket string, f func(*parser) *parseError) *parseError {
if err := p.expect(bra); err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions spanner/spansql/parser_test.go
Expand Up @@ -32,7 +32,7 @@ func TestParseQuery(t *testing.T) {
want Query
}{
{`SELECT 17`, Query{Select: Select{List: []Expr{IntegerLiteral(17)}}}},
{`SELECT Alias AS aka FROM Characters WHERE Age < @ageLimit AND Alias IS NOT NULL ORDER BY Age DESC LIMIT @limit OFFSET 3` + "\n\t",
{`SELECT Alias AS aka From Characters WHERE Age < @ageLimit AND Alias IS NOT NULL ORDER BY Age DESC LIMIT @limit OFFSET 3` + "\n\t",
Query{
Select: Select{
List: []Expr{ID("Alias")},
Expand Down Expand Up @@ -410,7 +410,7 @@ func TestParseDDL(t *testing.T) {
-- Table with generated column.
CREATE TABLE GenCol (
Name STRING(MAX) NOT NULL,
NameLen INT64 AS (CHAR_LENGTH(Name)) STORED,
NameLen INT64 AS (char_length(Name)) STORED,
) PRIMARY KEY (Name);
-- Trailing comment at end of file.
Expand Down

0 comments on commit ddb09d2

Please sign in to comment.