Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner/spansql): support EXTRACT #5218

Merged
merged 7 commits into from Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 59 additions & 0 deletions spanner/spannertest/db_eval.go
Expand Up @@ -469,6 +469,10 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) {
return ec.evalExpr(e.Expr)
case spansql.TypedExpr:
return ec.evalTypedExpr(e)
case spansql.ExtractExpr:
return ec.evalExtractExpr(e)
case spansql.AtTimeZoneExpr:
return ec.evalAtTimeZoneExpr(e)
case spansql.Func:
v, _, err := ec.evalFunc(e)
if err != nil {
Expand Down Expand Up @@ -675,6 +679,61 @@ func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (result interface{},
return convert(val, expr.Type)
}

func (ec evalContext) evalExtractExpr(expr spansql.ExtractExpr) (result interface{}, err error) {
val, err := ec.evalExpr(expr.Expr)
if err != nil {
return nil, err
}
switch expr.Part {
case "DATE":
switch v := val.(type) {
case time.Time:
return civil.DateOf(v), nil
case civil.Date:
return v, nil
}
case "DAY":
switch v := val.(type) {
case time.Time:
return int64(v.Day()), nil
case civil.Date:
return int64(v.Day), nil
}
case "MONTH":
switch v := val.(type) {
case time.Time:
return int64(v.Month()), nil
case civil.Date:
return int64(v.Month), nil
}
case "YEAR":
switch v := val.(type) {
case time.Time:
return int64(v.Year()), nil
case civil.Date:
return int64(v.Year), nil
}
}
return nil, fmt.Errorf("Extract with part %v not supported", expr.Part)
}

func (ec evalContext) evalAtTimeZoneExpr(expr spansql.AtTimeZoneExpr) (result interface{}, err error) {
val, err := ec.evalExpr(expr.Expr)
if err != nil {
return nil, err
}
switch v := val.(type) {
case time.Time:
loc, err := time.LoadLocation(expr.Zone)
if err != nil {
return nil, fmt.Errorf("AtTimeZone with %T not supported", v)
}
return v.In(loc), nil
default:
return nil, fmt.Errorf("AtTimeZone with %T not supported", val)
}
}

func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) {
switch v := lop.(type) {
case spansql.IntegerLiteral:
Expand Down
26 changes: 26 additions & 0 deletions spanner/spannertest/funcs.go
Expand Up @@ -107,6 +107,32 @@ var functions = map[string]function{
return "", spansql.Type{Base: spansql.String}, nil
},
},
"EXTRACT": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
date, okArg1 := values[0].(civil.Date)
part, okArg2 := values[0].(int64)
if !(okArg1 || okArg2) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function EXTRACT for the given argument types")
}
if okArg1 {
return date, spansql.Type{Base: spansql.Date}, nil
}
return part, spansql.Type{Base: spansql.Int64}, nil
},
},
"TIMESTAMP": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
t, okArg1 := values[0].(string)
if !(okArg1) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types")
}
timestamp, err := time.Parse(time.RFC3339, t)
if err != nil {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types")
}
return timestamp, spansql.Type{Base: spansql.Timestamp}, nil
},
},
}

func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) {
Expand Down
39 changes: 26 additions & 13 deletions spanner/spannertest/integration_test.go
Expand Up @@ -748,16 +748,24 @@ func TestIntegration_ReadsAndQueries(t *testing.T) {
}
rows.Stop()

rows = client.Single().Query(ctx, spanner.NewStatement("SELECT EXTRACT(INVALID_PART FROM TIMESTAMP('2008-12-25T05:30:00Z')"))
_, err = rows.Next()
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Errorf("error code mismatch for invalid part from EXTRACT\n Got: %v\nWant: %v", g, w)
}
rows.Stop()

// Do some complex queries.
tests := []struct {
q string
params map[string]interface{}
want [][]interface{}
}{
{
`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64)`,

`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64), EXTRACT(DATE FROM TIMESTAMP('2008-12-25T05:30:00Z') AT TIME ZONE 'Europe/Amsterdam'), EXTRACT(YEAR FROM TIMESTAMP('2008-12-25T05:30:00Z'))`,
nil,
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil}},
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil, civil.Date{Year: 2008, Month: 12, Day: 25}, int64(2008)}},
},
// Check handling of NULL values for the IS operator.
// There was a bug that returned errors for some of these cases.
Expand Down Expand Up @@ -1277,13 +1285,16 @@ func TestIntegration_GeneratedColumns(t *testing.T) {
defer cancel()

tableName := "SongWriters"

err := updateDDL(t, adminClient,
`CREATE TABLE `+tableName+` (
Name STRING(50) NOT NULL,
NumSongs INT64,
CreatedAT TIMESTAMP,
CreatedDate DATE,
EstimatedSales INT64 NOT NULL,
CanonicalName STRING(50) AS (LOWER(Name)) STORED,
GeneratedCreatedDate DATE AS (EXTRACT(DATE FROM CreatedAT AT TIME ZONE "CET")) STORED,
GeneratedCreatedDay INT64 AS (EXTRACT(DAY FROM CreatedDate)) STORED,
) PRIMARY KEY (Name)`)
if err != nil {
t.Fatalf("Setting up fresh table: %v", err)
Expand All @@ -1295,16 +1306,18 @@ func TestIntegration_GeneratedColumns(t *testing.T) {
}

// Insert some data.
d1, _ := civil.ParseDate("2016-11-15")
t1, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z")
_, err = client.Apply(ctx, []*spanner.Mutation{
spanner.Insert(tableName,
[]string{"Name", "EstimatedSales", "NumSongs"},
[]interface{}{"Average Writer", 10, 10}),
[]string{"Name", "EstimatedSales", "NumSongs", "CreatedAT", "CreatedDate"},
[]interface{}{"Average Writer", 10, 10, t1, d1}),
spanner.Insert(tableName,
[]string{"Name", "EstimatedSales"},
[]interface{}{"Great Writer", 100}),
[]string{"Name", "EstimatedSales", "CreatedAT", "CreatedDate"},
[]interface{}{"Great Writer", 100, t1, d1}),
spanner.Insert(tableName,
[]string{"Name", "EstimatedSales", "NumSongs"},
[]interface{}{"Poor Writer", 1, 50}),
[]string{"Name", "EstimatedSales", "NumSongs", "CreatedAT", "CreatedDate"},
[]interface{}{"Poor Writer", 1, 50, t1, d1}),
})
if err != nil {
t.Fatalf("Applying mutations: %v", err)
Expand All @@ -1317,7 +1330,7 @@ func TestIntegration_GeneratedColumns(t *testing.T) {
}

ri := client.Single().Query(ctx, spanner.NewStatement(
`SELECT CanonicalName, TotalSales FROM `+tableName+` ORDER BY Name`,
`SELECT CanonicalName, TotalSales, GeneratedCreatedDate, GeneratedCreatedDay FROM `+tableName+` ORDER BY Name`,
))
all, err := slurpRows(t, ri)
if err != nil {
Expand All @@ -1326,9 +1339,9 @@ func TestIntegration_GeneratedColumns(t *testing.T) {

// Great writer has nil because NumSongs is nil
want := [][]interface{}{
{"average writer", int64(100)},
{"great writer", nil},
{"poor writer", int64(50)},
{"average writer", int64(100), civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)},
{"great writer", nil, civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)},
{"poor writer", int64(50), civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)},
}
if !reflect.DeepEqual(all, want) {
t.Errorf("Expected values are wrong.\n got %v\nwant %v", all, want)
Expand Down
3 changes: 2 additions & 1 deletion spanner/spansql/keywords.go
Expand Up @@ -134,9 +134,10 @@ func init() {
for _, f := range allFuncs {
funcs[f] = true
}
// Special case for CAST and SAFE_CAST
// Special case for CAST, SAFE_CAST and EXTRACT
funcArgParsers["CAST"] = typedArgParser
funcArgParsers["SAFE_CAST"] = typedArgParser
funcArgParsers["EXTRACT"] = extractArgParser
}

var allFuncs = []string{
Expand Down
49 changes: 49 additions & 0 deletions spanner/spansql/parser.go
Expand Up @@ -1901,6 +1901,27 @@ func (p *parser) parseType() (Type, *parseError) {
return p.parseBaseOrParameterizedType(true)
}

var extractPartTypes = map[string]TypeBase{
"DAY": Int64,
"MONTH": Int64,
"YEAR": Int64,
"DATE": Date,
}

func (p *parser) parseExtractType() (Type, string, *parseError) {
var t Type
tok := p.next()
if tok.err != nil {
return Type{}, "", tok.err
}
base, ok := extractPartTypes[strings.ToUpper(tok.value)] // valid part types for EXTRACT is keyed by upper case strings.
if !ok {
return Type{}, "", p.errorf("got %q, want valid EXTRACT types", tok.value)
}
t.Base = base
return t, strings.ToUpper(tok.value), nil
}

func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) {
debugf("parseBaseOrParameterizedType: %v", p)

Expand Down Expand Up @@ -2482,6 +2503,34 @@ var typedArgParser = func(p *parser) (Expr, *parseError) {
}, nil
}

// Special argument parser for EXTRACT
var extractArgParser = func(p *parser) (Expr, *parseError) {
partType, part, err := p.parseExtractType()
if err != nil {
return nil, err
}
if err := p.expect("FROM"); err != nil {
return nil, err
}
e, err := p.parseExpr()
if err != nil {
return nil, err
}
// AT TIME ZONE is optional
if p.eat("AT", "TIME", "ZONE") {
tok := p.next()
if tok.err != nil {
return nil, err
}
return ExtractExpr{Part: part, Type: partType, Expr: AtTimeZoneExpr{Expr: e, Zone: tok.string, Type: Type{Base: Timestamp}}}, nil
}
return ExtractExpr{
Part: part,
Expr: e,
Type: partType,
}, nil
}

/*
Expressions

Expand Down
16 changes: 14 additions & 2 deletions spanner/spansql/parser_test.go
Expand Up @@ -340,6 +340,8 @@ func TestParseExpr(t *testing.T) {
{`STARTS_WITH(Bar, 'B')`, Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}}},
{`CAST(Bar AS STRING)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}},
{`SAFE_CAST(Bar AS INT64)`, Func{Name: "SAFE_CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Int64}}}}},
{`EXTRACT(DATE FROM TIMESTAMP AT TIME ZONE "America/Los_Angeles")`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("TIMESTAMP"), Zone: "America/Los_Angeles", Type: Type{Base: Timestamp}}}}}},
{`EXTRACT(DAY FROM DATE)`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DAY", Expr: ID("DATE"), Type: Type{Base: Int64}}}}},

// String literal:
// Accept double quote and single quote.
Expand Down Expand Up @@ -524,7 +526,9 @@ func TestParseDDL(t *testing.T) {
CREATE TABLE users (
user_id STRING(36) NOT NULL,
some_string STRING(16) NOT NULL,
some_time TIMESTAMP NOT NULL,
number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED,
generated_date DATE AS (EXTRACT(DATE FROM some_time AT TIME ZONE "CET")) STORED,
) PRIMARY KEY(user_id);

-- Trailing comment at end of file.
Expand Down Expand Up @@ -744,12 +748,20 @@ func TestParseDDL(t *testing.T) {
Columns: []ColumnDef{
{Name: "user_id", Type: Type{Base: String, Len: 36}, NotNull: true, Position: line(67)},
{Name: "some_string", Type: Type{Base: String, Len: 16}, NotNull: true, Position: line(68)},
{Name: "some_time", Type: Type{Base: Timestamp}, NotNull: true, Position: line(69)},
{
Name: "number_key", Type: Type{Base: Int64},
Generated: Func{Name: "SAFE_CAST", Args: []Expr{
TypedExpr{Expr: Func{Name: "SUBSTR", Args: []Expr{ID("some_string"), IntegerLiteral(2)}}, Type: Type{Base: Int64}},
}},
Position: line(69),
Position: line(70),
},
{
Name: "generated_date", Type: Type{Base: Date},
Generated: Func{Name: "EXTRACT", Args: []Expr{
ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("some_time"), Zone: "CET", Type: Type{Base: Timestamp}}},
}},
Position: line(71),
},
},
PrimaryKey: []KeyPart{{Column: "user_id"}},
Expand Down Expand Up @@ -777,7 +789,7 @@ func TestParseDDL(t *testing.T) {
{Marker: "--", Isolated: true, Start: line(49), End: line(49), Text: []string{"Table with row deletion policy."}},

// Comment after everything else.
{Marker: "--", Isolated: true, Start: line(72), End: line(72), Text: []string{"Trailing comment at end of file."}},
{Marker: "--", Isolated: true, Start: line(74), End: line(74), Text: []string{"Trailing comment at end of file."}},
}}},
// No trailing comma:
{`ALTER TABLE T ADD COLUMN C2 INT64`, &DDL{Filename: "filename", List: []DDLStmt{
Expand Down
14 changes: 14 additions & 0 deletions spanner/spansql/sql.go
Expand Up @@ -589,6 +589,20 @@ func (te TypedExpr) addSQL(sb *strings.Builder) {
sb.WriteString(te.Type.SQL())
}

func (ee ExtractExpr) SQL() string { return buildSQL(ee) }
func (ee ExtractExpr) addSQL(sb *strings.Builder) {
sb.WriteString(ee.Part)
sb.WriteString(" FROM ")
ee.Expr.addSQL(sb)
}

func (aze AtTimeZoneExpr) SQL() string { return buildSQL(aze) }
func (aze AtTimeZoneExpr) addSQL(sb *strings.Builder) {
aze.Expr.addSQL(sb)
sb.WriteString(" AT TIME ZONE ")
sb.WriteString(aze.Zone)
}

func idList(l []ID, join string) string {
var ss []string
for _, s := range l {
Expand Down
18 changes: 18 additions & 0 deletions spanner/spansql/types.go
Expand Up @@ -647,6 +647,24 @@ type TypedExpr struct {
func (TypedExpr) isBoolExpr() {} // possibly bool
func (TypedExpr) isExpr() {}

type ExtractExpr struct {
Part string
Type Type
Expr Expr
}

func (ExtractExpr) isBoolExpr() {} // possibly bool
func (ExtractExpr) isExpr() {}

type AtTimeZoneExpr struct {
Expr Expr
Type Type
Zone string
}

func (AtTimeZoneExpr) isBoolExpr() {} // possibly bool
func (AtTimeZoneExpr) isExpr() {}

// Paren represents a parenthesised expression.
type Paren struct {
Expr Expr
Expand Down