Skip to content

Commit

Permalink
feat(spanner/spansql): support EXTRACT (#5218)
Browse files Browse the repository at this point in the history
* feat(spanner/spansql): support EXTRACT

* added separate Expr for Extract func and added unit and integration tests

* add test for year

* repleace atTimeZone func with atTimeZone expression

* fixing failing tests

* added negative test, reduced the valid extract part values.

* remove extra space

Co-authored-by: Rahul Yadav <irahul@google.com>
  • Loading branch information
rahul2393 and rahul2393 committed Dec 16, 2021
1 parent 2c664a6 commit 81b7c85
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 16 deletions.
59 changes: 59 additions & 0 deletions spanner/spannertest/db_eval.go
Expand Up @@ -469,6 +469,10 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) {
return ec.evalExpr(e.Expr)
case spansql.TypedExpr:
return ec.evalTypedExpr(e)
case spansql.ExtractExpr:
return ec.evalExtractExpr(e)
case spansql.AtTimeZoneExpr:
return ec.evalAtTimeZoneExpr(e)
case spansql.Func:
v, _, err := ec.evalFunc(e)
if err != nil {
Expand Down Expand Up @@ -675,6 +679,61 @@ func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (result interface{},
return convert(val, expr.Type)
}

func (ec evalContext) evalExtractExpr(expr spansql.ExtractExpr) (result interface{}, err error) {
val, err := ec.evalExpr(expr.Expr)
if err != nil {
return nil, err
}
switch expr.Part {
case "DATE":
switch v := val.(type) {
case time.Time:
return civil.DateOf(v), nil
case civil.Date:
return v, nil
}
case "DAY":
switch v := val.(type) {
case time.Time:
return int64(v.Day()), nil
case civil.Date:
return int64(v.Day), nil
}
case "MONTH":
switch v := val.(type) {
case time.Time:
return int64(v.Month()), nil
case civil.Date:
return int64(v.Month), nil
}
case "YEAR":
switch v := val.(type) {
case time.Time:
return int64(v.Year()), nil
case civil.Date:
return int64(v.Year), nil
}
}
return nil, fmt.Errorf("Extract with part %v not supported", expr.Part)
}

func (ec evalContext) evalAtTimeZoneExpr(expr spansql.AtTimeZoneExpr) (result interface{}, err error) {
val, err := ec.evalExpr(expr.Expr)
if err != nil {
return nil, err
}
switch v := val.(type) {
case time.Time:
loc, err := time.LoadLocation(expr.Zone)
if err != nil {
return nil, fmt.Errorf("AtTimeZone with %T not supported", v)
}
return v.In(loc), nil
default:
return nil, fmt.Errorf("AtTimeZone with %T not supported", val)
}
}

func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) {
switch v := lop.(type) {
case spansql.IntegerLiteral:
Expand Down
26 changes: 26 additions & 0 deletions spanner/spannertest/funcs.go
Expand Up @@ -107,6 +107,32 @@ var functions = map[string]function{
return "", spansql.Type{Base: spansql.String}, nil
},
},
"EXTRACT": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
date, okArg1 := values[0].(civil.Date)
part, okArg2 := values[0].(int64)
if !(okArg1 || okArg2) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function EXTRACT for the given argument types")
}
if okArg1 {
return date, spansql.Type{Base: spansql.Date}, nil
}
return part, spansql.Type{Base: spansql.Int64}, nil
},
},
"TIMESTAMP": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
t, okArg1 := values[0].(string)
if !(okArg1) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types")
}
timestamp, err := time.Parse(time.RFC3339, t)
if err != nil {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types")
}
return timestamp, spansql.Type{Base: spansql.Timestamp}, nil
},
},
}

func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) {
Expand Down
39 changes: 26 additions & 13 deletions spanner/spannertest/integration_test.go
Expand Up @@ -748,16 +748,24 @@ func TestIntegration_ReadsAndQueries(t *testing.T) {
}
rows.Stop()

rows = client.Single().Query(ctx, spanner.NewStatement("SELECT EXTRACT(INVALID_PART FROM TIMESTAMP('2008-12-25T05:30:00Z')"))
_, err = rows.Next()
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Errorf("error code mismatch for invalid part from EXTRACT\n Got: %v\nWant: %v", g, w)
}
rows.Stop()

// Do some complex queries.
tests := []struct {
q string
params map[string]interface{}
want [][]interface{}
}{
{
`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64)`,

`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64), EXTRACT(DATE FROM TIMESTAMP('2008-12-25T05:30:00Z') AT TIME ZONE 'Europe/Amsterdam'), EXTRACT(YEAR FROM TIMESTAMP('2008-12-25T05:30:00Z'))`,
nil,
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil}},
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil, civil.Date{Year: 2008, Month: 12, Day: 25}, int64(2008)}},
},
// Check handling of NULL values for the IS operator.
// There was a bug that returned errors for some of these cases.
Expand Down Expand Up @@ -1277,13 +1285,16 @@ func TestIntegration_GeneratedColumns(t *testing.T) {
defer cancel()

tableName := "SongWriters"

err := updateDDL(t, adminClient,
`CREATE TABLE `+tableName+` (
Name STRING(50) NOT NULL,
NumSongs INT64,
CreatedAT TIMESTAMP,
CreatedDate DATE,
EstimatedSales INT64 NOT NULL,
CanonicalName STRING(50) AS (LOWER(Name)) STORED,
GeneratedCreatedDate DATE AS (EXTRACT(DATE FROM CreatedAT AT TIME ZONE "CET")) STORED,
GeneratedCreatedDay INT64 AS (EXTRACT(DAY FROM CreatedDate)) STORED,
) PRIMARY KEY (Name)`)
if err != nil {
t.Fatalf("Setting up fresh table: %v", err)
Expand All @@ -1295,16 +1306,18 @@ func TestIntegration_GeneratedColumns(t *testing.T) {
}

// Insert some data.
d1, _ := civil.ParseDate("2016-11-15")
t1, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z")
_, err = client.Apply(ctx, []*spanner.Mutation{
spanner.Insert(tableName,
[]string{"Name", "EstimatedSales", "NumSongs"},
[]interface{}{"Average Writer", 10, 10}),
[]string{"Name", "EstimatedSales", "NumSongs", "CreatedAT", "CreatedDate"},
[]interface{}{"Average Writer", 10, 10, t1, d1}),
spanner.Insert(tableName,
[]string{"Name", "EstimatedSales"},
[]interface{}{"Great Writer", 100}),
[]string{"Name", "EstimatedSales", "CreatedAT", "CreatedDate"},
[]interface{}{"Great Writer", 100, t1, d1}),
spanner.Insert(tableName,
[]string{"Name", "EstimatedSales", "NumSongs"},
[]interface{}{"Poor Writer", 1, 50}),
[]string{"Name", "EstimatedSales", "NumSongs", "CreatedAT", "CreatedDate"},
[]interface{}{"Poor Writer", 1, 50, t1, d1}),
})
if err != nil {
t.Fatalf("Applying mutations: %v", err)
Expand All @@ -1317,7 +1330,7 @@ func TestIntegration_GeneratedColumns(t *testing.T) {
}

ri := client.Single().Query(ctx, spanner.NewStatement(
`SELECT CanonicalName, TotalSales FROM `+tableName+` ORDER BY Name`,
`SELECT CanonicalName, TotalSales, GeneratedCreatedDate, GeneratedCreatedDay FROM `+tableName+` ORDER BY Name`,
))
all, err := slurpRows(t, ri)
if err != nil {
Expand All @@ -1326,9 +1339,9 @@ func TestIntegration_GeneratedColumns(t *testing.T) {

// Great writer has nil because NumSongs is nil
want := [][]interface{}{
{"average writer", int64(100)},
{"great writer", nil},
{"poor writer", int64(50)},
{"average writer", int64(100), civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)},
{"great writer", nil, civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)},
{"poor writer", int64(50), civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)},
}
if !reflect.DeepEqual(all, want) {
t.Errorf("Expected values are wrong.\n got %v\nwant %v", all, want)
Expand Down
3 changes: 2 additions & 1 deletion spanner/spansql/keywords.go
Expand Up @@ -134,9 +134,10 @@ func init() {
for _, f := range allFuncs {
funcs[f] = true
}
// Special case for CAST and SAFE_CAST
// Special case for CAST, SAFE_CAST and EXTRACT
funcArgParsers["CAST"] = typedArgParser
funcArgParsers["SAFE_CAST"] = typedArgParser
funcArgParsers["EXTRACT"] = extractArgParser
}

var allFuncs = []string{
Expand Down
49 changes: 49 additions & 0 deletions spanner/spansql/parser.go
Expand Up @@ -1901,6 +1901,27 @@ func (p *parser) parseType() (Type, *parseError) {
return p.parseBaseOrParameterizedType(true)
}

var extractPartTypes = map[string]TypeBase{
"DAY": Int64,
"MONTH": Int64,
"YEAR": Int64,
"DATE": Date,
}

func (p *parser) parseExtractType() (Type, string, *parseError) {
var t Type
tok := p.next()
if tok.err != nil {
return Type{}, "", tok.err
}
base, ok := extractPartTypes[strings.ToUpper(tok.value)] // valid part types for EXTRACT is keyed by upper case strings.
if !ok {
return Type{}, "", p.errorf("got %q, want valid EXTRACT types", tok.value)
}
t.Base = base
return t, strings.ToUpper(tok.value), nil
}

func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) {
debugf("parseBaseOrParameterizedType: %v", p)

Expand Down Expand Up @@ -2482,6 +2503,34 @@ var typedArgParser = func(p *parser) (Expr, *parseError) {
}, nil
}

// Special argument parser for EXTRACT
var extractArgParser = func(p *parser) (Expr, *parseError) {
partType, part, err := p.parseExtractType()
if err != nil {
return nil, err
}
if err := p.expect("FROM"); err != nil {
return nil, err
}
e, err := p.parseExpr()
if err != nil {
return nil, err
}
// AT TIME ZONE is optional
if p.eat("AT", "TIME", "ZONE") {
tok := p.next()
if tok.err != nil {
return nil, err
}
return ExtractExpr{Part: part, Type: partType, Expr: AtTimeZoneExpr{Expr: e, Zone: tok.string, Type: Type{Base: Timestamp}}}, nil
}
return ExtractExpr{
Part: part,
Expr: e,
Type: partType,
}, nil
}

/*
Expressions
Expand Down
16 changes: 14 additions & 2 deletions spanner/spansql/parser_test.go
Expand Up @@ -340,6 +340,8 @@ func TestParseExpr(t *testing.T) {
{`STARTS_WITH(Bar, 'B')`, Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}}},
{`CAST(Bar AS STRING)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}},
{`SAFE_CAST(Bar AS INT64)`, Func{Name: "SAFE_CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Int64}}}}},
{`EXTRACT(DATE FROM TIMESTAMP AT TIME ZONE "America/Los_Angeles")`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("TIMESTAMP"), Zone: "America/Los_Angeles", Type: Type{Base: Timestamp}}}}}},
{`EXTRACT(DAY FROM DATE)`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DAY", Expr: ID("DATE"), Type: Type{Base: Int64}}}}},

// String literal:
// Accept double quote and single quote.
Expand Down Expand Up @@ -524,7 +526,9 @@ func TestParseDDL(t *testing.T) {
CREATE TABLE users (
user_id STRING(36) NOT NULL,
some_string STRING(16) NOT NULL,
some_time TIMESTAMP NOT NULL,
number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED,
generated_date DATE AS (EXTRACT(DATE FROM some_time AT TIME ZONE "CET")) STORED,
) PRIMARY KEY(user_id);
-- Trailing comment at end of file.
Expand Down Expand Up @@ -744,12 +748,20 @@ func TestParseDDL(t *testing.T) {
Columns: []ColumnDef{
{Name: "user_id", Type: Type{Base: String, Len: 36}, NotNull: true, Position: line(67)},
{Name: "some_string", Type: Type{Base: String, Len: 16}, NotNull: true, Position: line(68)},
{Name: "some_time", Type: Type{Base: Timestamp}, NotNull: true, Position: line(69)},
{
Name: "number_key", Type: Type{Base: Int64},
Generated: Func{Name: "SAFE_CAST", Args: []Expr{
TypedExpr{Expr: Func{Name: "SUBSTR", Args: []Expr{ID("some_string"), IntegerLiteral(2)}}, Type: Type{Base: Int64}},
}},
Position: line(69),
Position: line(70),
},
{
Name: "generated_date", Type: Type{Base: Date},
Generated: Func{Name: "EXTRACT", Args: []Expr{
ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("some_time"), Zone: "CET", Type: Type{Base: Timestamp}}},
}},
Position: line(71),
},
},
PrimaryKey: []KeyPart{{Column: "user_id"}},
Expand Down Expand Up @@ -777,7 +789,7 @@ func TestParseDDL(t *testing.T) {
{Marker: "--", Isolated: true, Start: line(49), End: line(49), Text: []string{"Table with row deletion policy."}},

// Comment after everything else.
{Marker: "--", Isolated: true, Start: line(72), End: line(72), Text: []string{"Trailing comment at end of file."}},
{Marker: "--", Isolated: true, Start: line(74), End: line(74), Text: []string{"Trailing comment at end of file."}},
}}},
// No trailing comma:
{`ALTER TABLE T ADD COLUMN C2 INT64`, &DDL{Filename: "filename", List: []DDLStmt{
Expand Down
14 changes: 14 additions & 0 deletions spanner/spansql/sql.go
Expand Up @@ -589,6 +589,20 @@ func (te TypedExpr) addSQL(sb *strings.Builder) {
sb.WriteString(te.Type.SQL())
}

func (ee ExtractExpr) SQL() string { return buildSQL(ee) }
func (ee ExtractExpr) addSQL(sb *strings.Builder) {
sb.WriteString(ee.Part)
sb.WriteString(" FROM ")
ee.Expr.addSQL(sb)
}

func (aze AtTimeZoneExpr) SQL() string { return buildSQL(aze) }
func (aze AtTimeZoneExpr) addSQL(sb *strings.Builder) {
aze.Expr.addSQL(sb)
sb.WriteString(" AT TIME ZONE ")
sb.WriteString(aze.Zone)
}

func idList(l []ID, join string) string {
var ss []string
for _, s := range l {
Expand Down
18 changes: 18 additions & 0 deletions spanner/spansql/types.go
Expand Up @@ -647,6 +647,24 @@ type TypedExpr struct {
func (TypedExpr) isBoolExpr() {} // possibly bool
func (TypedExpr) isExpr() {}

type ExtractExpr struct {
Part string
Type Type
Expr Expr
}

func (ExtractExpr) isBoolExpr() {} // possibly bool
func (ExtractExpr) isExpr() {}

type AtTimeZoneExpr struct {
Expr Expr
Type Type
Zone string
}

func (AtTimeZoneExpr) isBoolExpr() {} // possibly bool
func (AtTimeZoneExpr) isExpr() {}

// Paren represents a parenthesised expression.
type Paren struct {
Expr Expr
Expand Down

0 comments on commit 81b7c85

Please sign in to comment.