Skip to content

Commit

Permalink
feat(spanner/spansql): support CAST and SAFE_CAST (#5057)
Browse files Browse the repository at this point in the history
* feat(spanner/spansql): support CAST and SAFE_CAST

Adds support for the CAST and SAFE_CAST functions. This change also includes
a small refactor of the function evaluation to allow the function evaluation to
receive both parsing errors and type information for the arguments. This makes
it easier to implement the SAFE versions of the functions.

* fix: only ignore conversion errors in SAFE_CAST

* fix: outdent else statement

Co-authored-by: rahul2393 <rahulyadavsep92@gmail.com>
Co-authored-by: Hengfeng Li <hengfeng@google.com>
  • Loading branch information
3 people committed Nov 3, 2021
1 parent cbd5c8c commit 54cbf4c
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 15 deletions.
21 changes: 17 additions & 4 deletions spanner/spannertest/db_eval.go
Expand Up @@ -385,16 +385,19 @@ func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) {
}

func (ec evalContext) evalFunc(e spansql.Func) (interface{}, spansql.Type, error) {
var err error
if f, ok := functions[e.Name]; ok {
args := make([]interface{}, len(e.Args))
types := make([]spansql.Type, len(e.Args))
for i, arg := range e.Args {
val, err := ec.evalExpr(arg)
if err != nil {
if args[i], err = ec.evalExpr(arg); err != nil {
return nil, spansql.Type{}, err
}
args[i] = val
if te, ok := arg.(spansql.TypedExpr); ok {
types[i] = te.Type
}
}
return f.Eval(args)
return f.Eval(args, types)
}
return nil, spansql.Type{}, status.Errorf(codes.Unimplemented, "function %q is not implemented", e.Name)
}
Expand Down Expand Up @@ -464,6 +467,8 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) {
return bool(e), nil
case spansql.Paren:
return ec.evalExpr(e.Expr)
case spansql.TypedExpr:
return ec.evalTypedExpr(e)
case spansql.Func:
v, _, err := ec.evalFunc(e)
if err != nil {
Expand Down Expand Up @@ -662,6 +667,14 @@ func (ec evalContext) coerceString(target spansql.Expr, slit spansql.StringLiter
return nil, fmt.Errorf("unable to coerce string literal %q to match %v", slit, ci.Type)
}

func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (result interface{}, err error) {
val, err := ec.evalExpr(expr.Expr)
if err != nil {
return nil, err
}
return convert(val, expr.Type)
}

func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) {
switch v := lop.(type) {
case spansql.IntegerLiteral:
Expand Down
171 changes: 168 additions & 3 deletions spanner/spannertest/funcs.go
Expand Up @@ -19,8 +19,11 @@ package spannertest
import (
"fmt"
"math"
"strconv"
"strings"
"time"

"cloud.google.com/go/civil"
"cloud.google.com/go/spanner/spansql"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -30,12 +33,21 @@ import (

type function struct {
// Eval evaluates the result of the function using the given input.
Eval func(values []interface{}) (interface{}, spansql.Type, error)
Eval func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error)
}

func firstErr(errors []error) error {
for _, err := range errors {
if err != nil {
return err
}
}
return nil
}

var functions = map[string]function{
"STARTS_WITH": {
Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
// TODO: Refine error messages to exactly match Spanner.
// Check input values first.
if len(values) != 2 {
Expand All @@ -53,7 +65,7 @@ var functions = map[string]function{
},
},
"LOWER": {
Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
if len(values) != 1 {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function LOWER for the given argument types")
}
Expand All @@ -66,6 +78,159 @@ var functions = map[string]function{
return strings.ToLower(values[0].(string)), spansql.Type{Base: spansql.String}, nil
},
},
"CAST": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
return cast(values, types, false)
},
},
"SAFE_CAST": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
return cast(values, types, true)
},
},
}

func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) {
name := "CAST"
if safe {
name = "SAFE_CAST"
}
if len(types) != 1 {
return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No type information for function %s for the given arguments", name)
}
if len(values) != 1 {
return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No matching signature for function %s for the given arguments", name)
}
// If the input type is an error, then the conversion itself failed.
if err, ok := values[0].(error); ok {
if safe {
return nil, types[0], nil
}
return nil, types[0], err
}
return values[0], types[0], nil
}

func convert(val interface{}, tp spansql.Type) (interface{}, error) {
// TODO: Implement more conversions.
if tp.Array {
return nil, status.Errorf(codes.Unimplemented, "conversion to ARRAY types is not implemented")
}
var res interface{}
var convertErr, err error
switch tp.Base {
case spansql.Int64:
res, convertErr, err = convertToInt64(val)
case spansql.Float64:
res, convertErr, err = convertToFloat64(val)
case spansql.String:
res, convertErr, err = convertToString(val)
case spansql.Bool:
res, convertErr, err = convertToBool(val)
case spansql.Date:
res, convertErr, err = convertToDate(val)
case spansql.Timestamp:
res, convertErr, err = convertToTimestamp(val)
case spansql.Numeric:
case spansql.JSON:
}
if err != nil {
return nil, err
}
if convertErr != nil {
res = convertErr
}
if res != nil {
return res, nil
}

return nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to %v", val, tp.Base.SQL())
}

func convertToInt64(val interface{}) (res int64, convertErr error, err error) {
switch v := val.(type) {
case int64:
return v, nil, nil
case string:
res, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, status.Errorf(codes.InvalidArgument, "invalid value for INT64: %q", v), nil
}
return res, nil, nil
}
return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to INT64", val)
}

func convertToFloat64(val interface{}) (res float64, convertErr error, err error) {
switch v := val.(type) {
case int64:
return float64(v), nil, nil
case float64:
return v, nil, nil
case string:
res, err := strconv.ParseFloat(v, 64)
if err != nil {
return 0, status.Errorf(codes.InvalidArgument, "invalid value for FLOAT64: %q", v), nil
}
return res, nil, nil
}
return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to FLOAT64", val)
}

func convertToString(val interface{}) (res string, convertErr error, err error) {
switch v := val.(type) {
case string:
return v, nil, nil
case bool, int64, float64:
return fmt.Sprintf("%v", v), nil, nil
case civil.Date:
return v.String(), nil, nil
case time.Time:
return v.UTC().Format(time.RFC3339Nano), nil, nil
}
return "", nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to STRING", val)
}

func convertToBool(val interface{}) (res bool, convertErr error, err error) {
switch v := val.(type) {
case bool:
return v, nil, nil
case string:
res, err := strconv.ParseBool(v)
if err != nil {
return false, status.Errorf(codes.InvalidArgument, "invalid value for BOOL: %q", v), nil
}
return res, nil, nil
}
return false, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to BOOL", val)
}

func convertToDate(val interface{}) (res civil.Date, convertErr error, err error) {
switch v := val.(type) {
case civil.Date:
return v, nil, nil
case string:
res, err := civil.ParseDate(v)
if err != nil {
return civil.Date{}, status.Errorf(codes.InvalidArgument, "invalid value for DATE: %q", v), nil
}
return res, nil, nil
}
return civil.Date{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to DATE", val)
}

func convertToTimestamp(val interface{}) (res time.Time, convertErr error, err error) {
switch v := val.(type) {
case time.Time:
return v, nil, nil
case string:
res, err := time.Parse(time.RFC3339Nano, v)
if err != nil {
return time.Time{}, status.Errorf(codes.InvalidArgument, "invalid value for TIMESTAMP: %q", v), nil
}
return res, nil, nil
}
return time.Time{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to TIMESTAMP", val)
}

type aggregateFunc struct {
Expand Down
11 changes: 9 additions & 2 deletions spanner/spannertest/integration_test.go
Expand Up @@ -741,16 +741,23 @@ func TestIntegration_ReadsAndQueries(t *testing.T) {
t.Errorf("Updating with DML affected %d rows, want 3", n)
}

rows := client.Single().Query(ctx, spanner.NewStatement("SELECT CAST('Foo' AS INT64)"))
_, err = rows.Next()
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Errorf("error code mismatch for invalid CAST\n Got: %v\nWant: %v", g, w)
}
rows.Stop()

// Do some complex queries.
tests := []struct {
q string
params map[string]interface{}
want [][]interface{}
}{
{
`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B')`,
`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64)`,
nil,
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true}},
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil}},
},
// Check handling of NULL values for the IS operator.
// There was a bug that returned errors for some of these cases.
Expand Down
8 changes: 8 additions & 0 deletions spanner/spansql/keywords.go
Expand Up @@ -128,11 +128,15 @@ var keywords = map[string]bool{
// funcs is the set of reserved keywords that are functions.
// https://cloud.google.com/spanner/docs/functions-and-operators
var funcs = make(map[string]bool)
var funcArgParsers = make(map[string]func(*parser) (Expr, *parseError))

func init() {
for _, f := range allFuncs {
funcs[f] = true
}
// Special case for CAST and SAFE_CAST
funcArgParsers["CAST"] = typedArgParser
funcArgParsers["SAFE_CAST"] = typedArgParser
}

var allFuncs = []string{
Expand All @@ -148,6 +152,10 @@ var allFuncs = []string{
"MIN",
"SUM",

// Cast functions.
"CAST",
"SAFE_CAST",

// Mathematical functions.
"ABS",

Expand Down
48 changes: 44 additions & 4 deletions spanner/spansql/parser.go
Expand Up @@ -1893,8 +1893,16 @@ var baseTypes = map[string]TypeBase{
"JSON": JSON,
}

func (p *parser) parseBaseType() (Type, *parseError) {
return p.parseBaseOrParameterizedType(false)
}

func (p *parser) parseType() (Type, *parseError) {
debugf("parseType: %v", p)
return p.parseBaseOrParameterizedType(true)
}

func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) {
debugf("parseBaseOrParameterizedType: %v", p)

/*
array_type:
Expand Down Expand Up @@ -1928,7 +1936,7 @@ func (p *parser) parseType() (Type, *parseError) {
}
t.Base = base

if t.Base == String || t.Base == Bytes {
if withParam && (t.Base == String || t.Base == Bytes) {
if err := p.expect("("); err != nil {
return Type{}, err
}
Expand Down Expand Up @@ -2436,9 +2444,15 @@ func (p *parser) parseExprList() ([]Expr, *parseError) {
}

func (p *parser) parseParenExprList() ([]Expr, *parseError) {
return p.parseParenExprListWithParseFunc(func(p *parser) (Expr, *parseError) {
return p.parseExpr()
})
}

func (p *parser) parseParenExprListWithParseFunc(f func(*parser) (Expr, *parseError)) ([]Expr, *parseError) {
var list []Expr
err := p.parseCommaList("(", ")", func(p *parser) *parseError {
e, err := p.parseExpr()
e, err := f(p)
if err != nil {
return err
}
Expand All @@ -2448,6 +2462,26 @@ func (p *parser) parseParenExprList() ([]Expr, *parseError) {
return list, err
}

// Special argument parser for CAST and SAFE_CAST
var typedArgParser = func(p *parser) (Expr, *parseError) {
e, err := p.parseExpr()
if err != nil {
return nil, err
}
if err := p.expect("AS"); err != nil {
return nil, err
}
// typename in cast function must not be parameterized types
toType, err := p.parseBaseType()
if err != nil {
return nil, err
}
return TypedExpr{
Expr: e,
Type: toType,
}, nil
}

/*
Expressions
Expand Down Expand Up @@ -2800,7 +2834,13 @@ func (p *parser) parseLit() (Expr, *parseError) {
// this is a function invocation.
// The `funcs` map is keyed by upper case strings.
if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") {
list, err := p.parseParenExprList()
var list []Expr
var err *parseError
if f, ok := funcArgParsers[name]; ok {
list, err = p.parseParenExprListWithParseFunc(f)
} else {
list, err = p.parseParenExprList()
}
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 54cbf4c

Please sign in to comment.