Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner/spansql): add support for STARTS_WITH function #4670

Merged
merged 2 commits into from Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 30 additions & 1 deletion spanner/spannertest/db_eval.go
Expand Up @@ -28,6 +28,8 @@ import (

"cloud.google.com/go/civil"
"cloud.google.com/go/spanner/spansql"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// evalContext represents the context for evaluating an expression.
Expand Down Expand Up @@ -72,7 +74,7 @@ func (ec evalContext) evalBoolExpr(be spansql.BoolExpr) (*bool, error) {
case spansql.BoolLiteral:
b := bool(be)
return &b, nil
case spansql.ID, spansql.Param, spansql.Paren, spansql.InOp: // InOp is a bit weird.
case spansql.ID, spansql.Param, spansql.Paren, spansql.Func, spansql.InOp: // InOp is a bit weird.
e, err := ec.evalExpr(be)
if err != nil {
return nil, err
Expand Down Expand Up @@ -363,6 +365,21 @@ func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) {
return nil, fmt.Errorf("TODO: evalArithOp(%s %v)", e.SQL(), e.Op)
}

func (ec evalContext) evalFunc(e spansql.Func) (interface{}, spansql.Type, error) {
if f, ok := functions[e.Name]; ok {
args := make([]interface{}, len(e.Args))
for i, arg := range e.Args {
val, err := ec.evalExpr(arg)
if err != nil {
return nil, spansql.Type{}, err
}
args[i] = val
}
return f.Eval(args)
}
return nil, spansql.Type{}, status.Errorf(codes.Unimplemented, "function %q is not implemented", e.Name)
}

// evalFloat64 evaluates an expression and returns its FLOAT64 value.
// If the expression does not yield a FLOAT64 or INT64 it returns an error.
func (ec evalContext) evalFloat64(e spansql.Expr) (float64, error) {
Expand Down Expand Up @@ -428,6 +445,12 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) {
return bool(e), nil
case spansql.Paren:
return ec.evalExpr(e.Expr)
case spansql.Func:
v, _, err := ec.evalFunc(e)
if err != nil {
return nil, err
}
return v, nil
case spansql.Array:
var arr []interface{}
for _, elt := range e {
Expand Down Expand Up @@ -785,6 +808,12 @@ func (ec evalContext) colInfo(e spansql.Expr) (colInfo, error) {
return colInfo{Type: qp.Type}, nil
case spansql.Paren:
return ec.colInfo(e.Expr)
case spansql.Func:
_, t, err := ec.evalFunc(e)
if err != nil {
return colInfo{}, err
}
return colInfo{Type: t}, nil
case spansql.Array:
// Assume all element of an array literal have the same type.
if len(e) == 0 {
Expand Down
29 changes: 29 additions & 0 deletions spanner/spannertest/funcs.go
Expand Up @@ -19,12 +19,41 @@ package spannertest
import (
"fmt"
"math"
"strings"

"cloud.google.com/go/spanner/spansql"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// This file contains implementations of query functions.

type function struct {
// Eval evaluates the result of the function using the given input.
Eval func(values []interface{}) (interface{}, spansql.Type, error)
}

var functions = map[string]function{
"STARTS_WITH": {
Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
// TODO: Refine error messages to exactly match Spanner.
// Check input values first.
if len(values) != 2 {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function STARTS_WITH for the given argument types")
}
for _, v := range values {
// TODO: STARTS_WITH also supports BYTES as input parameters.
if _, ok := v.(string); !ok {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function STARTS_WITH for the given argument types")
}
}
s := values[0].(string)
prefix := values[1].(string)
return strings.HasPrefix(s, prefix), spansql.Type{Base: spansql.Bool}, nil
},
},
}

type aggregateFunc struct {
// Whether the function can take a * arg (only COUNT).
AcceptStar bool
Expand Down
11 changes: 9 additions & 2 deletions spanner/spannertest/integration_test.go
Expand Up @@ -748,9 +748,9 @@ func TestIntegration_ReadsAndQueries(t *testing.T) {
want [][]interface{}
}{
{
`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello"`,
`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B')`,
nil,
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello")}},
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true}},
},
// Check handling of NULL values for the IS operator.
// There was a bug that returned errors for some of these cases.
Expand Down Expand Up @@ -824,6 +824,13 @@ func TestIntegration_ReadsAndQueries(t *testing.T) {
{"Sam"},
},
},
{
`SELECT Name FROM Staff WHERE STARTS_WITH(Name, 'Ja')`,
nil,
[][]interface{}{
{"Jack"},
},
},
{
`SELECT Name, Height FROM Staff WHERE Height BETWEEN @min AND @max ORDER BY Height DESC`,
map[string]interface{}{"min": 1.75, "max": 1.85},
Expand Down
12 changes: 12 additions & 0 deletions spanner/spansql/parser_test.go
Expand Up @@ -112,6 +112,15 @@ func TestParseQuery(t *testing.T) {
},
},
},
{`SELECT * FROM Foo WHERE STARTS_WITH(Bar, 'B')`,
Query{
Select: Select{
List: []Expr{Star},
From: []SelectFrom{SelectFromTable{Table: "Foo"}},
Where: Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}},
},
},
},
{`SELECT SUM(PointsScored) AS total_points, FirstName, LastName AS surname FROM PlayerStats GROUP BY FirstName, LastName`,
Query{
Select: Select{
Expand Down Expand Up @@ -314,6 +323,9 @@ func TestParseExpr(t *testing.T) {
{`@needle IN UNNEST(@haystack)`, InOp{LHS: Param("needle"), RHS: []Expr{Param("haystack")}, Unnest: true}},
{`@needle NOT IN UNNEST(@haystack)`, InOp{LHS: Param("needle"), Neg: true, RHS: []Expr{Param("haystack")}, Unnest: true}},

// Functions
{`STARTS_WITH(Bar, 'B')`, Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}}},

// String literal:
// Accept double quote and single quote.
{`"hello"`, StringLiteral("hello")},
Expand Down