From 6e5d779c0b91577dfdf4c1c8482e2daa63dc90fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 1 Nov 2021 12:48:18 +0100 Subject: [PATCH 1/3] feat(spanner/spansql): support CAST and SAFE_CAST Adds support for the CAST and SAFE_CAST functions. This change also includes a small refactor of the function evaluation to allow the function evaluation to receive both parsing errors and type information for the arguments. This makes it easier to implement the SAFE versions of the functions. --- spanner/spannertest/db_eval.go | 21 ++- spanner/spannertest/funcs.go | 166 +++++++++++++++++++++++- spanner/spannertest/integration_test.go | 11 +- spanner/spansql/keywords.go | 8 ++ spanner/spansql/parser.go | 48 ++++++- spanner/spansql/parser_test.go | 45 ++++++- spanner/spansql/sql.go | 9 +- spanner/spansql/sql_test.go | 24 ++++ spanner/spansql/types.go | 9 ++ 9 files changed, 325 insertions(+), 16 deletions(-) diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go index b97e41a52d8..aa0b6863420 100644 --- a/spanner/spannertest/db_eval.go +++ b/spanner/spannertest/db_eval.go @@ -387,14 +387,15 @@ func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) { func (ec evalContext) evalFunc(e spansql.Func) (interface{}, spansql.Type, error) { if f, ok := functions[e.Name]; ok { args := make([]interface{}, len(e.Args)) + types := make([]spansql.Type, len(e.Args)) + errs := make([]error, len(e.Args)) for i, arg := range e.Args { - val, err := ec.evalExpr(arg) - if err != nil { - return nil, spansql.Type{}, err + args[i], errs[i] = ec.evalExpr(arg) + if te, ok := arg.(spansql.TypedExpr); ok { + types[i] = te.Type } - args[i] = val } - return f.Eval(args) + return f.Eval(args, types, errs) } return nil, spansql.Type{}, status.Errorf(codes.Unimplemented, "function %q is not implemented", e.Name) } @@ -464,6 +465,8 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) { return bool(e), nil case spansql.Paren: return ec.evalExpr(e.Expr) + case spansql.TypedExpr: + return ec.evalTypedExpr(e) case spansql.Func: v, _, err := ec.evalFunc(e) if err != nil { @@ -662,6 +665,14 @@ func (ec evalContext) coerceString(target spansql.Expr, slit spansql.StringLiter return nil, fmt.Errorf("unable to coerce string literal %q to match %v", slit, ci.Type) } +func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (interface{}, error) { + val, err := ec.evalExpr(expr.Expr) + if err != nil { + return nil, err + } + return convert(val, expr.Type) +} + func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) { switch v := lop.(type) { case spansql.IntegerLiteral: diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go index 6a11718cafd..8a3f3c5e899 100644 --- a/spanner/spannertest/funcs.go +++ b/spanner/spannertest/funcs.go @@ -19,8 +19,11 @@ package spannertest import ( "fmt" "math" + "strconv" "strings" + "time" + "cloud.google.com/go/civil" "cloud.google.com/go/spanner/spansql" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -30,12 +33,24 @@ import ( type function struct { // Eval evaluates the result of the function using the given input. - Eval func(values []interface{}) (interface{}, spansql.Type, error) + Eval func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) +} + +func firstErr(errors []error) error { + for _, err := range errors { + if err != nil { + return err + } + } + return nil } var functions = map[string]function{ "STARTS_WITH": { - Eval: func(values []interface{}) (interface{}, spansql.Type, error) { + Eval: func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) { + if err := firstErr(errors); err != nil { + return nil, spansql.Type{}, err + } // TODO: Refine error messages to exactly match Spanner. // Check input values first. if len(values) != 2 { @@ -53,7 +68,10 @@ var functions = map[string]function{ }, }, "LOWER": { - Eval: func(values []interface{}) (interface{}, spansql.Type, error) { + Eval: func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) { + if err := firstErr(errors); err != nil { + return nil, spansql.Type{}, err + } if len(values) != 1 { return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function LOWER for the given argument types") } @@ -66,6 +84,148 @@ var functions = map[string]function{ return strings.ToLower(values[0].(string)), spansql.Type{Base: spansql.String}, nil }, }, + "CAST": { + Eval: func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) { + return cast(values, types, errors, false) + }, + }, + "SAFE_CAST": { + Eval: func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) { + return cast(values, types, errors, true) + }, + }, +} + +func cast(values []interface{}, types []spansql.Type, errors []error, safe bool) (interface{}, spansql.Type, error) { + name := "CAST" + if safe { + name = "SAFE_CAST" + } + if len(types) != 1 { + return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No type information for function %s for the given arguments", name) + } + if len(values) != 1 { + return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No matching signature for function %s for the given arguments", name) + } + if err := firstErr(errors); err != nil { + if safe { + return nil, types[0], nil + } else { + return nil, types[0], err + } + } + return values[0], types[0], nil +} + +func convert(val interface{}, tp spansql.Type) (interface{}, error) { + // TODO: Implement more conversions. + if tp.Array { + return nil, status.Errorf(codes.Unimplemented, "conversion to ARRAY types is not implemented") + } + switch tp.Base { + case spansql.Int64: + return convertToInt64(val) + case spansql.Float64: + return convertToFloat64(val) + case spansql.String: + return convertToString(val) + case spansql.Bool: + return convertToBool(val) + case spansql.Date: + return convertToDate(val) + case spansql.Timestamp: + return convertToTimestamp(val) + case spansql.Numeric: + case spansql.JSON: + } + + return nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to %v", val, tp.Base.SQL()) +} + +func convertToInt64(val interface{}) (int64, error) { + switch v := val.(type) { + case int64: + return v, nil + case string: + res, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid value for INT64: %q", v) + } + return res, nil + } + return 0, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to INT64", val) +} + +func convertToFloat64(val interface{}) (float64, error) { + switch v := val.(type) { + case int64: + return float64(v), nil + case float64: + return v, nil + case string: + res, err := strconv.ParseFloat(v, 64) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid value for FLOAT64: %q", v) + } + return res, nil + } + return 0, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to FLOAT64", val) +} + +func convertToString(val interface{}) (string, error) { + switch v := val.(type) { + case string: + return v, nil + case bool, int64, float64: + return fmt.Sprintf("%v", v), nil + case civil.Date: + return v.String(), nil + case time.Time: + return v.UTC().Format(time.RFC3339Nano), nil + } + return "", status.Errorf(codes.Unimplemented, "unsupported conversion for %v to STRING", val) +} + +func convertToBool(val interface{}) (bool, error) { + switch v := val.(type) { + case bool: + return v, nil + case string: + res, err := strconv.ParseBool(v) + if err != nil { + return false, status.Errorf(codes.InvalidArgument, "invalid value for BOOL: %q", v) + } + return res, nil + } + return false, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to BOOL", val) +} + +func convertToDate(val interface{}) (civil.Date, error) { + switch v := val.(type) { + case civil.Date: + return v, nil + case string: + res, err := civil.ParseDate(v) + if err != nil { + return civil.Date{}, status.Errorf(codes.InvalidArgument, "invalid value for DATE: %q", v) + } + return res, nil + } + return civil.Date{}, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to DATE", val) +} + +func convertToTimestamp(val interface{}) (time.Time, error) { + switch v := val.(type) { + case time.Time: + return v, nil + case string: + res, err := time.Parse(time.RFC3339Nano, v) + if err != nil { + return time.Time{}, status.Errorf(codes.InvalidArgument, "invalid value for TIMESTAMP: %q", v) + } + return res, nil + } + return time.Time{}, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to TIMESTAMP", val) } type aggregateFunc struct { diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index 21f8d27ecb1..c5d5c81a328 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -741,6 +741,13 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { t.Errorf("Updating with DML affected %d rows, want 3", n) } + rows := client.Single().Query(ctx, spanner.NewStatement("SELECT CAST('Foo' AS INT64)")) + _, err = rows.Next() + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Errorf("error code mismatch for invalid CAST\n Got: %v\nWant: %v", g, w) + } + rows.Stop() + // Do some complex queries. tests := []struct { q string @@ -748,9 +755,9 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { want [][]interface{} }{ { - `SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B')`, + `SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64)`, nil, - [][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true}}, + [][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil}}, }, // Check handling of NULL values for the IS operator. // There was a bug that returned errors for some of these cases. diff --git a/spanner/spansql/keywords.go b/spanner/spansql/keywords.go index 72f79cda933..61ac0657c49 100644 --- a/spanner/spansql/keywords.go +++ b/spanner/spansql/keywords.go @@ -128,11 +128,15 @@ var keywords = map[string]bool{ // funcs is the set of reserved keywords that are functions. // https://cloud.google.com/spanner/docs/functions-and-operators var funcs = make(map[string]bool) +var funcArgParsers = make(map[string]func(*parser) (Expr, *parseError)) func init() { for _, f := range allFuncs { funcs[f] = true } + // Special case for CAST and SAFE_CAST + funcArgParsers["CAST"] = typedArgParser + funcArgParsers["SAFE_CAST"] = typedArgParser } var allFuncs = []string{ @@ -148,6 +152,10 @@ var allFuncs = []string{ "MIN", "SUM", + // Cast functions. + "CAST", + "SAFE_CAST", + // Mathematical functions. "ABS", diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index f69ee620cc3..f4467a306ae 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -1893,8 +1893,16 @@ var baseTypes = map[string]TypeBase{ "JSON": JSON, } +func (p *parser) parseBaseType() (Type, *parseError) { + return p.parseBaseOrParameterizedType(false) +} + func (p *parser) parseType() (Type, *parseError) { - debugf("parseType: %v", p) + return p.parseBaseOrParameterizedType(true) +} + +func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) { + debugf("parseBaseOrParameterizedType: %v", p) /* array_type: @@ -1928,7 +1936,7 @@ func (p *parser) parseType() (Type, *parseError) { } t.Base = base - if t.Base == String || t.Base == Bytes { + if withParam && (t.Base == String || t.Base == Bytes) { if err := p.expect("("); err != nil { return Type{}, err } @@ -2436,9 +2444,15 @@ func (p *parser) parseExprList() ([]Expr, *parseError) { } func (p *parser) parseParenExprList() ([]Expr, *parseError) { + return p.parseParenExprListWithParseFunc(func(p *parser) (Expr, *parseError) { + return p.parseExpr() + }) +} + +func (p *parser) parseParenExprListWithParseFunc(f func(*parser) (Expr, *parseError)) ([]Expr, *parseError) { var list []Expr err := p.parseCommaList("(", ")", func(p *parser) *parseError { - e, err := p.parseExpr() + e, err := f(p) if err != nil { return err } @@ -2448,6 +2462,26 @@ func (p *parser) parseParenExprList() ([]Expr, *parseError) { return list, err } +// Special argument parser for CAST and SAFE_CAST +var typedArgParser = func(p *parser) (Expr, *parseError) { + e, err := p.parseExpr() + if err != nil { + return nil, err + } + if err := p.expect("AS"); err != nil { + return nil, err + } + // typename in cast function must not be parameterized types + toType, err := p.parseBaseType() + if err != nil { + return nil, err + } + return TypedExpr{ + Expr: e, + Type: toType, + }, nil +} + /* Expressions @@ -2800,7 +2834,13 @@ func (p *parser) parseLit() (Expr, *parseError) { // this is a function invocation. // The `funcs` map is keyed by upper case strings. if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") { - list, err := p.parseParenExprList() + var list []Expr + var err *parseError + if f, ok := funcArgParsers[name]; ok { + list, err = p.parseParenExprListWithParseFunc(f) + } else { + list, err = p.parseParenExprList() + } if err != nil { return nil, err } diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index 6aca49eff8e..966e4c5d1f0 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -121,6 +121,19 @@ func TestParseQuery(t *testing.T) { }, }, }, + {`SELECT * FROM Foo WHERE CAST(Bar AS STRING)='Bar'`, + Query{ + Select: Select{ + List: []Expr{Star}, + From: []SelectFrom{SelectFromTable{Table: "Foo"}}, + Where: ComparisonOp{ + Op: Eq, + LHS: Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}, + RHS: StringLiteral("Bar"), + }, + }, + }, + }, {`SELECT SUM(PointsScored) AS total_points, FirstName, LastName AS surname FROM PlayerStats GROUP BY FirstName, LastName`, Query{ Select: Select{ @@ -325,6 +338,8 @@ func TestParseExpr(t *testing.T) { // Functions {`STARTS_WITH(Bar, 'B')`, Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}}}, + {`CAST(Bar AS STRING)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}}, + {`SAFE_CAST(Bar AS INT64)`, Func{Name: "SAFE_CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Int64}}}}}, // String literal: // Accept double quote and single quote. @@ -506,6 +521,12 @@ func TestParseDDL(t *testing.T) { FROM Singers ORDER BY LastName, FirstName; + CREATE TABLE users ( + user_id STRING(36) NOT NULL, + some_string STRING(16) NOT NULL, + number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED, + ) PRIMARY KEY(user_id); + -- Trailing comment at end of file. `, &DDL{Filename: "filename", List: []DDLStmt{ &CreateTable{ @@ -712,6 +733,28 @@ func TestParseDDL(t *testing.T) { }, Position: line(60), }, + + // CREATE TABLE users ( + // user_id STRING(36) NOT NULL, + // some_string STRING(16) NOT NULL, + // number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED, + //) PRIMARY KEY(user_id); + &CreateTable{ + Name: "users", + Columns: []ColumnDef{ + {Name: "user_id", Type: Type{Base: String, Len: 36}, NotNull: true, Position: line(67)}, + {Name: "some_string", Type: Type{Base: String, Len: 16}, NotNull: true, Position: line(68)}, + { + Name: "number_key", Type: Type{Base: Int64}, + Generated: Func{Name: "SAFE_CAST", Args: []Expr{ + TypedExpr{Expr: Func{Name: "SUBSTR", Args: []Expr{ID("some_string"), IntegerLiteral(2)}}, Type: Type{Base: Int64}}, + }}, + Position: line(69), + }, + }, + PrimaryKey: []KeyPart{{Column: "user_id"}}, + Position: line(66), + }, }, Comments: []*Comment{ {Marker: "#", Start: line(2), End: line(2), Text: []string{"This is a comment."}}, @@ -734,7 +777,7 @@ func TestParseDDL(t *testing.T) { {Marker: "--", Isolated: true, Start: line(49), End: line(49), Text: []string{"Table with row deletion policy."}}, // Comment after everything else. - {Marker: "--", Isolated: true, Start: line(66), End: line(66), Text: []string{"Trailing comment at end of file."}}, + {Marker: "--", Isolated: true, Start: line(72), End: line(72), Text: []string{"Trailing comment at end of file."}}, }}}, // No trailing comma: {`ALTER TABLE T ADD COLUMN C2 INT64`, &DDL{Filename: "filename", List: []DDLStmt{ diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index 5032b57ad99..29c5a5dac54 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -289,7 +289,7 @@ func (c Check) SQL() string { func (t Type) SQL() string { str := t.Base.SQL() - if t.Base == String || t.Base == Bytes { + if t.Len > 0 && (t.Base == String || t.Base == Bytes) { str += "(" if t.Len == MaxLen { str += "MAX" @@ -582,6 +582,13 @@ func (f Func) addSQL(sb *strings.Builder) { sb.WriteString(")") } +func (te TypedExpr) SQL() string { return buildSQL(te) } +func (te TypedExpr) addSQL(sb *strings.Builder) { + te.Expr.addSQL(sb) + sb.WriteString(" AS ") + sb.WriteString(te.Type.SQL()) +} + func idList(l []ID, join string) string { var ss []string for _, s := range l { diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go index 9f3b13d45b4..b442506c719 100644 --- a/spanner/spansql/sql_test.go +++ b/spanner/spansql/sql_test.go @@ -465,6 +465,30 @@ func TestSQL(t *testing.T) { `SELECT 7`, reparseQuery, }, + { + Query{ + Select: Select{ + List: []Expr{Func{ + Name: "CAST", + Args: []Expr{TypedExpr{Expr: IntegerLiteral(7), Type: Type{Base: String}}}, + }}, + }, + }, + `SELECT CAST(7 AS STRING)`, + reparseQuery, + }, + { + Query{ + Select: Select{ + List: []Expr{Func{ + Name: "SAFE_CAST", + Args: []Expr{TypedExpr{Expr: IntegerLiteral(7), Type: Type{Base: Date}}}, + }}, + }, + }, + `SELECT SAFE_CAST(7 AS DATE)`, + reparseQuery, + }, { ComparisonOp{LHS: ID("X"), Op: NotBetween, RHS: ID("Y"), RHS2: ID("Z")}, `X NOT BETWEEN Y AND Z`, diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index 86b98f4d189..074e07cf0f1 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -638,6 +638,15 @@ type Func struct { func (Func) isBoolExpr() {} // possibly bool func (Func) isExpr() {} +// TypedExpr represents a typed expression in the form `expr AS type_name`, e.g. `'17' AS INT64`. +type TypedExpr struct { + Type Type + Expr Expr +} + +func (TypedExpr) isBoolExpr() {} // possibly bool +func (TypedExpr) isExpr() {} + // Paren represents a parenthesised expression. type Paren struct { Expr Expr From e714617d25d6d007507b85c06d663cead048784a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 1 Nov 2021 13:30:29 +0100 Subject: [PATCH 2/3] fix: only ignore conversion errors in SAFE_CAST --- spanner/spannertest/db_eval.go | 10 +-- spanner/spannertest/funcs.go | 112 +++++++++++++++++---------------- 2 files changed, 65 insertions(+), 57 deletions(-) diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go index aa0b6863420..41d5094fe82 100644 --- a/spanner/spannertest/db_eval.go +++ b/spanner/spannertest/db_eval.go @@ -385,17 +385,19 @@ func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) { } func (ec evalContext) evalFunc(e spansql.Func) (interface{}, spansql.Type, error) { + var err error if f, ok := functions[e.Name]; ok { args := make([]interface{}, len(e.Args)) types := make([]spansql.Type, len(e.Args)) - errs := make([]error, len(e.Args)) for i, arg := range e.Args { - args[i], errs[i] = ec.evalExpr(arg) + if args[i], err = ec.evalExpr(arg); err != nil { + return nil, spansql.Type{}, err + } if te, ok := arg.(spansql.TypedExpr); ok { types[i] = te.Type } } - return f.Eval(args, types, errs) + return f.Eval(args, types) } return nil, spansql.Type{}, status.Errorf(codes.Unimplemented, "function %q is not implemented", e.Name) } @@ -665,7 +667,7 @@ func (ec evalContext) coerceString(target spansql.Expr, slit spansql.StringLiter return nil, fmt.Errorf("unable to coerce string literal %q to match %v", slit, ci.Type) } -func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (interface{}, error) { +func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (result interface{}, err error) { val, err := ec.evalExpr(expr.Expr) if err != nil { return nil, err diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go index 8a3f3c5e899..e2f6889d454 100644 --- a/spanner/spannertest/funcs.go +++ b/spanner/spannertest/funcs.go @@ -33,7 +33,7 @@ import ( type function struct { // Eval evaluates the result of the function using the given input. - Eval func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) + Eval func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) } func firstErr(errors []error) error { @@ -47,10 +47,7 @@ func firstErr(errors []error) error { var functions = map[string]function{ "STARTS_WITH": { - Eval: func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) { - if err := firstErr(errors); err != nil { - return nil, spansql.Type{}, err - } + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { // TODO: Refine error messages to exactly match Spanner. // Check input values first. if len(values) != 2 { @@ -68,10 +65,7 @@ var functions = map[string]function{ }, }, "LOWER": { - Eval: func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) { - if err := firstErr(errors); err != nil { - return nil, spansql.Type{}, err - } + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { if len(values) != 1 { return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function LOWER for the given argument types") } @@ -85,18 +79,18 @@ var functions = map[string]function{ }, }, "CAST": { - Eval: func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) { - return cast(values, types, errors, false) + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { + return cast(values, types, false) }, }, "SAFE_CAST": { - Eval: func(values []interface{}, types []spansql.Type, errors []error) (interface{}, spansql.Type, error) { - return cast(values, types, errors, true) + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { + return cast(values, types, true) }, }, } -func cast(values []interface{}, types []spansql.Type, errors []error, safe bool) (interface{}, spansql.Type, error) { +func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) { name := "CAST" if safe { name = "SAFE_CAST" @@ -107,7 +101,8 @@ func cast(values []interface{}, types []spansql.Type, errors []error, safe bool) if len(values) != 1 { return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No matching signature for function %s for the given arguments", name) } - if err := firstErr(errors); err != nil { + // If the input type is an error, then the conversion itself failed. + if err, ok := values[0].(error); ok { if safe { return nil, types[0], nil } else { @@ -122,110 +117,121 @@ func convert(val interface{}, tp spansql.Type) (interface{}, error) { if tp.Array { return nil, status.Errorf(codes.Unimplemented, "conversion to ARRAY types is not implemented") } + var res interface{} + var convertErr, err error switch tp.Base { case spansql.Int64: - return convertToInt64(val) + res, convertErr, err = convertToInt64(val) case spansql.Float64: - return convertToFloat64(val) + res, convertErr, err = convertToFloat64(val) case spansql.String: - return convertToString(val) + res, convertErr, err = convertToString(val) case spansql.Bool: - return convertToBool(val) + res, convertErr, err = convertToBool(val) case spansql.Date: - return convertToDate(val) + res, convertErr, err = convertToDate(val) case spansql.Timestamp: - return convertToTimestamp(val) + res, convertErr, err = convertToTimestamp(val) case spansql.Numeric: case spansql.JSON: } + if err != nil { + return nil, err + } + if convertErr != nil { + res = convertErr + } + if res != nil { + return res, nil + } return nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to %v", val, tp.Base.SQL()) } -func convertToInt64(val interface{}) (int64, error) { +func convertToInt64(val interface{}) (res int64, convertErr error, err error) { switch v := val.(type) { case int64: - return v, nil + return v, nil, nil case string: res, err := strconv.ParseInt(v, 10, 64) if err != nil { - return 0, status.Errorf(codes.InvalidArgument, "invalid value for INT64: %q", v) + return 0, status.Errorf(codes.InvalidArgument, "invalid value for INT64: %q", v), nil } - return res, nil + return res, nil, nil } - return 0, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to INT64", val) + return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to INT64", val) } -func convertToFloat64(val interface{}) (float64, error) { +func convertToFloat64(val interface{}) (res float64, convertErr error, err error) { switch v := val.(type) { case int64: - return float64(v), nil + return float64(v), nil, nil case float64: - return v, nil + return v, nil, nil case string: res, err := strconv.ParseFloat(v, 64) if err != nil { - return 0, status.Errorf(codes.InvalidArgument, "invalid value for FLOAT64: %q", v) + return 0, status.Errorf(codes.InvalidArgument, "invalid value for FLOAT64: %q", v), nil } - return res, nil + return res, nil, nil } - return 0, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to FLOAT64", val) + return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to FLOAT64", val) } -func convertToString(val interface{}) (string, error) { +func convertToString(val interface{}) (res string, convertErr error, err error) { switch v := val.(type) { case string: - return v, nil + return v, nil, nil case bool, int64, float64: - return fmt.Sprintf("%v", v), nil + return fmt.Sprintf("%v", v), nil, nil case civil.Date: - return v.String(), nil + return v.String(), nil, nil case time.Time: - return v.UTC().Format(time.RFC3339Nano), nil + return v.UTC().Format(time.RFC3339Nano), nil, nil } - return "", status.Errorf(codes.Unimplemented, "unsupported conversion for %v to STRING", val) + return "", nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to STRING", val) } -func convertToBool(val interface{}) (bool, error) { +func convertToBool(val interface{}) (res bool, convertErr error, err error) { switch v := val.(type) { case bool: - return v, nil + return v, nil, nil case string: res, err := strconv.ParseBool(v) if err != nil { - return false, status.Errorf(codes.InvalidArgument, "invalid value for BOOL: %q", v) + return false, status.Errorf(codes.InvalidArgument, "invalid value for BOOL: %q", v), nil } - return res, nil + return res, nil, nil } - return false, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to BOOL", val) + return false, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to BOOL", val) } -func convertToDate(val interface{}) (civil.Date, error) { +func convertToDate(val interface{}) (res civil.Date, convertErr error, err error) { switch v := val.(type) { case civil.Date: - return v, nil + return v, nil, nil case string: res, err := civil.ParseDate(v) if err != nil { - return civil.Date{}, status.Errorf(codes.InvalidArgument, "invalid value for DATE: %q", v) + return civil.Date{}, status.Errorf(codes.InvalidArgument, "invalid value for DATE: %q", v), nil } - return res, nil + return res, nil, nil } - return civil.Date{}, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to DATE", val) + return civil.Date{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to DATE", val) } -func convertToTimestamp(val interface{}) (time.Time, error) { +func convertToTimestamp(val interface{}) (res time.Time, convertErr error, err error) { switch v := val.(type) { case time.Time: - return v, nil + return v, nil, nil case string: res, err := time.Parse(time.RFC3339Nano, v) if err != nil { - return time.Time{}, status.Errorf(codes.InvalidArgument, "invalid value for TIMESTAMP: %q", v) + return time.Time{}, status.Errorf(codes.InvalidArgument, "invalid value for TIMESTAMP: %q", v), nil } - return res, nil + return res, nil, nil } - return time.Time{}, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to TIMESTAMP", val) + return time.Time{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to TIMESTAMP", val) } type aggregateFunc struct { From ba388b85edbb3ca6a825b53d8d761b6b54007fc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 1 Nov 2021 14:20:54 +0100 Subject: [PATCH 3/3] fix: outdent else statement --- spanner/spannertest/funcs.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go index e2f6889d454..a27eab5e7be 100644 --- a/spanner/spannertest/funcs.go +++ b/spanner/spannertest/funcs.go @@ -105,9 +105,8 @@ func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, s if err, ok := values[0].(error); ok { if safe { return nil, types[0], nil - } else { - return nil, types[0], err } + return nil, types[0], err } return values[0], types[0], nil }