From 81b7c85a8993a36557ea4eb4ec0c47d1f93c4960 Mon Sep 17 00:00:00 2001 From: rahul2393 Date: Thu, 16 Dec 2021 10:32:27 +0530 Subject: [PATCH] feat(spanner/spansql): support EXTRACT (#5218) * feat(spanner/spansql): support EXTRACT * added separate Expr for Extract func and added unit and integration tests * add test for year * repleace atTimeZone func with atTimeZone expression * fixing failing tests * added negative test, reduced the valid extract part values. * remove extra space Co-authored-by: Rahul Yadav --- spanner/spannertest/db_eval.go | 59 +++++++++++++++++++++++++ spanner/spannertest/funcs.go | 26 +++++++++++ spanner/spannertest/integration_test.go | 39 ++++++++++------ spanner/spansql/keywords.go | 3 +- spanner/spansql/parser.go | 49 ++++++++++++++++++++ spanner/spansql/parser_test.go | 16 ++++++- spanner/spansql/sql.go | 14 ++++++ spanner/spansql/types.go | 18 ++++++++ 8 files changed, 208 insertions(+), 16 deletions(-) diff --git a/spanner/spannertest/db_eval.go b/spanner/spannertest/db_eval.go index 41d5094fe82..3cf20b156ce 100644 --- a/spanner/spannertest/db_eval.go +++ b/spanner/spannertest/db_eval.go @@ -469,6 +469,10 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) { return ec.evalExpr(e.Expr) case spansql.TypedExpr: return ec.evalTypedExpr(e) + case spansql.ExtractExpr: + return ec.evalExtractExpr(e) + case spansql.AtTimeZoneExpr: + return ec.evalAtTimeZoneExpr(e) case spansql.Func: v, _, err := ec.evalFunc(e) if err != nil { @@ -675,6 +679,61 @@ func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (result interface{}, return convert(val, expr.Type) } +func (ec evalContext) evalExtractExpr(expr spansql.ExtractExpr) (result interface{}, err error) { + val, err := ec.evalExpr(expr.Expr) + if err != nil { + return nil, err + } + switch expr.Part { + case "DATE": + switch v := val.(type) { + case time.Time: + return civil.DateOf(v), nil + case civil.Date: + return v, nil + } + case "DAY": + switch v := val.(type) { + case time.Time: + return int64(v.Day()), nil + case civil.Date: + return int64(v.Day), nil + } + case "MONTH": + switch v := val.(type) { + case time.Time: + return int64(v.Month()), nil + case civil.Date: + return int64(v.Month), nil + } + case "YEAR": + switch v := val.(type) { + case time.Time: + return int64(v.Year()), nil + case civil.Date: + return int64(v.Year), nil + } + } + return nil, fmt.Errorf("Extract with part %v not supported", expr.Part) +} + +func (ec evalContext) evalAtTimeZoneExpr(expr spansql.AtTimeZoneExpr) (result interface{}, err error) { + val, err := ec.evalExpr(expr.Expr) + if err != nil { + return nil, err + } + switch v := val.(type) { + case time.Time: + loc, err := time.LoadLocation(expr.Zone) + if err != nil { + return nil, fmt.Errorf("AtTimeZone with %T not supported", v) + } + return v.In(loc), nil + default: + return nil, fmt.Errorf("AtTimeZone with %T not supported", val) + } +} + func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) { switch v := lop.(type) { case spansql.IntegerLiteral: diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go index 3c451a8b7cd..d92b130fa26 100644 --- a/spanner/spannertest/funcs.go +++ b/spanner/spannertest/funcs.go @@ -107,6 +107,32 @@ var functions = map[string]function{ return "", spansql.Type{Base: spansql.String}, nil }, }, + "EXTRACT": { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { + date, okArg1 := values[0].(civil.Date) + part, okArg2 := values[0].(int64) + if !(okArg1 || okArg2) { + return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function EXTRACT for the given argument types") + } + if okArg1 { + return date, spansql.Type{Base: spansql.Date}, nil + } + return part, spansql.Type{Base: spansql.Int64}, nil + }, + }, + "TIMESTAMP": { + Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) { + t, okArg1 := values[0].(string) + if !(okArg1) { + return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types") + } + timestamp, err := time.Parse(time.RFC3339, t) + if err != nil { + return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types") + } + return timestamp, spansql.Type{Base: spansql.Timestamp}, nil + }, + }, } func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) { diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index 4d962c5aaa0..52b87a03841 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -748,6 +748,13 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { } rows.Stop() + rows = client.Single().Query(ctx, spanner.NewStatement("SELECT EXTRACT(INVALID_PART FROM TIMESTAMP('2008-12-25T05:30:00Z')")) + _, err = rows.Next() + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Errorf("error code mismatch for invalid part from EXTRACT\n Got: %v\nWant: %v", g, w) + } + rows.Stop() + // Do some complex queries. tests := []struct { q string @@ -755,9 +762,10 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { want [][]interface{} }{ { - `SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64)`, + + `SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64), EXTRACT(DATE FROM TIMESTAMP('2008-12-25T05:30:00Z') AT TIME ZONE 'Europe/Amsterdam'), EXTRACT(YEAR FROM TIMESTAMP('2008-12-25T05:30:00Z'))`, nil, - [][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil}}, + [][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil, civil.Date{Year: 2008, Month: 12, Day: 25}, int64(2008)}}, }, // Check handling of NULL values for the IS operator. // There was a bug that returned errors for some of these cases. @@ -1277,13 +1285,16 @@ func TestIntegration_GeneratedColumns(t *testing.T) { defer cancel() tableName := "SongWriters" - err := updateDDL(t, adminClient, `CREATE TABLE `+tableName+` ( Name STRING(50) NOT NULL, NumSongs INT64, + CreatedAT TIMESTAMP, + CreatedDate DATE, EstimatedSales INT64 NOT NULL, CanonicalName STRING(50) AS (LOWER(Name)) STORED, + GeneratedCreatedDate DATE AS (EXTRACT(DATE FROM CreatedAT AT TIME ZONE "CET")) STORED, + GeneratedCreatedDay INT64 AS (EXTRACT(DAY FROM CreatedDate)) STORED, ) PRIMARY KEY (Name)`) if err != nil { t.Fatalf("Setting up fresh table: %v", err) @@ -1295,16 +1306,18 @@ func TestIntegration_GeneratedColumns(t *testing.T) { } // Insert some data. + d1, _ := civil.ParseDate("2016-11-15") + t1, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z") _, err = client.Apply(ctx, []*spanner.Mutation{ spanner.Insert(tableName, - []string{"Name", "EstimatedSales", "NumSongs"}, - []interface{}{"Average Writer", 10, 10}), + []string{"Name", "EstimatedSales", "NumSongs", "CreatedAT", "CreatedDate"}, + []interface{}{"Average Writer", 10, 10, t1, d1}), spanner.Insert(tableName, - []string{"Name", "EstimatedSales"}, - []interface{}{"Great Writer", 100}), + []string{"Name", "EstimatedSales", "CreatedAT", "CreatedDate"}, + []interface{}{"Great Writer", 100, t1, d1}), spanner.Insert(tableName, - []string{"Name", "EstimatedSales", "NumSongs"}, - []interface{}{"Poor Writer", 1, 50}), + []string{"Name", "EstimatedSales", "NumSongs", "CreatedAT", "CreatedDate"}, + []interface{}{"Poor Writer", 1, 50, t1, d1}), }) if err != nil { t.Fatalf("Applying mutations: %v", err) @@ -1317,7 +1330,7 @@ func TestIntegration_GeneratedColumns(t *testing.T) { } ri := client.Single().Query(ctx, spanner.NewStatement( - `SELECT CanonicalName, TotalSales FROM `+tableName+` ORDER BY Name`, + `SELECT CanonicalName, TotalSales, GeneratedCreatedDate, GeneratedCreatedDay FROM `+tableName+` ORDER BY Name`, )) all, err := slurpRows(t, ri) if err != nil { @@ -1326,9 +1339,9 @@ func TestIntegration_GeneratedColumns(t *testing.T) { // Great writer has nil because NumSongs is nil want := [][]interface{}{ - {"average writer", int64(100)}, - {"great writer", nil}, - {"poor writer", int64(50)}, + {"average writer", int64(100), civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)}, + {"great writer", nil, civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)}, + {"poor writer", int64(50), civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)}, } if !reflect.DeepEqual(all, want) { t.Errorf("Expected values are wrong.\n got %v\nwant %v", all, want) diff --git a/spanner/spansql/keywords.go b/spanner/spansql/keywords.go index 0ddade0b1b3..d005f81dca7 100644 --- a/spanner/spansql/keywords.go +++ b/spanner/spansql/keywords.go @@ -134,9 +134,10 @@ func init() { for _, f := range allFuncs { funcs[f] = true } - // Special case for CAST and SAFE_CAST + // Special case for CAST, SAFE_CAST and EXTRACT funcArgParsers["CAST"] = typedArgParser funcArgParsers["SAFE_CAST"] = typedArgParser + funcArgParsers["EXTRACT"] = extractArgParser } var allFuncs = []string{ diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index f4467a306ae..f540fa04800 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -1901,6 +1901,27 @@ func (p *parser) parseType() (Type, *parseError) { return p.parseBaseOrParameterizedType(true) } +var extractPartTypes = map[string]TypeBase{ + "DAY": Int64, + "MONTH": Int64, + "YEAR": Int64, + "DATE": Date, +} + +func (p *parser) parseExtractType() (Type, string, *parseError) { + var t Type + tok := p.next() + if tok.err != nil { + return Type{}, "", tok.err + } + base, ok := extractPartTypes[strings.ToUpper(tok.value)] // valid part types for EXTRACT is keyed by upper case strings. + if !ok { + return Type{}, "", p.errorf("got %q, want valid EXTRACT types", tok.value) + } + t.Base = base + return t, strings.ToUpper(tok.value), nil +} + func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) { debugf("parseBaseOrParameterizedType: %v", p) @@ -2482,6 +2503,34 @@ var typedArgParser = func(p *parser) (Expr, *parseError) { }, nil } +// Special argument parser for EXTRACT +var extractArgParser = func(p *parser) (Expr, *parseError) { + partType, part, err := p.parseExtractType() + if err != nil { + return nil, err + } + if err := p.expect("FROM"); err != nil { + return nil, err + } + e, err := p.parseExpr() + if err != nil { + return nil, err + } + // AT TIME ZONE is optional + if p.eat("AT", "TIME", "ZONE") { + tok := p.next() + if tok.err != nil { + return nil, err + } + return ExtractExpr{Part: part, Type: partType, Expr: AtTimeZoneExpr{Expr: e, Zone: tok.string, Type: Type{Base: Timestamp}}}, nil + } + return ExtractExpr{ + Part: part, + Expr: e, + Type: partType, + }, nil +} + /* Expressions diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index 429355dbec6..2e404f6d443 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -340,6 +340,8 @@ func TestParseExpr(t *testing.T) { {`STARTS_WITH(Bar, 'B')`, Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}}}, {`CAST(Bar AS STRING)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}}, {`SAFE_CAST(Bar AS INT64)`, Func{Name: "SAFE_CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Int64}}}}}, + {`EXTRACT(DATE FROM TIMESTAMP AT TIME ZONE "America/Los_Angeles")`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("TIMESTAMP"), Zone: "America/Los_Angeles", Type: Type{Base: Timestamp}}}}}}, + {`EXTRACT(DAY FROM DATE)`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DAY", Expr: ID("DATE"), Type: Type{Base: Int64}}}}}, // String literal: // Accept double quote and single quote. @@ -524,7 +526,9 @@ func TestParseDDL(t *testing.T) { CREATE TABLE users ( user_id STRING(36) NOT NULL, some_string STRING(16) NOT NULL, + some_time TIMESTAMP NOT NULL, number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED, + generated_date DATE AS (EXTRACT(DATE FROM some_time AT TIME ZONE "CET")) STORED, ) PRIMARY KEY(user_id); -- Trailing comment at end of file. @@ -744,12 +748,20 @@ func TestParseDDL(t *testing.T) { Columns: []ColumnDef{ {Name: "user_id", Type: Type{Base: String, Len: 36}, NotNull: true, Position: line(67)}, {Name: "some_string", Type: Type{Base: String, Len: 16}, NotNull: true, Position: line(68)}, + {Name: "some_time", Type: Type{Base: Timestamp}, NotNull: true, Position: line(69)}, { Name: "number_key", Type: Type{Base: Int64}, Generated: Func{Name: "SAFE_CAST", Args: []Expr{ TypedExpr{Expr: Func{Name: "SUBSTR", Args: []Expr{ID("some_string"), IntegerLiteral(2)}}, Type: Type{Base: Int64}}, }}, - Position: line(69), + Position: line(70), + }, + { + Name: "generated_date", Type: Type{Base: Date}, + Generated: Func{Name: "EXTRACT", Args: []Expr{ + ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("some_time"), Zone: "CET", Type: Type{Base: Timestamp}}}, + }}, + Position: line(71), }, }, PrimaryKey: []KeyPart{{Column: "user_id"}}, @@ -777,7 +789,7 @@ func TestParseDDL(t *testing.T) { {Marker: "--", Isolated: true, Start: line(49), End: line(49), Text: []string{"Table with row deletion policy."}}, // Comment after everything else. - {Marker: "--", Isolated: true, Start: line(72), End: line(72), Text: []string{"Trailing comment at end of file."}}, + {Marker: "--", Isolated: true, Start: line(74), End: line(74), Text: []string{"Trailing comment at end of file."}}, }}}, // No trailing comma: {`ALTER TABLE T ADD COLUMN C2 INT64`, &DDL{Filename: "filename", List: []DDLStmt{ diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index 29c5a5dac54..ef877369e6f 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -589,6 +589,20 @@ func (te TypedExpr) addSQL(sb *strings.Builder) { sb.WriteString(te.Type.SQL()) } +func (ee ExtractExpr) SQL() string { return buildSQL(ee) } +func (ee ExtractExpr) addSQL(sb *strings.Builder) { + sb.WriteString(ee.Part) + sb.WriteString(" FROM ") + ee.Expr.addSQL(sb) +} + +func (aze AtTimeZoneExpr) SQL() string { return buildSQL(aze) } +func (aze AtTimeZoneExpr) addSQL(sb *strings.Builder) { + aze.Expr.addSQL(sb) + sb.WriteString(" AT TIME ZONE ") + sb.WriteString(aze.Zone) +} + func idList(l []ID, join string) string { var ss []string for _, s := range l { diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index 074e07cf0f1..9d099cc3505 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -647,6 +647,24 @@ type TypedExpr struct { func (TypedExpr) isBoolExpr() {} // possibly bool func (TypedExpr) isExpr() {} +type ExtractExpr struct { + Part string + Type Type + Expr Expr +} + +func (ExtractExpr) isBoolExpr() {} // possibly bool +func (ExtractExpr) isExpr() {} + +type AtTimeZoneExpr struct { + Expr Expr + Type Type + Zone string +} + +func (AtTimeZoneExpr) isBoolExpr() {} // possibly bool +func (AtTimeZoneExpr) isExpr() {} + // Paren represents a parenthesised expression. type Paren struct { Expr Expr