diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go index b97e41a52d8..41d5094fe82 100644 --- a/spanner/spannertest/db_eval.go +++ b/spanner/spannertest/db_eval.go @@ -385,16 +385,19 @@ func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) { } func (ec evalContext) evalFunc(e spansql.Func) (interface{}, spansql.Type, error) { + var err error if f, ok := functions[e.Name]; ok { args := make([]interface{}, len(e.Args)) + types := make([]spansql.Type, len(e.Args)) for i, arg := range e.Args { - val, err := ec.evalExpr(arg) - if err != nil { + if args[i], err = ec.evalExpr(arg); err != nil { return nil, spansql.Type{}, err } - args[i] = val + if te, ok := arg.(spansql.TypedExpr); ok { + types[i] = te.Type + } } - return f.Eval(args) + return f.Eval(args, types) } return nil, spansql.Type{}, status.Errorf(codes.Unimplemented, "function %q is not implemented", e.Name) } @@ -464,6 +467,8 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) { return bool(e), nil case spansql.Paren: return ec.evalExpr(e.Expr) + case spansql.TypedExpr: + return ec.evalTypedExpr(e) case spansql.Func: v, _, err := ec.evalFunc(e) if err != nil { @@ -662,6 +667,14 @@ func (ec evalContext) coerceString(target spansql.Expr, slit spansql.StringLiter return nil, fmt.Errorf("unable to coerce string literal %q to match %v", slit, ci.Type) } +func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (result interface{}, err error) { + val, err := ec.evalExpr(expr.Expr) + if err != nil { + return nil, err + } + return convert(val, expr.Type) +} + func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) { switch v := lop.(type) { case spansql.IntegerLiteral: diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go index 6a11718cafd..a27eab5e7be 100644 --- a/spanner/spannertest/funcs.go +++ b/spanner/spannertest/funcs.go @@ -19,8 +19,11 @@ package spannertest import ( "fmt" "math" + "strconv" "strings" + "time" + "cloud.google.com/go/civil" "cloud.google.com/go/spanner/spansql" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -30,12 +33,21 @@ import ( type function struct { // Eval evaluates the result of the function using the given input. - Eval func(values []interface{}) (interface{}, spansql.Type, error) + Eval func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) +} + +func firstErr(errors []error) error { + for _, err := range errors { + if err != nil { + return err + } + } + return nil } var functions = map[string]function{ "STARTS_WITH": { - Eval: func(values []interface{}) (interface{}, spansql.Type, error) { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { // TODO: Refine error messages to exactly match Spanner. // Check input values first. if len(values) != 2 { @@ -53,7 +65,7 @@ var functions = map[string]function{ }, }, "LOWER": { - Eval: func(values []interface{}) (interface{}, spansql.Type, error) { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { if len(values) != 1 { return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function LOWER for the given argument types") } @@ -66,6 +78,159 @@ var functions = map[string]function{ return strings.ToLower(values[0].(string)), spansql.Type{Base: spansql.String}, nil }, }, + "CAST": { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { + return cast(values, types, false) + }, + }, + "SAFE_CAST": { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { + return cast(values, types, true) + }, + }, +} + +func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) { + name := "CAST" + if safe { + name = "SAFE_CAST" + } + if len(types) != 1 { + return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No type information for function %s for the given arguments", name) + } + if len(values) != 1 { + return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No matching signature for function %s for the given arguments", name) + } + // If the input type is an error, then the conversion itself failed. + if err, ok := values[0].(error); ok { + if safe { + return nil, types[0], nil + } + return nil, types[0], err + } + return values[0], types[0], nil +} + +func convert(val interface{}, tp spansql.Type) (interface{}, error) { + // TODO: Implement more conversions. + if tp.Array { + return nil, status.Errorf(codes.Unimplemented, "conversion to ARRAY types is not implemented") + } + var res interface{} + var convertErr, err error + switch tp.Base { + case spansql.Int64: + res, convertErr, err = convertToInt64(val) + case spansql.Float64: + res, convertErr, err = convertToFloat64(val) + case spansql.String: + res, convertErr, err = convertToString(val) + case spansql.Bool: + res, convertErr, err = convertToBool(val) + case spansql.Date: + res, convertErr, err = convertToDate(val) + case spansql.Timestamp: + res, convertErr, err = convertToTimestamp(val) + case spansql.Numeric: + case spansql.JSON: + } + if err != nil { + return nil, err + } + if convertErr != nil { + res = convertErr + } + if res != nil { + return res, nil + } + + return nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to %v", val, tp.Base.SQL()) +} + +func convertToInt64(val interface{}) (res int64, convertErr error, err error) { + switch v := val.(type) { + case int64: + return v, nil, nil + case string: + res, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid value for INT64: %q", v), nil + } + return res, nil, nil + } + return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to INT64", val) +} + +func convertToFloat64(val interface{}) (res float64, convertErr error, err error) { + switch v := val.(type) { + case int64: + return float64(v), nil, nil + case float64: + return v, nil, nil + case string: + res, err := strconv.ParseFloat(v, 64) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid value for FLOAT64: %q", v), nil + } + return res, nil, nil + } + return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to FLOAT64", val) +} + +func convertToString(val interface{}) (res string, convertErr error, err error) { + switch v := val.(type) { + case string: + return v, nil, nil + case bool, int64, float64: + return fmt.Sprintf("%v", v), nil, nil + case civil.Date: + return v.String(), nil, nil + case time.Time: + return v.UTC().Format(time.RFC3339Nano), nil, nil + } + return "", nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to STRING", val) +} + +func convertToBool(val interface{}) (res bool, convertErr error, err error) { + switch v := val.(type) { + case bool: + return v, nil, nil + case string: + res, err := strconv.ParseBool(v) + if err != nil { + return false, status.Errorf(codes.InvalidArgument, "invalid value for BOOL: %q", v), nil + } + return res, nil, nil + } + return false, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to BOOL", val) +} + +func convertToDate(val interface{}) (res civil.Date, convertErr error, err error) { + switch v := val.(type) { + case civil.Date: + return v, nil, nil + case string: + res, err := civil.ParseDate(v) + if err != nil { + return civil.Date{}, status.Errorf(codes.InvalidArgument, "invalid value for DATE: %q", v), nil + } + return res, nil, nil + } + return civil.Date{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to DATE", val) +} + +func convertToTimestamp(val interface{}) (res time.Time, convertErr error, err error) { + switch v := val.(type) { + case time.Time: + return v, nil, nil + case string: + res, err := time.Parse(time.RFC3339Nano, v) + if err != nil { + return time.Time{}, status.Errorf(codes.InvalidArgument, "invalid value for TIMESTAMP: %q", v), nil + } + return res, nil, nil + } + return time.Time{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to TIMESTAMP", val) } type aggregateFunc struct { diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index 21f8d27ecb1..c5d5c81a328 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -741,6 +741,13 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { t.Errorf("Updating with DML affected %d rows, want 3", n) } + rows := client.Single().Query(ctx, spanner.NewStatement("SELECT CAST('Foo' AS INT64)")) + _, err = rows.Next() + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Errorf("error code mismatch for invalid CAST\n Got: %v\nWant: %v", g, w) + } + rows.Stop() + // Do some complex queries. tests := []struct { q string @@ -748,9 +755,9 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { want [][]interface{} }{ { - `SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B')`, + `SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64)`, nil, - [][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true}}, + [][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil}}, }, // Check handling of NULL values for the IS operator. // There was a bug that returned errors for some of these cases. diff --git a/spanner/spansql/keywords.go b/spanner/spansql/keywords.go index 72f79cda933..61ac0657c49 100644 --- a/spanner/spansql/keywords.go +++ b/spanner/spansql/keywords.go @@ -128,11 +128,15 @@ var keywords = map[string]bool{ // funcs is the set of reserved keywords that are functions. // https://cloud.google.com/spanner/docs/functions-and-operators var funcs = make(map[string]bool) +var funcArgParsers = make(map[string]func(*parser) (Expr, *parseError)) func init() { for _, f := range allFuncs { funcs[f] = true } + // Special case for CAST and SAFE_CAST + funcArgParsers["CAST"] = typedArgParser + funcArgParsers["SAFE_CAST"] = typedArgParser } var allFuncs = []string{ @@ -148,6 +152,10 @@ var allFuncs = []string{ "MIN", "SUM", + // Cast functions. + "CAST", + "SAFE_CAST", + // Mathematical functions. "ABS", diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index f69ee620cc3..f4467a306ae 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -1893,8 +1893,16 @@ var baseTypes = map[string]TypeBase{ "JSON": JSON, } +func (p *parser) parseBaseType() (Type, *parseError) { + return p.parseBaseOrParameterizedType(false) +} + func (p *parser) parseType() (Type, *parseError) { - debugf("parseType: %v", p) + return p.parseBaseOrParameterizedType(true) +} + +func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) { + debugf("parseBaseOrParameterizedType: %v", p) /* array_type: @@ -1928,7 +1936,7 @@ func (p *parser) parseType() (Type, *parseError) { } t.Base = base - if t.Base == String || t.Base == Bytes { + if withParam && (t.Base == String || t.Base == Bytes) { if err := p.expect("("); err != nil { return Type{}, err } @@ -2436,9 +2444,15 @@ func (p *parser) parseExprList() ([]Expr, *parseError) { } func (p *parser) parseParenExprList() ([]Expr, *parseError) { + return p.parseParenExprListWithParseFunc(func(p *parser) (Expr, *parseError) { + return p.parseExpr() + }) +} + +func (p *parser) parseParenExprListWithParseFunc(f func(*parser) (Expr, *parseError)) ([]Expr, *parseError) { var list []Expr err := p.parseCommaList("(", ")", func(p *parser) *parseError { - e, err := p.parseExpr() + e, err := f(p) if err != nil { return err } @@ -2448,6 +2462,26 @@ func (p *parser) parseParenExprList() ([]Expr, *parseError) { return list, err } +// Special argument parser for CAST and SAFE_CAST +var typedArgParser = func(p *parser) (Expr, *parseError) { + e, err := p.parseExpr() + if err != nil { + return nil, err + } + if err := p.expect("AS"); err != nil { + return nil, err + } + // typename in cast function must not be parameterized types + toType, err := p.parseBaseType() + if err != nil { + return nil, err + } + return TypedExpr{ + Expr: e, + Type: toType, + }, nil +} + /* Expressions @@ -2800,7 +2834,13 @@ func (p *parser) parseLit() (Expr, *parseError) { // this is a function invocation. // The `funcs` map is keyed by upper case strings. if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") { - list, err := p.parseParenExprList() + var list []Expr + var err *parseError + if f, ok := funcArgParsers[name]; ok { + list, err = p.parseParenExprListWithParseFunc(f) + } else { + list, err = p.parseParenExprList() + } if err != nil { return nil, err } diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index 6aca49eff8e..966e4c5d1f0 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -121,6 +121,19 @@ func TestParseQuery(t *testing.T) { }, }, }, + {`SELECT * FROM Foo WHERE CAST(Bar AS STRING)='Bar'`, + Query{ + Select: Select{ + List: []Expr{Star}, + From: []SelectFrom{SelectFromTable{Table: "Foo"}}, + Where: ComparisonOp{ + Op: Eq, + LHS: Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}, + RHS: StringLiteral("Bar"), + }, + }, + }, + }, {`SELECT SUM(PointsScored) AS total_points, FirstName, LastName AS surname FROM PlayerStats GROUP BY FirstName, LastName`, Query{ Select: Select{ @@ -325,6 +338,8 @@ func TestParseExpr(t *testing.T) { // Functions {`STARTS_WITH(Bar, 'B')`, Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}}}, + {`CAST(Bar AS STRING)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}}, + {`SAFE_CAST(Bar AS INT64)`, Func{Name: "SAFE_CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Int64}}}}}, // String literal: // Accept double quote and single quote. @@ -506,6 +521,12 @@ func TestParseDDL(t *testing.T) { FROM Singers ORDER BY LastName, FirstName; + CREATE TABLE users ( + user_id STRING(36) NOT NULL, + some_string STRING(16) NOT NULL, + number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED, + ) PRIMARY KEY(user_id); + -- Trailing comment at end of file. `, &DDL{Filename: "filename", List: []DDLStmt{ &CreateTable{ @@ -712,6 +733,28 @@ func TestParseDDL(t *testing.T) { }, Position: line(60), }, + + // CREATE TABLE users ( + // user_id STRING(36) NOT NULL, + // some_string STRING(16) NOT NULL, + // number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED, + //) PRIMARY KEY(user_id); + &CreateTable{ + Name: "users", + Columns: []ColumnDef{ + {Name: "user_id", Type: Type{Base: String, Len: 36}, NotNull: true, Position: line(67)}, + {Name: "some_string", Type: Type{Base: String, Len: 16}, NotNull: true, Position: line(68)}, + { + Name: "number_key", Type: Type{Base: Int64}, + Generated: Func{Name: "SAFE_CAST", Args: []Expr{ + TypedExpr{Expr: Func{Name: "SUBSTR", Args: []Expr{ID("some_string"), IntegerLiteral(2)}}, Type: Type{Base: Int64}}, + }}, + Position: line(69), + }, + }, + PrimaryKey: []KeyPart{{Column: "user_id"}}, + Position: line(66), + }, }, Comments: []*Comment{ {Marker: "#", Start: line(2), End: line(2), Text: []string{"This is a comment."}}, @@ -734,7 +777,7 @@ func TestParseDDL(t *testing.T) { {Marker: "--", Isolated: true, Start: line(49), End: line(49), Text: []string{"Table with row deletion policy."}}, // Comment after everything else. - {Marker: "--", Isolated: true, Start: line(66), End: line(66), Text: []string{"Trailing comment at end of file."}}, + {Marker: "--", Isolated: true, Start: line(72), End: line(72), Text: []string{"Trailing comment at end of file."}}, }}}, // No trailing comma: {`ALTER TABLE T ADD COLUMN C2 INT64`, &DDL{Filename: "filename", List: []DDLStmt{ diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index 5032b57ad99..29c5a5dac54 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -289,7 +289,7 @@ func (c Check) SQL() string { func (t Type) SQL() string { str := t.Base.SQL() - if t.Base == String || t.Base == Bytes { + if t.Len > 0 && (t.Base == String || t.Base == Bytes) { str += "(" if t.Len == MaxLen { str += "MAX" @@ -582,6 +582,13 @@ func (f Func) addSQL(sb *strings.Builder) { sb.WriteString(")") } +func (te TypedExpr) SQL() string { return buildSQL(te) } +func (te TypedExpr) addSQL(sb *strings.Builder) { + te.Expr.addSQL(sb) + sb.WriteString(" AS ") + sb.WriteString(te.Type.SQL()) +} + func idList(l []ID, join string) string { var ss []string for _, s := range l { diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go index 9f3b13d45b4..b442506c719 100644 --- a/spanner/spansql/sql_test.go +++ b/spanner/spansql/sql_test.go @@ -465,6 +465,30 @@ func TestSQL(t *testing.T) { `SELECT 7`, reparseQuery, }, + { + Query{ + Select: Select{ + List: []Expr{Func{ + Name: "CAST", + Args: []Expr{TypedExpr{Expr: IntegerLiteral(7), Type: Type{Base: String}}}, + }}, + }, + }, + `SELECT CAST(7 AS STRING)`, + reparseQuery, + }, + { + Query{ + Select: Select{ + List: []Expr{Func{ + Name: "SAFE_CAST", + Args: []Expr{TypedExpr{Expr: IntegerLiteral(7), Type: Type{Base: Date}}}, + }}, + }, + }, + `SELECT SAFE_CAST(7 AS DATE)`, + reparseQuery, + }, { ComparisonOp{LHS: ID("X"), Op: NotBetween, RHS: ID("Y"), RHS2: ID("Z")}, `X NOT BETWEEN Y AND Z`, diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index 86b98f4d189..074e07cf0f1 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -638,6 +638,15 @@ type Func struct { func (Func) isBoolExpr() {} // possibly bool func (Func) isExpr() {} +// TypedExpr represents a typed expression in the form `expr AS type_name`, e.g. `'17' AS INT64`. +type TypedExpr struct { + Type Type + Expr Expr +} + +func (TypedExpr) isBoolExpr() {} // possibly bool +func (TypedExpr) isExpr() {} + // Paren represents a parenthesised expression. type Paren struct { Expr Expr