From 54cbf4c0a0305e680b213f84487110dfeaf8e7e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 3 Nov 2021 08:49:16 +0100 Subject: [PATCH] feat(spanner/spansql): support CAST and SAFE_CAST (#5057) * feat(spanner/spansql): support CAST and SAFE_CAST Adds support for the CAST and SAFE_CAST functions. This change also includes a small refactor of the function evaluation to allow the function evaluation to receive both parsing errors and type information for the arguments. This makes it easier to implement the SAFE versions of the functions. * fix: only ignore conversion errors in SAFE_CAST * fix: outdent else statement Co-authored-by: rahul2393 Co-authored-by: Hengfeng Li --- spanner/spannertest/db_eval.go | 21 ++- spanner/spannertest/funcs.go | 171 +++++++++++++++++++++++- spanner/spannertest/integration_test.go | 11 +- spanner/spansql/keywords.go | 8 ++ spanner/spansql/parser.go | 48 ++++++- spanner/spansql/parser_test.go | 45 ++++++- spanner/spansql/sql.go | 9 +- spanner/spansql/sql_test.go | 24 ++++ spanner/spansql/types.go | 9 ++ 9 files changed, 331 insertions(+), 15 deletions(-) diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go index b97e41a52d8..41d5094fe82 100644 --- a/spanner/spannertest/db_eval.go +++ b/spanner/spannertest/db_eval.go @@ -385,16 +385,19 @@ func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) { } func (ec evalContext) evalFunc(e spansql.Func) (interface{}, spansql.Type, error) { + var err error if f, ok := functions[e.Name]; ok { args := make([]interface{}, len(e.Args)) + types := make([]spansql.Type, len(e.Args)) for i, arg := range e.Args { - val, err := ec.evalExpr(arg) - if err != nil { + if args[i], err = ec.evalExpr(arg); err != nil { return nil, spansql.Type{}, err } - args[i] = val + if te, ok := arg.(spansql.TypedExpr); ok { + types[i] = te.Type + } } - return f.Eval(args) + return f.Eval(args, types) } return nil, spansql.Type{}, status.Errorf(codes.Unimplemented, "function %q is not implemented", e.Name) } @@ -464,6 +467,8 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) { return bool(e), nil case spansql.Paren: return ec.evalExpr(e.Expr) + case spansql.TypedExpr: + return ec.evalTypedExpr(e) case spansql.Func: v, _, err := ec.evalFunc(e) if err != nil { @@ -662,6 +667,14 @@ func (ec evalContext) coerceString(target spansql.Expr, slit spansql.StringLiter return nil, fmt.Errorf("unable to coerce string literal %q to match %v", slit, ci.Type) } +func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (result interface{}, err error) { + val, err := ec.evalExpr(expr.Expr) + if err != nil { + return nil, err + } + return convert(val, expr.Type) +} + func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) { switch v := lop.(type) { case spansql.IntegerLiteral: diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go index 6a11718cafd..a27eab5e7be 100644 --- a/spanner/spannertest/funcs.go +++ b/spanner/spannertest/funcs.go @@ -19,8 +19,11 @@ package spannertest import ( "fmt" "math" + "strconv" "strings" + "time" + "cloud.google.com/go/civil" "cloud.google.com/go/spanner/spansql" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -30,12 +33,21 @@ import ( type function struct { // Eval evaluates the result of the function using the given input. - Eval func(values []interface{}) (interface{}, spansql.Type, error) + Eval func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) +} + +func firstErr(errors []error) error { + for _, err := range errors { + if err != nil { + return err + } + } + return nil } var functions = map[string]function{ "STARTS_WITH": { - Eval: func(values []interface{}) (interface{}, spansql.Type, error) { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { // TODO: Refine error messages to exactly match Spanner. // Check input values first. if len(values) != 2 { @@ -53,7 +65,7 @@ var functions = map[string]function{ }, }, "LOWER": { - Eval: func(values []interface{}) (interface{}, spansql.Type, error) { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { if len(values) != 1 { return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function LOWER for the given argument types") } @@ -66,6 +78,159 @@ var functions = map[string]function{ return strings.ToLower(values[0].(string)), spansql.Type{Base: spansql.String}, nil }, }, + "CAST": { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { + return cast(values, types, false) + }, + }, + "SAFE_CAST": { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { + return cast(values, types, true) + }, + }, +} + +func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) { + name := "CAST" + if safe { + name = "SAFE_CAST" + } + if len(types) != 1 { + return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No type information for function %s for the given arguments", name) + } + if len(values) != 1 { + return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No matching signature for function %s for the given arguments", name) + } + // If the input type is an error, then the conversion itself failed. + if err, ok := values[0].(error); ok { + if safe { + return nil, types[0], nil + } + return nil, types[0], err + } + return values[0], types[0], nil +} + +func convert(val interface{}, tp spansql.Type) (interface{}, error) { + // TODO: Implement more conversions. + if tp.Array { + return nil, status.Errorf(codes.Unimplemented, "conversion to ARRAY types is not implemented") + } + var res interface{} + var convertErr, err error + switch tp.Base { + case spansql.Int64: + res, convertErr, err = convertToInt64(val) + case spansql.Float64: + res, convertErr, err = convertToFloat64(val) + case spansql.String: + res, convertErr, err = convertToString(val) + case spansql.Bool: + res, convertErr, err = convertToBool(val) + case spansql.Date: + res, convertErr, err = convertToDate(val) + case spansql.Timestamp: + res, convertErr, err = convertToTimestamp(val) + case spansql.Numeric: + case spansql.JSON: + } + if err != nil { + return nil, err + } + if convertErr != nil { + res = convertErr + } + if res != nil { + return res, nil + } + + return nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to %v", val, tp.Base.SQL()) +} + +func convertToInt64(val interface{}) (res int64, convertErr error, err error) { + switch v := val.(type) { + case int64: + return v, nil, nil + case string: + res, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid value for INT64: %q", v), nil + } + return res, nil, nil + } + return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to INT64", val) +} + +func convertToFloat64(val interface{}) (res float64, convertErr error, err error) { + switch v := val.(type) { + case int64: + return float64(v), nil, nil + case float64: + return v, nil, nil + case string: + res, err := strconv.ParseFloat(v, 64) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid value for FLOAT64: %q", v), nil + } + return res, nil, nil + } + return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to FLOAT64", val) +} + +func convertToString(val interface{}) (res string, convertErr error, err error) { + switch v := val.(type) { + case string: + return v, nil, nil + case bool, int64, float64: + return fmt.Sprintf("%v", v), nil, nil + case civil.Date: + return v.String(), nil, nil + case time.Time: + return v.UTC().Format(time.RFC3339Nano), nil, nil + } + return "", nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to STRING", val) +} + +func convertToBool(val interface{}) (res bool, convertErr error, err error) { + switch v := val.(type) { + case bool: + return v, nil, nil + case string: + res, err := strconv.ParseBool(v) + if err != nil { + return false, status.Errorf(codes.InvalidArgument, "invalid value for BOOL: %q", v), nil + } + return res, nil, nil + } + return false, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to BOOL", val) +} + +func convertToDate(val interface{}) (res civil.Date, convertErr error, err error) { + switch v := val.(type) { + case civil.Date: + return v, nil, nil + case string: + res, err := civil.ParseDate(v) + if err != nil { + return civil.Date{}, status.Errorf(codes.InvalidArgument, "invalid value for DATE: %q", v), nil + } + return res, nil, nil + } + return civil.Date{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to DATE", val) +} + +func convertToTimestamp(val interface{}) (res time.Time, convertErr error, err error) { + switch v := val.(type) { + case time.Time: + return v, nil, nil + case string: + res, err := time.Parse(time.RFC3339Nano, v) + if err != nil { + return time.Time{}, status.Errorf(codes.InvalidArgument, "invalid value for TIMESTAMP: %q", v), nil + } + return res, nil, nil + } + return time.Time{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to TIMESTAMP", val) } type aggregateFunc struct { diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index 21f8d27ecb1..c5d5c81a328 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -741,6 +741,13 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { t.Errorf("Updating with DML affected %d rows, want 3", n) } + rows := client.Single().Query(ctx, spanner.NewStatement("SELECT CAST('Foo' AS INT64)")) + _, err = rows.Next() + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Errorf("error code mismatch for invalid CAST\n Got: %v\nWant: %v", g, w) + } + rows.Stop() + // Do some complex queries. tests := []struct { q string @@ -748,9 +755,9 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { want [][]interface{} }{ { - `SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B')`, + `SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64)`, nil, - [][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true}}, + [][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil}}, }, // Check handling of NULL values for the IS operator. // There was a bug that returned errors for some of these cases. diff --git a/spanner/spansql/keywords.go b/spanner/spansql/keywords.go index 72f79cda933..61ac0657c49 100644 --- a/spanner/spansql/keywords.go +++ b/spanner/spansql/keywords.go @@ -128,11 +128,15 @@ var keywords = map[string]bool{ // funcs is the set of reserved keywords that are functions. // https://cloud.google.com/spanner/docs/functions-and-operators var funcs = make(map[string]bool) +var funcArgParsers = make(map[string]func(*parser) (Expr, *parseError)) func init() { for _, f := range allFuncs { funcs[f] = true } + // Special case for CAST and SAFE_CAST + funcArgParsers["CAST"] = typedArgParser + funcArgParsers["SAFE_CAST"] = typedArgParser } var allFuncs = []string{ @@ -148,6 +152,10 @@ var allFuncs = []string{ "MIN", "SUM", + // Cast functions. + "CAST", + "SAFE_CAST", + // Mathematical functions. "ABS", diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index f69ee620cc3..f4467a306ae 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -1893,8 +1893,16 @@ var baseTypes = map[string]TypeBase{ "JSON": JSON, } +func (p *parser) parseBaseType() (Type, *parseError) { + return p.parseBaseOrParameterizedType(false) +} + func (p *parser) parseType() (Type, *parseError) { - debugf("parseType: %v", p) + return p.parseBaseOrParameterizedType(true) +} + +func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) { + debugf("parseBaseOrParameterizedType: %v", p) /* array_type: @@ -1928,7 +1936,7 @@ func (p *parser) parseType() (Type, *parseError) { } t.Base = base - if t.Base == String || t.Base == Bytes { + if withParam && (t.Base == String || t.Base == Bytes) { if err := p.expect("("); err != nil { return Type{}, err } @@ -2436,9 +2444,15 @@ func (p *parser) parseExprList() ([]Expr, *parseError) { } func (p *parser) parseParenExprList() ([]Expr, *parseError) { + return p.parseParenExprListWithParseFunc(func(p *parser) (Expr, *parseError) { + return p.parseExpr() + }) +} + +func (p *parser) parseParenExprListWithParseFunc(f func(*parser) (Expr, *parseError)) ([]Expr, *parseError) { var list []Expr err := p.parseCommaList("(", ")", func(p *parser) *parseError { - e, err := p.parseExpr() + e, err := f(p) if err != nil { return err } @@ -2448,6 +2462,26 @@ func (p *parser) parseParenExprList() ([]Expr, *parseError) { return list, err } +// Special argument parser for CAST and SAFE_CAST +var typedArgParser = func(p *parser) (Expr, *parseError) { + e, err := p.parseExpr() + if err != nil { + return nil, err + } + if err := p.expect("AS"); err != nil { + return nil, err + } + // typename in cast function must not be parameterized types + toType, err := p.parseBaseType() + if err != nil { + return nil, err + } + return TypedExpr{ + Expr: e, + Type: toType, + }, nil +} + /* Expressions @@ -2800,7 +2834,13 @@ func (p *parser) parseLit() (Expr, *parseError) { // this is a function invocation. // The `funcs` map is keyed by upper case strings. if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") { - list, err := p.parseParenExprList() + var list []Expr + var err *parseError + if f, ok := funcArgParsers[name]; ok { + list, err = p.parseParenExprListWithParseFunc(f) + } else { + list, err = p.parseParenExprList() + } if err != nil { return nil, err } diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index 6aca49eff8e..966e4c5d1f0 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -121,6 +121,19 @@ func TestParseQuery(t *testing.T) { }, }, }, + {`SELECT * FROM Foo WHERE CAST(Bar AS STRING)='Bar'`, + Query{ + Select: Select{ + List: []Expr{Star}, + From: []SelectFrom{SelectFromTable{Table: "Foo"}}, + Where: ComparisonOp{ + Op: Eq, + LHS: Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}, + RHS: StringLiteral("Bar"), + }, + }, + }, + }, {`SELECT SUM(PointsScored) AS total_points, FirstName, LastName AS surname FROM PlayerStats GROUP BY FirstName, LastName`, Query{ Select: Select{ @@ -325,6 +338,8 @@ func TestParseExpr(t *testing.T) { // Functions {`STARTS_WITH(Bar, 'B')`, Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}}}, + {`CAST(Bar AS STRING)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}}, + {`SAFE_CAST(Bar AS INT64)`, Func{Name: "SAFE_CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Int64}}}}}, // String literal: // Accept double quote and single quote. @@ -506,6 +521,12 @@ func TestParseDDL(t *testing.T) { FROM Singers ORDER BY LastName, FirstName; + CREATE TABLE users ( + user_id STRING(36) NOT NULL, + some_string STRING(16) NOT NULL, + number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED, + ) PRIMARY KEY(user_id); + -- Trailing comment at end of file. `, &DDL{Filename: "filename", List: []DDLStmt{ &CreateTable{ @@ -712,6 +733,28 @@ func TestParseDDL(t *testing.T) { }, Position: line(60), }, + + // CREATE TABLE users ( + // user_id STRING(36) NOT NULL, + // some_string STRING(16) NOT NULL, + // number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED, + //) PRIMARY KEY(user_id); + &CreateTable{ + Name: "users", + Columns: []ColumnDef{ + {Name: "user_id", Type: Type{Base: String, Len: 36}, NotNull: true, Position: line(67)}, + {Name: "some_string", Type: Type{Base: String, Len: 16}, NotNull: true, Position: line(68)}, + { + Name: "number_key", Type: Type{Base: Int64}, + Generated: Func{Name: "SAFE_CAST", Args: []Expr{ + TypedExpr{Expr: Func{Name: "SUBSTR", Args: []Expr{ID("some_string"), IntegerLiteral(2)}}, Type: Type{Base: Int64}}, + }}, + Position: line(69), + }, + }, + PrimaryKey: []KeyPart{{Column: "user_id"}}, + Position: line(66), + }, }, Comments: []*Comment{ {Marker: "#", Start: line(2), End: line(2), Text: []string{"This is a comment."}}, @@ -734,7 +777,7 @@ func TestParseDDL(t *testing.T) { {Marker: "--", Isolated: true, Start: line(49), End: line(49), Text: []string{"Table with row deletion policy."}}, // Comment after everything else. - {Marker: "--", Isolated: true, Start: line(66), End: line(66), Text: []string{"Trailing comment at end of file."}}, + {Marker: "--", Isolated: true, Start: line(72), End: line(72), Text: []string{"Trailing comment at end of file."}}, }}}, // No trailing comma: {`ALTER TABLE T ADD COLUMN C2 INT64`, &DDL{Filename: "filename", List: []DDLStmt{ diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index 5032b57ad99..29c5a5dac54 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -289,7 +289,7 @@ func (c Check) SQL() string { func (t Type) SQL() string { str := t.Base.SQL() - if t.Base == String || t.Base == Bytes { + if t.Len > 0 && (t.Base == String || t.Base == Bytes) { str += "(" if t.Len == MaxLen { str += "MAX" @@ -582,6 +582,13 @@ func (f Func) addSQL(sb *strings.Builder) { sb.WriteString(")") } +func (te TypedExpr) SQL() string { return buildSQL(te) } +func (te TypedExpr) addSQL(sb *strings.Builder) { + te.Expr.addSQL(sb) + sb.WriteString(" AS ") + sb.WriteString(te.Type.SQL()) +} + func idList(l []ID, join string) string { var ss []string for _, s := range l { diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go index 9f3b13d45b4..b442506c719 100644 --- a/spanner/spansql/sql_test.go +++ b/spanner/spansql/sql_test.go @@ -465,6 +465,30 @@ func TestSQL(t *testing.T) { `SELECT 7`, reparseQuery, }, + { + Query{ + Select: Select{ + List: []Expr{Func{ + Name: "CAST", + Args: []Expr{TypedExpr{Expr: IntegerLiteral(7), Type: Type{Base: String}}}, + }}, + }, + }, + `SELECT CAST(7 AS STRING)`, + reparseQuery, + }, + { + Query{ + Select: Select{ + List: []Expr{Func{ + Name: "SAFE_CAST", + Args: []Expr{TypedExpr{Expr: IntegerLiteral(7), Type: Type{Base: Date}}}, + }}, + }, + }, + `SELECT SAFE_CAST(7 AS DATE)`, + reparseQuery, + }, { ComparisonOp{LHS: ID("X"), Op: NotBetween, RHS: ID("Y"), RHS2: ID("Z")}, `X NOT BETWEEN Y AND Z`, diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index 86b98f4d189..074e07cf0f1 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -638,6 +638,15 @@ type Func struct { func (Func) isBoolExpr() {} // possibly bool func (Func) isExpr() {} +// TypedExpr represents a typed expression in the form `expr AS type_name`, e.g. `'17' AS INT64`. +type TypedExpr struct { + Type Type + Expr Expr +} + +func (TypedExpr) isBoolExpr() {} // possibly bool +func (TypedExpr) isExpr() {} + // Paren represents a parenthesised expression. type Paren struct { Expr Expr