From 4788415c908f58c1cc08c951f1a7f17cdaf35aa2 Mon Sep 17 00:00:00 2001 From: Ryoya Sekino Date: Wed, 2 Dec 2020 13:30:03 +0900 Subject: [PATCH] feat(spanner/spannertest): support AVG aggregation function (#3286) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #3285 Sorry for creating PR without waiting for the response to the above issue 😅 --- spanner/spannertest/funcs.go | 37 +++++++++++++++++++++++++ spanner/spannertest/integration_test.go | 7 +++++ spanner/spansql/keywords.go | 1 + 3 files changed, 45 insertions(+) diff --git a/spanner/spannertest/funcs.go b/spanner/spannertest/funcs.go index cd8aafc253e..3cfa53211e4 100644 --- a/spanner/spannertest/funcs.go +++ b/spanner/spannertest/funcs.go @@ -109,6 +109,43 @@ var aggregateFuncs = map[string]aggregateFunc{ return sum, typ, nil }, }, + "AVG": { + Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) { + if typ.Array || !(typ.Base == spansql.Int64 || typ.Base == spansql.Float64) { + return nil, spansql.Type{}, fmt.Errorf("AVG only supports arguments of INT64 or FLOAT64 type, not %s", typ.SQL()) + } + if typ.Base == spansql.Int64 { + var sum int64 + var n float64 + for _, v := range values { + if v == nil { + continue + } + sum += v.(int64) + n++ + } + if n == 0 { + // "Returns NULL if the input contains only NULLs". + return nil, typ, nil + } + return (float64(sum) / n), float64Type, nil + } + var sum float64 + var n float64 + for _, v := range values { + if v == nil { + continue + } + sum += v.(float64) + n++ + } + if n == 0 { + // "Returns NULL if the input contains only NULLs". + return nil, typ, nil + } + return (sum / n), typ, nil + }, + }, } func evalMinMax(name string, isMin bool, values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) { diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index 365dd78d0ca..7f452464d5c 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -889,6 +889,13 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { {int64(1), int64(25)}, // Jack(ID=1, Tenure=10), Sam(ID=3, Tenure=9), George(ID=5, Tenure=6) }, }, + { + `SELECT AVG(Height) FROM Staff WHERE ID <= 2`, + nil, + [][]interface{}{ + {float64(1.84)}, + }, + }, { `SELECT MAX(Name) FROM Staff WHERE Name < @lim`, map[string]interface{}{"lim": "Teal'c"}, diff --git a/spanner/spansql/keywords.go b/spanner/spansql/keywords.go index 64672d39568..7964385d2da 100644 --- a/spanner/spansql/keywords.go +++ b/spanner/spansql/keywords.go @@ -130,6 +130,7 @@ var keywords = map[string]bool{ var funcs = map[string]bool{ // Aggregate functions. "ARRAY_AGG": true, + "AVG": true, "BIT_XOR": true, "COUNT": true, "MAX": true,