From a781a3ad0c626fc0a7aff0ce33b1ef0830ee2259 Mon Sep 17 00:00:00 2001 From: tmdiep Date: Fri, 20 Nov 2020 08:08:27 +1100 Subject: [PATCH 1/6] fix(pubsub): Remove NewMessage and NewPublishResult (#3232) Rolls back functions that we have decided not to expose to users. --- pubsub/message.go | 6 ------ pubsub/topic.go | 7 ------- 2 files changed, 13 deletions(-) diff --git a/pubsub/message.go b/pubsub/message.go index 54be363eefa..c09a7e51c40 100644 --- a/pubsub/message.go +++ b/pubsub/message.go @@ -61,12 +61,6 @@ type Message struct { ackh ackHandler } -// NewMessage creates a message with a custom ack/nack handler, which should not -// be nil. -func NewMessage(ackh ackHandler) *Message { - return &Message{ackh: ackh} -} - func toMessage(resp *pb.ReceivedMessage) (*Message, error) { if resp.Message == nil { return &Message{ackh: &psAckHandler{ackID: resp.AckId}}, nil diff --git a/pubsub/topic.go b/pubsub/topic.go index 1dd50c00291..07e392e0869 100644 --- a/pubsub/topic.go +++ b/pubsub/topic.go @@ -470,13 +470,6 @@ type PublishResult struct { err error } -// NewPublishResult returns the set() function to enable callers from outside -// this package to store and call it (e.g. unit tests). -func NewPublishResult() (*PublishResult, func(string, error)) { - result := &PublishResult{ready: make(chan struct{})} - return result, result.set -} - // Ready returns a channel that is closed when the result is ready. // When the Ready channel is closed, Get is guaranteed not to block. func (r *PublishResult) Ready() <-chan struct{} { return r.ready } From ea3cde55ad3d8d843bce8d023747cf69552850b5 Mon Sep 17 00:00:00 2001 From: shollyman Date: Thu, 19 Nov 2020 14:18:03 -0800 Subject: [PATCH 2/6] feat(bigquery): add support for bignumeric (#2779) feat(bigquery): add support for bignumeric This PR adds basic support for the BIGNUMERIC type in BigQuery. This library has several cases where it tries to infer the appropriate BigQuery type from a native Go type. For big.Rat types, we continue the existing behavior of mapping them to NUMERIC, as the big.Rat doesn't provide a general way of indicating desired precision or scale to determine whether BIGNUMERIC is a more appropriate mapping. --- bigquery/integration_test.go | 28 ++++---- bigquery/params.go | 41 +++++++----- bigquery/schema.go | 32 +++++---- bigquery/schema_test.go | 6 +- bigquery/value.go | 34 +++++++++- bigquery/value_test.go | 124 +++++++++++++++++++++-------------- go.sum | 2 + 7 files changed, 172 insertions(+), 95 deletions(-) diff --git a/bigquery/integration_test.go b/bigquery/integration_test.go index 14573e2bb3a..60959d0a0ca 100644 --- a/bigquery/integration_test.go +++ b/bigquery/integration_test.go @@ -1434,6 +1434,7 @@ func TestIntegration_InsertAndReadNullable(t *testing.T) { ctm := civil.Time{Hour: 15, Minute: 4, Second: 5, Nanosecond: 6000} cdt := civil.DateTime{Date: testDate, Time: ctm} rat := big.NewRat(33, 100) + rat2 := big.NewRat(66, 100) geo := "POINT(-122.198939 47.669865)" // Nil fields in the struct. @@ -1455,20 +1456,21 @@ func TestIntegration_InsertAndReadNullable(t *testing.T) { // Populate the struct with values. testInsertAndReadNullable(t, testStructNullable{ - String: NullString{"x", true}, - Bytes: []byte{1, 2, 3}, - Integer: NullInt64{1, true}, - Float: NullFloat64{2.3, true}, - Boolean: NullBool{true, true}, - Timestamp: NullTimestamp{testTimestamp, true}, - Date: NullDate{testDate, true}, - Time: NullTime{ctm, true}, - DateTime: NullDateTime{cdt, true}, - Numeric: rat, - Geography: NullGeography{geo, true}, - Record: &subNullable{X: NullInt64{4, true}}, + String: NullString{"x", true}, + Bytes: []byte{1, 2, 3}, + Integer: NullInt64{1, true}, + Float: NullFloat64{2.3, true}, + Boolean: NullBool{true, true}, + Timestamp: NullTimestamp{testTimestamp, true}, + Date: NullDate{testDate, true}, + Time: NullTime{ctm, true}, + DateTime: NullDateTime{cdt, true}, + Numeric: rat, + BigNumeric: rat2, + Geography: NullGeography{geo, true}, + Record: &subNullable{X: NullInt64{4, true}}, }, - []Value{"x", []byte{1, 2, 3}, int64(1), 2.3, true, testTimestamp, testDate, ctm, cdt, rat, geo, []Value{int64(4)}}) + []Value{"x", []byte{1, 2, 3}, int64(1), 2.3, true, testTimestamp, testDate, ctm, cdt, rat, rat2, geo, []Value{int64(4)}}) } func testInsertAndReadNullable(t *testing.T, ts testStructNullable, wantRow []Value) { diff --git a/bigquery/params.go b/bigquery/params.go index 59570226189..68bb3fa01ac 100644 --- a/bigquery/params.go +++ b/bigquery/params.go @@ -65,16 +65,17 @@ func (e invalidFieldNameError) Error() string { var fieldCache = fields.NewCache(bqTagParser, nil, nil) var ( - int64ParamType = &bq.QueryParameterType{Type: "INT64"} - float64ParamType = &bq.QueryParameterType{Type: "FLOAT64"} - boolParamType = &bq.QueryParameterType{Type: "BOOL"} - stringParamType = &bq.QueryParameterType{Type: "STRING"} - bytesParamType = &bq.QueryParameterType{Type: "BYTES"} - dateParamType = &bq.QueryParameterType{Type: "DATE"} - timeParamType = &bq.QueryParameterType{Type: "TIME"} - dateTimeParamType = &bq.QueryParameterType{Type: "DATETIME"} - timestampParamType = &bq.QueryParameterType{Type: "TIMESTAMP"} - numericParamType = &bq.QueryParameterType{Type: "NUMERIC"} + int64ParamType = &bq.QueryParameterType{Type: "INT64"} + float64ParamType = &bq.QueryParameterType{Type: "FLOAT64"} + boolParamType = &bq.QueryParameterType{Type: "BOOL"} + stringParamType = &bq.QueryParameterType{Type: "STRING"} + bytesParamType = &bq.QueryParameterType{Type: "BYTES"} + dateParamType = &bq.QueryParameterType{Type: "DATE"} + timeParamType = &bq.QueryParameterType{Type: "TIME"} + dateTimeParamType = &bq.QueryParameterType{Type: "DATETIME"} + timestampParamType = &bq.QueryParameterType{Type: "TIMESTAMP"} + numericParamType = &bq.QueryParameterType{Type: "NUMERIC"} + bigNumericParamType = &bq.QueryParameterType{Type: "BIGNUMERIC"} ) var ( @@ -233,6 +234,9 @@ func paramValue(v reflect.Value) (bq.QueryParameterValue, error) { return res, nil case typeOfRat: + // big.Rat types don't communicate scale or precision, so we cannot + // disambiguate between NUMERIC and BIGNUMERIC. For now, we'll continue + // to honor previous behavior and send as Numeric type. res.Value = NumericString(v.Interface().(*big.Rat)) return res, nil } @@ -304,14 +308,15 @@ func bqToQueryParameter(q *bq.QueryParameter) (QueryParameter, error) { } var paramTypeToFieldType = map[string]FieldType{ - int64ParamType.Type: IntegerFieldType, - float64ParamType.Type: FloatFieldType, - boolParamType.Type: BooleanFieldType, - stringParamType.Type: StringFieldType, - bytesParamType.Type: BytesFieldType, - dateParamType.Type: DateFieldType, - timeParamType.Type: TimeFieldType, - numericParamType.Type: NumericFieldType, + int64ParamType.Type: IntegerFieldType, + float64ParamType.Type: FloatFieldType, + boolParamType.Type: BooleanFieldType, + stringParamType.Type: StringFieldType, + bytesParamType.Type: BytesFieldType, + dateParamType.Type: DateFieldType, + timeParamType.Type: TimeFieldType, + numericParamType.Type: NumericFieldType, + bigNumericParamType.Type: BigNumericFieldType, } // Convert a parameter value from the service to a Go value. This is similar to, but diff --git a/bigquery/schema.go b/bigquery/schema.go index a9a50ef0006..dcb5c6f9b77 100644 --- a/bigquery/schema.go +++ b/bigquery/schema.go @@ -182,23 +182,27 @@ const ( // GeographyFieldType is a string field type. Geography types represent a set of points // on the Earth's surface, represented in Well Known Text (WKT) format. GeographyFieldType FieldType = "GEOGRAPHY" + // BigNumericFieldType is a numeric field type that supports values of larger precision + // and scale than the NumericFieldType. + BigNumericFieldType FieldType = "BIGNUMERIC" ) var ( errEmptyJSONSchema = errors.New("bigquery: empty JSON schema") fieldTypes = map[FieldType]bool{ - StringFieldType: true, - BytesFieldType: true, - IntegerFieldType: true, - FloatFieldType: true, - BooleanFieldType: true, - TimestampFieldType: true, - RecordFieldType: true, - DateFieldType: true, - TimeFieldType: true, - DateTimeFieldType: true, - NumericFieldType: true, - GeographyFieldType: true, + StringFieldType: true, + BytesFieldType: true, + IntegerFieldType: true, + FloatFieldType: true, + BooleanFieldType: true, + TimestampFieldType: true, + RecordFieldType: true, + DateFieldType: true, + TimeFieldType: true, + DateTimeFieldType: true, + NumericFieldType: true, + GeographyFieldType: true, + BigNumericFieldType: true, } // The API will accept alias names for the types based on the Standard SQL type names. fieldAliases = map[FieldType]FieldType{ @@ -346,6 +350,10 @@ func inferFieldSchema(fieldName string, rt reflect.Type, nullable bool) (*FieldS case typeOfDateTime: return &FieldSchema{Required: true, Type: DateTimeFieldType}, nil case typeOfRat: + // We automatically infer big.Rat values as NUMERIC as we cannot + // determine precision/scale from the type. Users who want the + // larger precision of BIGNUMERIC need to manipulate the inferred + // schema. return &FieldSchema{Required: !nullable, Type: NumericFieldType}, nil } if ft := nullableFieldType(rt); ft != "" { diff --git a/bigquery/schema_test.go b/bigquery/schema_test.go index 5bd10849ac3..ae936980c78 100644 --- a/bigquery/schema_test.go +++ b/bigquery/schema_test.go @@ -1041,7 +1041,8 @@ func TestSchemaFromJSON(t *testing.T) { {"name":"flat_date","type":"DATE","mode":"NULLABLE","description":"Flat required DATE"}, {"name":"flat_time","type":"TIME","mode":"REQUIRED","description":"Flat nullable TIME"}, {"name":"flat_datetime","type":"DATETIME","mode":"NULLABLE","description":"Flat required DATETIME"}, - {"name":"flat_numeric","type":"NUMERIC","mode":"REQUIRED","description":"Flat nullable NUMERIC"}, + {"name":"flat_numeric","type":"NUMERIC","mode":"REQUIRED","description":"Flat required NUMERIC"}, + {"name":"flat_bignumeric","type":"BIGNUMERIC","mode":"NULLABLE","description":"Flat nullable BIGNUMERIC"}, {"name":"flat_geography","type":"GEOGRAPHY","mode":"REQUIRED","description":"Flat required GEOGRAPHY"}, {"name":"aliased_integer","type":"INT64","mode":"REQUIRED","description":"Aliased required integer"}, {"name":"aliased_boolean","type":"BOOL","mode":"NULLABLE","description":"Aliased nullable boolean"}, @@ -1058,7 +1059,8 @@ func TestSchemaFromJSON(t *testing.T) { fieldSchema("Flat required DATE", "flat_date", "DATE", false, false, nil), fieldSchema("Flat nullable TIME", "flat_time", "TIME", false, true, nil), fieldSchema("Flat required DATETIME", "flat_datetime", "DATETIME", false, false, nil), - fieldSchema("Flat nullable NUMERIC", "flat_numeric", "NUMERIC", false, true, nil), + fieldSchema("Flat required NUMERIC", "flat_numeric", "NUMERIC", false, true, nil), + fieldSchema("Flat nullable BIGNUMERIC", "flat_bignumeric", "BIGNUMERIC", false, false, nil), fieldSchema("Flat required GEOGRAPHY", "flat_geography", "GEOGRAPHY", false, true, nil), fieldSchema("Aliased required integer", "aliased_integer", "INTEGER", false, true, nil), fieldSchema("Aliased nullable boolean", "aliased_boolean", "BOOLEAN", false, false, nil), diff --git a/bigquery/value.go b/bigquery/value.go index 361341c2bb8..cae28be0464 100644 --- a/bigquery/value.go +++ b/bigquery/value.go @@ -407,6 +407,13 @@ func determineSetFunc(ftype reflect.Type, stype FieldType) setFunc { return setNull(v, x, func() interface{} { return x.(*big.Rat) }) } } + + case BigNumericFieldType: + if ftype == typeOfRat { + return func(v reflect.Value, x interface{}) error { + return setNull(v, x, func() interface{} { return x.(*big.Rat) }) + } + } } return nil } @@ -692,7 +699,7 @@ func structFieldToUploadValue(vfield reflect.Value, schemaField *FieldSchema) (i } func toUploadValue(val interface{}, fs *FieldSchema) interface{} { - if fs.Type == TimeFieldType || fs.Type == DateTimeFieldType || fs.Type == NumericFieldType { + if fs.Type == TimeFieldType || fs.Type == DateTimeFieldType || fs.Type == NumericFieldType || fs.Type == BigNumericFieldType { return toUploadValueReflect(reflect.ValueOf(val), fs) } return val @@ -721,6 +728,13 @@ func toUploadValueReflect(v reflect.Value, fs *FieldSchema) interface{} { return formatUploadValue(v, fs, func(v reflect.Value) string { return NumericString(v.Interface().(*big.Rat)) }) + case BigNumericFieldType: + if r, ok := v.Interface().(*big.Rat); ok && r == nil { + return nil + } + return formatUploadValue(v, fs, func(v reflect.Value) string { + return BigNumericString(v.Interface().(*big.Rat)) + }) default: if !fs.Repeated || v.Len() > 0 { return v.Interface() @@ -786,6 +800,12 @@ const ( // NumericScaleDigits is the maximum number of digits after the decimal point in a NUMERIC value. NumericScaleDigits = 9 + + // BigNumericPrecisionDigits is the maximum number of full digits in a BIGNUMERIC value. + BigNumericPrecisionDigits = 76 + + // BigNumericScaleDigits is the maximum number of full digits in a BIGNUMERIC value. + BigNumericScaleDigits = 38 ) // NumericString returns a string representing a *big.Rat in a format compatible @@ -795,6 +815,12 @@ func NumericString(r *big.Rat) string { return r.FloatString(NumericScaleDigits) } +// BigNumericString returns a string representing a *big.Rat in a format compatible with BigQuery +// SQL. It returns a floating point literal with 38 digits after the decimal point. +func BigNumericString(r *big.Rat) string { + return r.FloatString(BigNumericScaleDigits) +} + // convertRows converts a series of TableRows into a series of Value slices. // schema is used to interpret the data from rows; its length must match the // length of each row. @@ -913,6 +939,12 @@ func convertBasicType(val string, typ FieldType) (Value, error) { return nil, fmt.Errorf("bigquery: invalid NUMERIC value %q", val) } return Value(r), nil + case BigNumericFieldType: + r, ok := (&big.Rat{}).SetString(val) + if !ok { + return nil, fmt.Errorf("bigquery: invalid BIGNUMERIC value %q", val) + } + return Value(r), nil case GeographyFieldType: return val, nil default: diff --git a/bigquery/value_test.go b/bigquery/value_test.go index 1e5bb2d37fc..4d08336ed02 100644 --- a/bigquery/value_test.go +++ b/bigquery/value_test.go @@ -36,6 +36,7 @@ func TestConvertBasicValues(t *testing.T) { {Type: BooleanFieldType}, {Type: BytesFieldType}, {Type: NumericFieldType}, + {Type: BigNumericFieldType}, {Type: GeographyFieldType}, } row := &bq.TableRow{ @@ -46,6 +47,7 @@ func TestConvertBasicValues(t *testing.T) { {V: "true"}, {V: base64.StdEncoding.EncodeToString([]byte("foo"))}, {V: "123.123456789"}, + {V: "99999999999999999999999999999999999999.99999999999999999999999999999999999999"}, {V: testGeography}, }, } @@ -54,7 +56,9 @@ func TestConvertBasicValues(t *testing.T) { t.Fatalf("error converting: %v", err) } - want := []Value{"a", int64(1), 1.2, true, []byte("foo"), big.NewRat(123123456789, 1e9), testGeography} + bigRatVal := new(big.Rat) + bigRatVal.SetString("99999999999999999999999999999999999999.99999999999999999999999999999999999999") + want := []Value{"a", int64(1), 1.2, true, []byte("foo"), big.NewRat(123123456789, 1e9), bigRatVal, testGeography} if !testutil.Equal(got, want) { t.Errorf("converting basic values: got:\n%v\nwant:\n%v", got, want) } @@ -444,6 +448,7 @@ func TestValuesSaverConvertsToMap(t *testing.T) { {Name: "strField", Type: StringFieldType}, {Name: "dtField", Type: DateTimeFieldType}, {Name: "nField", Type: NumericFieldType}, + {Name: "bigNumField", Type: BigNumericFieldType}, {Name: "geoField", Type: GeographyFieldType}, }, InsertID: "iid", @@ -452,16 +457,18 @@ func TestValuesSaverConvertsToMap(t *testing.T) { Date: civil.Date{Year: 1, Month: 2, Day: 3}, Time: civil.Time{Hour: 4, Minute: 5, Second: 6, Nanosecond: 7000}}, big.NewRat(123456789000, 1e9), + big.NewRat(1, 3), testGeography, }, }, wantInsertID: "iid", wantRow: map[string]Value{ - "intField": 1, - "strField": "a", - "dtField": "0001-02-03 04:05:06.000007", - "nField": "123.456789000", - "geoField": testGeography, + "intField": 1, + "strField": "a", + "dtField": "0001-02-03 04:05:06.000007", + "nField": "123.456789000", + "bigNumField": "0.33333333333333333333333333333333333333", + "geoField": testGeography, }, }, { @@ -603,6 +610,8 @@ func TestStructSaver(t *testing.T) { {Name: "p", Type: IntegerFieldType, Required: false}, {Name: "n", Type: NumericFieldType, Required: false}, {Name: "nr", Type: NumericFieldType, Repeated: true}, + {Name: "bn", Type: BigNumericFieldType, Required: false}, + {Name: "bnr", Type: BigNumericFieldType, Repeated: true}, {Name: "g", Type: GeographyFieldType, Required: false}, {Name: "gr", Type: GeographyFieldType, Repeated: true}, } @@ -619,6 +628,8 @@ func TestStructSaver(t *testing.T) { P NullInt64 N *big.Rat NR []*big.Rat + BN *big.Rat + BNR []*big.Rat G NullGeography GR []string // Repeated Geography } @@ -654,6 +665,8 @@ func TestStructSaver(t *testing.T) { P: NullInt64{Valid: true, Int64: 17}, N: big.NewRat(123456, 1000), NR: []*big.Rat{big.NewRat(3, 1), big.NewRat(56789, 1e5)}, + BN: big.NewRat(1, 3), + BNR: []*big.Rat{big.NewRat(1, 3), big.NewRat(1, 2)}, G: NullGeography{Valid: true, GeographyVal: "POINT(-122.350220 47.649154)"}, GR: []string{"POINT(-122.350220 47.649154)", "POINT(-122.198939 47.669865)"}, } @@ -667,6 +680,8 @@ func TestStructSaver(t *testing.T) { "p": NullInt64{Valid: true, Int64: 17}, "n": "123.456000000", "nr": []string{"3.000000000", "0.567890000"}, + "bn": "0.33333333333333333333333333333333333333", + "bnr": []string{"0.33333333333333333333333333333333333333", "0.50000000000000000000000000000000000000"}, "g": NullGeography{Valid: true, GeographyVal: "POINT(-122.350220 47.649154)"}, "gr": []string{"POINT(-122.350220 47.649154)", "POINT(-122.198939 47.669865)"}, } @@ -729,20 +744,24 @@ func TestStructSaverErrors(t *testing.T) { } } -func TestNumericString(t *testing.T) { +func TestNumericStrings(t *testing.T) { for _, test := range []struct { - in *big.Rat - want string + description string + in *big.Rat + wantNumeric string + wantBigNumeric string }{ - {big.NewRat(2, 3), "0.666666667"}, // round to 9 places - {big.NewRat(1, 2), "0.500000000"}, - {big.NewRat(1, 2*1e8), "0.000000005"}, - {big.NewRat(5, 1e10), "0.000000001"}, // round up the 5 in the 10th decimal place - {big.NewRat(-5, 1e10), "-0.000000001"}, // round half away from zero + {"repeating with rounding", big.NewRat(2, 3), "0.666666667", "0.66666666666666666666666666666666666667"}, + {"all zero padding", big.NewRat(1, 2), "0.500000000", "0.50000000000000000000000000000000000000"}, + {"zero pad with digit", big.NewRat(1, 2*1e8), "0.000000005", "0.00000000500000000000000000000000000000"}, + {"smaller rounding case 1", big.NewRat(5, 1e10), "0.000000001", "0.00000000050000000000000000000000000000"}, + {"smaller rounding case 2", big.NewRat(-5, 1e10), "-0.000000001", "-0.00000000050000000000000000000000000000"}, } { - got := NumericString(test.in) - if got != test.want { - t.Errorf("%v: got %q, want %q", test.in, got, test.want) + if got := NumericString(test.in); got != test.wantNumeric { + t.Errorf("case %q, val %v as numeric: got %q, want %q", test.description, test.in, got, test.wantNumeric) + } + if got := BigNumericString(test.in); got != test.wantBigNumeric { + t.Errorf("case %q, val %v as bignumeric: got %q, want %q", test.description, test.in, got, test.wantBigNumeric) } } } @@ -887,6 +906,7 @@ var ( {Name: "T", Type: TimeFieldType}, {Name: "DT", Type: DateTimeFieldType}, {Name: "N", Type: NumericFieldType}, + {Name: "BN", Type: BigNumericFieldType}, {Name: "G", Type: GeographyFieldType}, {Name: "nested", Type: RecordFieldType, Schema: Schema{ {Name: "nestS", Type: StringFieldType}, @@ -895,16 +915,17 @@ var ( {Name: "t", Type: StringFieldType}, } - testTimestamp = time.Date(2016, 11, 5, 7, 50, 22, 8, time.UTC) - testDate = civil.Date{Year: 2016, Month: 11, Day: 5} - testTime = civil.Time{Hour: 7, Minute: 50, Second: 22, Nanosecond: 8} - testDateTime = civil.DateTime{Date: testDate, Time: testTime} - testNumeric = big.NewRat(123, 456) + testTimestamp = time.Date(2016, 11, 5, 7, 50, 22, 8, time.UTC) + testDate = civil.Date{Year: 2016, Month: 11, Day: 5} + testTime = civil.Time{Hour: 7, Minute: 50, Second: 22, Nanosecond: 8} + testDateTime = civil.DateTime{Date: testDate, Time: testTime} + testNumeric = big.NewRat(123, 456) + testBigNumeric = big.NewRat(456, 789) // testGeography is a WKT string representing a single point. testGeography = "POINT(-122.350220 47.649154)" testValues = []Value{"x", "y", []byte{1, 2, 3}, int64(7), int64(8), 3.14, true, - testTimestamp, testDate, testTime, testDateTime, testNumeric, testGeography, + testTimestamp, testDate, testTime, testDateTime, testNumeric, testBigNumeric, testGeography, []Value{"nested", int64(17)}, "z"} ) @@ -918,6 +939,7 @@ type testStruct1 struct { By []byte F float64 N *big.Rat + BN *big.Rat G string Nested nested Tagged string `bigquery:"t"` @@ -952,6 +974,7 @@ func TestStructLoader(t *testing.T) { S2: "y", By: []byte{1, 2, 3}, N: big.NewRat(123, 456), + BN: big.NewRat(456, 789), G: testGeography, Nested: nested{NestS: "nested", NestI: 17}, Tagged: "z", @@ -1043,18 +1066,19 @@ func TestStructLoaderRepeated(t *testing.T) { } type testStructNullable struct { - String NullString - Bytes []byte - Integer NullInt64 - Float NullFloat64 - Boolean NullBool - Timestamp NullTimestamp - Date NullDate - Time NullTime - DateTime NullDateTime - Numeric *big.Rat - Geography NullGeography - Record *subNullable + String NullString + Bytes []byte + Integer NullInt64 + Float NullFloat64 + Boolean NullBool + Timestamp NullTimestamp + Date NullDate + Time NullTime + DateTime NullDateTime + Numeric *big.Rat + BigNumeric *big.Rat + Geography NullGeography + Record *subNullable } type subNullable struct { @@ -1072,6 +1096,7 @@ var testStructNullableSchema = Schema{ {Name: "Time", Type: TimeFieldType, Required: false}, {Name: "DateTime", Type: DateTimeFieldType, Required: false}, {Name: "Numeric", Type: NumericFieldType, Required: false}, + {Name: "BigNumeric", Type: BigNumericFieldType, Required: false}, {Name: "Geography", Type: GeographyFieldType, Required: false}, {Name: "Record", Type: RecordFieldType, Required: false, Schema: Schema{ {Name: "X", Type: IntegerFieldType, Required: false}, @@ -1088,24 +1113,25 @@ func TestStructLoaderNullable(t *testing.T) { } nonnilVals := []Value{"x", []byte{1, 2, 3}, int64(1), 2.3, true, testTimestamp, testDate, testTime, - testDateTime, big.NewRat(1, 2), testGeography, []Value{int64(4)}} + testDateTime, big.NewRat(1, 2), big.NewRat(3, 4), testGeography, []Value{int64(4)}} // All ts fields are nil. Loading non-nil values will cause them all to // be allocated. mustLoad(t, &ts, testStructNullableSchema, nonnilVals) want = testStructNullable{ - String: NullString{StringVal: "x", Valid: true}, - Bytes: []byte{1, 2, 3}, - Integer: NullInt64{Int64: 1, Valid: true}, - Float: NullFloat64{Float64: 2.3, Valid: true}, - Boolean: NullBool{Bool: true, Valid: true}, - Timestamp: NullTimestamp{Timestamp: testTimestamp, Valid: true}, - Date: NullDate{Date: testDate, Valid: true}, - Time: NullTime{Time: testTime, Valid: true}, - DateTime: NullDateTime{DateTime: testDateTime, Valid: true}, - Numeric: big.NewRat(1, 2), - Geography: NullGeography{GeographyVal: testGeography, Valid: true}, - Record: &subNullable{X: NullInt64{Int64: 4, Valid: true}}, + String: NullString{StringVal: "x", Valid: true}, + Bytes: []byte{1, 2, 3}, + Integer: NullInt64{Int64: 1, Valid: true}, + Float: NullFloat64{Float64: 2.3, Valid: true}, + Boolean: NullBool{Bool: true, Valid: true}, + Timestamp: NullTimestamp{Timestamp: testTimestamp, Valid: true}, + Date: NullDate{Date: testDate, Valid: true}, + Time: NullTime{Time: testTime, Valid: true}, + DateTime: NullDateTime{DateTime: testDateTime, Valid: true}, + Numeric: big.NewRat(1, 2), + BigNumeric: big.NewRat(3, 4), + Geography: NullGeography{GeographyVal: testGeography, Valid: true}, + Record: &subNullable{X: NullInt64{Int64: 4, Valid: true}}, } if diff := testutil.Diff(ts, want); diff != "" { t.Error(diff) @@ -1114,7 +1140,7 @@ func TestStructLoaderNullable(t *testing.T) { // Struct pointers are reused, byte slices are not. want = ts want.Bytes = []byte{17} - vals2 := []Value{nil, []byte{17}, nil, nil, nil, nil, nil, nil, nil, nil, nil, []Value{int64(7)}} + vals2 := []Value{nil, []byte{17}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, []Value{int64(7)}} mustLoad(t, &ts, testStructNullableSchema, vals2) if ts.Record != want.Record { t.Error("record pointers not identical") diff --git a/go.sum b/go.sum index ea567ca6a70..b75391c1aae 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,7 @@ cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvf cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.7.0 h1:a/O/bK/vWrYGOTFtH8di4rBxMZnmkjy+Y5LxpDwo+dA= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= +cloud.google.com/go/bigquery v1.8.0 h1:PQcPefKFdaIzjQFbiyOgAqyx8q5djaE7x9Sqe712DPA= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/datastore v1.0.0 h1:Kt+gOPPp2LEPWp8CSfxhsM8ik9CcyE/gYu+0r+RnZvM= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= @@ -348,6 +349,7 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c h1:fqgJT0MGcGpPgpWU7VRdRjuA golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= From fa77efa1a1880ff89307d54cc7e9e8c09430e4e2 Mon Sep 17 00:00:00 2001 From: shollyman Date: Thu, 19 Nov 2020 16:20:55 -0800 Subject: [PATCH 3/6] feat(bigquery): expose hive partitioning options (#3240) * feat(bigquery): add HivePartitioningOptions This resuscitates https://code-review.googlesource.com/c/gocloud/+/43270 and updates it now that the service is properly advertising the fields such as RequirePartitionFilter. --- bigquery/external.go | 72 ++++++++++++++++++++++++++++++++++++ bigquery/external_test.go | 8 ++++ bigquery/integration_test.go | 53 ++++++++++++++++++++++++++ 3 files changed, 133 insertions(+) diff --git a/bigquery/external.go b/bigquery/external.go index e8d1fdee550..3f8647e568a 100644 --- a/bigquery/external.go +++ b/bigquery/external.go @@ -92,6 +92,10 @@ type ExternalDataConfig struct { // Additional options for CSV, GoogleSheets and Bigtable formats. Options ExternalDataConfigOptions + + // HivePartitioningOptions allows use of Hive partitioning based on the + // layout of objects in Google Cloud Storage. + HivePartitioningOptions *HivePartitioningOptions } func (e *ExternalDataConfig) toBQ() bq.ExternalDataConfiguration { @@ -106,6 +110,9 @@ func (e *ExternalDataConfig) toBQ() bq.ExternalDataConfiguration { if e.Schema != nil { q.Schema = e.Schema.toBQ() } + if e.HivePartitioningOptions != nil { + q.HivePartitioningOptions = e.HivePartitioningOptions.toBQ() + } if e.Options != nil { e.Options.populateExternalDataConfig(&q) } @@ -134,6 +141,9 @@ func bqToExternalDataConfig(q *bq.ExternalDataConfiguration) (*ExternalDataConfi return nil, err } } + if q.HivePartitioningOptions != nil { + e.HivePartitioningOptions = bqToHivePartitioningOptions(q.HivePartitioningOptions) + } return e, nil } @@ -409,3 +419,65 @@ func bqToBigtableColumn(q *bq.BigtableColumn) (*BigtableColumn, error) { } return b, nil } + +// HivePartitioningMode is used in conjunction with HivePartitioningOptions. +type HivePartitioningMode string + +const ( + // AutoHivePartitioningMode automatically infers partitioning key and types. + AutoHivePartitioningMode HivePartitioningMode = "AUTO" + // StringHivePartitioningMode automatically infers partitioning keys and treats values as string. + StringHivePartitioningMode HivePartitioningMode = "STRINGS" + // CustomHivePartitioningMode allows custom definition of the external partitioning. + CustomHivePartitioningMode HivePartitioningMode = "CUSTOM" +) + +// HivePartitioningOptions defines the behavior of Hive partitioning +// when working with external data. +type HivePartitioningOptions struct { + + // Mode defines which hive partitioning mode to use when reading data. + Mode HivePartitioningMode + + // When hive partition detection is requested, a common prefix for + // all source uris should be supplied. The prefix must end immediately + // before the partition key encoding begins. + // + // For example, consider files following this data layout. + // gs://bucket/path_to_table/dt=2019-01-01/country=BR/id=7/file.avro + // gs://bucket/path_to_table/dt=2018-12-31/country=CA/id=3/file.avro + // + // When hive partitioning is requested with either AUTO or STRINGS + // detection, the common prefix can be either of + // gs://bucket/path_to_table or gs://bucket/path_to_table/ (trailing + // slash does not matter). + SourceURIPrefix string + + // If set to true, queries against this external table require + // a partition filter to be present that can perform partition + // elimination. Hive-partitioned load jobs with this field + // set to true will fail. + RequirePartitionFilter bool +} + +func (o *HivePartitioningOptions) toBQ() *bq.HivePartitioningOptions { + if o == nil { + return nil + } + return &bq.HivePartitioningOptions{ + Mode: string(o.Mode), + SourceUriPrefix: o.SourceURIPrefix, + RequirePartitionFilter: o.RequirePartitionFilter, + } +} + +func bqToHivePartitioningOptions(q *bq.HivePartitioningOptions) *HivePartitioningOptions { + if q == nil { + return nil + } + return &HivePartitioningOptions{ + Mode: HivePartitioningMode(q.Mode), + SourceURIPrefix: q.SourceUriPrefix, + RequirePartitionFilter: q.RequirePartitionFilter, + } +} diff --git a/bigquery/external_test.go b/bigquery/external_test.go index f64297fd1a0..222cf274999 100644 --- a/bigquery/external_test.go +++ b/bigquery/external_test.go @@ -48,6 +48,14 @@ func TestExternalDataConfig(t *testing.T) { Range: "sheet1!A1:Z10", }, }, + { + SourceFormat: Avro, + HivePartitioningOptions: &HivePartitioningOptions{ + Mode: AutoHivePartitioningMode, + SourceURIPrefix: "gs://somebucket/a/b/c", + RequirePartitionFilter: true, + }, + }, { SourceFormat: Bigtable, Options: &BigtableOptions{ diff --git a/bigquery/integration_test.go b/bigquery/integration_test.go index 60959d0a0ca..e450e207ec5 100644 --- a/bigquery/integration_test.go +++ b/bigquery/integration_test.go @@ -1859,6 +1859,59 @@ func TestIntegration_LegacyQuery(t *testing.T) { } } +func TestIntegration_QueryExternalHivePartitioning(t *testing.T) { + if client == nil { + t.Skip("Integration tests skipped") + } + ctx := context.Background() + + autoTable := dataset.Table(tableIDs.New()) + customTable := dataset.Table(tableIDs.New()) + + err := autoTable.Create(ctx, &TableMetadata{ + ExternalDataConfig: &ExternalDataConfig{ + SourceFormat: Parquet, + SourceURIs: []string{"gs://cloud-samples-data/bigquery/hive-partitioning-samples/autolayout/*"}, + AutoDetect: true, + HivePartitioningOptions: &HivePartitioningOptions{ + Mode: AutoHivePartitioningMode, + SourceURIPrefix: "gs://cloud-samples-data/bigquery/hive-partitioning-samples/autolayout/", + RequirePartitionFilter: true, + }, + }, + }) + if err != nil { + t.Fatalf("table.Create(auto): %v", err) + } + defer autoTable.Delete(ctx) + + err = customTable.Create(ctx, &TableMetadata{ + ExternalDataConfig: &ExternalDataConfig{ + SourceFormat: Parquet, + SourceURIs: []string{"gs://cloud-samples-data/bigquery/hive-partitioning-samples/customlayout/*"}, + AutoDetect: true, + HivePartitioningOptions: &HivePartitioningOptions{ + Mode: CustomHivePartitioningMode, + SourceURIPrefix: "gs://cloud-samples-data/bigquery/hive-partitioning-samples/customlayout/{pkey:STRING}/", + RequirePartitionFilter: true, + }, + }, + }) + if err != nil { + t.Fatalf("table.Create(custom): %v", err) + } + defer customTable.Delete(ctx) + + // Issue a test query that prunes based on the custom hive partitioning key, and verify the result is as expected. + sql := fmt.Sprintf("SELECT COUNT(*) as ct FROM `%s`.%s.%s WHERE pkey=\"foo\"", customTable.ProjectID, customTable.DatasetID, customTable.TableID) + q := client.Query(sql) + it, err := q.Read(ctx) + if err != nil { + t.Fatalf("Error querying: %v", err) + } + checkReadAndTotalRows(t, "HiveQuery", it, [][]Value{{int64(50)}}) +} + func TestIntegration_QueryParameters(t *testing.T) { if client == nil { t.Skip("Integration tests skipped") From cf308308a3da134163a19ab78d6c605d6a2c2ef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 20 Nov 2020 13:30:04 +0100 Subject: [PATCH 4/6] test(spanner): fix flaky TestMaintainer test (#3237) * test(spanner): fix flaky TestMaintainer test Fixes #3230 * docs: clearify comment --- spanner/session_test.go | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/spanner/session_test.go b/spanner/session_test.go index b868ee839ce..4f4fdfc1a1e 100644 --- a/spanner/session_test.go +++ b/spanner/session_test.go @@ -1557,7 +1557,7 @@ func TestMaintainer(t *testing.T) { waitFor(t, func() error { sp.mu.Lock() defer sp.mu.Unlock() - if sp.numOpened != 5 { + if sp.numOpened != minOpened { return fmt.Errorf("Replenish. Expect %d open, got %d", sp.MinOpened, sp.numOpened) } return nil @@ -1574,12 +1574,22 @@ func TestMaintainer(t *testing.T) { t.Fatalf("cannot get session from session pool: %v", err) } } - sp.mu.Lock() - g, w := sp.numOpened, sp.MinOpened+sp.incStep - sp.mu.Unlock() - if g != w { - t.Fatalf("numOpened sessions mismatch\nGot: %d\nWant: %d", g, w) - } + // Wait for all sessions to be added to the pool. + // The pool already contained 5 sessions (MinOpened=5). + // The test took 20 sessions from the pool. That initiated the creation of + // additional sessions, and that is done in batches of 25 sessions, so the + // pool should contain 30 sessions (with 20 currently checked out). It + // could take a couple of milliseconds before all sessions have been + // created and added to the pool. + waitFor(t, func() error { + sp.mu.Lock() + defer sp.mu.Unlock() + g, w := sp.numOpened, sp.MinOpened+sp.incStep + if g != w { + return fmt.Errorf("numOpened sessions mismatch\nGot: %d\nWant: %d", g, w) + } + return nil + }) // Return 14 sessions to the pool. There are still 6 sessions checked out. for _, sh := range shs[:14] { From b17f9f00d1e8091bae98ab81291afb46438d6aa9 Mon Sep 17 00:00:00 2001 From: tmdiep Date: Tue, 24 Nov 2020 10:36:11 +1100 Subject: [PATCH 5/6] refactor(pubsublite): message routers set partition count when created (#3252) Removes messageRouter.SetPartitionCount and introduces messageRouterFactory to create new instances when the partition count changes. --- pubsublite/internal/wire/message_router.go | 49 ++++++++-------- .../internal/wire/message_router_test.go | 57 ++++++++----------- 2 files changed, 50 insertions(+), 56 deletions(-) diff --git a/pubsublite/internal/wire/message_router.go b/pubsublite/internal/wire/message_router.go index 23edebb9af8..58027c9c678 100644 --- a/pubsublite/internal/wire/message_router.go +++ b/pubsublite/internal/wire/message_router.go @@ -19,14 +19,9 @@ import ( "math/rand" ) -// messageRouter outputs a partition number, given an ordering key. Results are -// undefined when: -// - setPartitionCount() is called with count <= 0. -// - route() is called before setPartitionCount() to initialize the router. -// -// Message routers need to accommodate topic partition resizing. +// messageRouter outputs a partition number, given an ordering key (which it may +// ignore). type messageRouter interface { - SetPartitionCount(count int) Route(orderingKey []byte) int } @@ -38,12 +33,15 @@ type roundRobinMsgRouter struct { nextPartition int } -func (r *roundRobinMsgRouter) SetPartitionCount(count int) { - r.partitionCount = count - r.nextPartition = int(r.rng.Int63n(int64(count))) +func newRoundRobinMsgRouter(rng *rand.Rand, count int) *roundRobinMsgRouter { + return &roundRobinMsgRouter{ + rng: rng, + partitionCount: count, + nextPartition: int(rng.Int63n(int64(count))), + } } -func (r *roundRobinMsgRouter) Route(orderingKey []byte) (partition int) { +func (r *roundRobinMsgRouter) Route(_ []byte) (partition int) { partition = r.nextPartition r.nextPartition = (partition + 1) % r.partitionCount return @@ -58,8 +56,10 @@ type hashingMsgRouter struct { partitionCount *big.Int } -func (r *hashingMsgRouter) SetPartitionCount(count int) { - r.partitionCount = big.NewInt(int64(count)) +func newHashingMsgRouter(count int) *hashingMsgRouter { + return &hashingMsgRouter{ + partitionCount: big.NewInt(int64(count)), + } } func (r *hashingMsgRouter) Route(orderingKey []byte) int { @@ -79,11 +79,6 @@ type compositeMsgRouter struct { keylessRouter messageRouter } -func (r *compositeMsgRouter) SetPartitionCount(count int) { - r.keyedRouter.SetPartitionCount(count) - r.keylessRouter.SetPartitionCount(count) -} - func (r *compositeMsgRouter) Route(orderingKey []byte) int { if len(orderingKey) > 0 { return r.keyedRouter.Route(orderingKey) @@ -91,11 +86,19 @@ func (r *compositeMsgRouter) Route(orderingKey []byte) int { return r.keylessRouter.Route(orderingKey) } -// defaultMessageRouter returns a compositeMsgRouter that uses hashingMsgRouter -// for messages with ordering key and roundRobinMsgRouter for messages without. -func newDefaultMessageRouter(rng *rand.Rand) messageRouter { +type messageRouterFactory struct { + rng *rand.Rand +} + +func newMessageRouterFactory(rng *rand.Rand) *messageRouterFactory { + return &messageRouterFactory{rng: rng} +} + +// New returns a compositeMsgRouter that uses hashingMsgRouter for messages with +// ordering key and roundRobinMsgRouter for messages without. +func (f *messageRouterFactory) New(partitionCount int) messageRouter { return &compositeMsgRouter{ - keyedRouter: &hashingMsgRouter{}, - keylessRouter: &roundRobinMsgRouter{rng: rng}, + keyedRouter: newHashingMsgRouter(partitionCount), + keylessRouter: newRoundRobinMsgRouter(f.rng, partitionCount), } } diff --git a/pubsublite/internal/wire/message_router_test.go b/pubsublite/internal/wire/message_router_test.go index 0f847074afa..79d2159bec5 100644 --- a/pubsublite/internal/wire/message_router_test.go +++ b/pubsublite/internal/wire/message_router_test.go @@ -21,25 +21,7 @@ import ( "cloud.google.com/go/pubsublite/internal/test" ) -type fakeMsgRouter struct { - multiplier int - partitionCount int -} - -func (f *fakeMsgRouter) SetPartitionCount(count int) { - f.partitionCount = count -} - -func (f *fakeMsgRouter) Route(orderingKey []byte) int { - return f.partitionCount * f.multiplier -} - func TestRoundRobinMsgRouter(t *testing.T) { - // Using the same msgRouter for each test run ensures that it reinitializes - // when the partition count changes. - source := &test.FakeSource{} - msgRouter := &roundRobinMsgRouter{rng: rand.New(source)} - for _, tc := range []struct { partitionCount int source int64 @@ -57,8 +39,9 @@ func TestRoundRobinMsgRouter(t *testing.T) { }, } { t.Run(fmt.Sprintf("partitionCount=%d", tc.partitionCount), func(t *testing.T) { - source.Ret = tc.source - msgRouter.SetPartitionCount(tc.partitionCount) + source := &test.FakeSource{Ret: tc.source} + msgRouter := newRoundRobinMsgRouter(rand.New(source), tc.partitionCount) + for i, want := range tc.want { got := msgRouter.Route([]byte("IGNORED")) if got != want { @@ -70,10 +53,6 @@ func TestRoundRobinMsgRouter(t *testing.T) { } func TestHashingMsgRouter(t *testing.T) { - // Using the same msgRouter for each test run ensures that it reinitializes - // when the partition count changes. - msgRouter := &hashingMsgRouter{} - keys := [][]byte{ []byte("foo1"), []byte("foo2"), @@ -89,7 +68,7 @@ func TestHashingMsgRouter(t *testing.T) { {partitionCount: 5}, } { t.Run(fmt.Sprintf("partitionCount=%d", tc.partitionCount), func(t *testing.T) { - msgRouter.SetPartitionCount(tc.partitionCount) + msgRouter := newHashingMsgRouter(tc.partitionCount) for _, key := range keys { p1 := msgRouter.Route(key) p2 := msgRouter.Route(key) @@ -104,14 +83,16 @@ func TestHashingMsgRouter(t *testing.T) { } } -func TestCompositeMsgRouter(t *testing.T) { - keyedRouter := &fakeMsgRouter{multiplier: 10} - keylessRouter := &fakeMsgRouter{multiplier: 100} - msgRouter := &compositeMsgRouter{ - keyedRouter: keyedRouter, - keylessRouter: keylessRouter, - } +type fakeMsgRouter struct { + multiplier int + partitionCount int +} +func (f *fakeMsgRouter) Route(orderingKey []byte) int { + return f.partitionCount * f.multiplier +} + +func TestCompositeMsgRouter(t *testing.T) { for _, tc := range []struct { desc string partitionCount int @@ -138,7 +119,17 @@ func TestCompositeMsgRouter(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - msgRouter.SetPartitionCount(tc.partitionCount) + msgRouter := &compositeMsgRouter{ + keyedRouter: &fakeMsgRouter{ + multiplier: 10, + partitionCount: tc.partitionCount, + }, + keylessRouter: &fakeMsgRouter{ + multiplier: 100, + partitionCount: tc.partitionCount, + }, + } + if got := msgRouter.Route(tc.key); got != tc.want { t.Errorf("Route() = %d, want = %d", got, tc.want) } From 200e261dc28ab818ac7ebb08a38881aa663a4d9f Mon Sep 17 00:00:00 2001 From: tmdiep Date: Tue, 24 Nov 2020 11:28:44 +1100 Subject: [PATCH 6/6] test(pubsublite): integration test utils (#3254) Utils to track messages, validate ordering and detect duplicates. --- pubsublite/internal/test/msg_tracker.go | 90 +++++++++ pubsublite/internal/test/msg_tracker_test.go | 56 ++++++ pubsublite/internal/test/msg_validators.go | 181 ++++++++++++++++++ .../internal/test/msg_validators_test.go | 140 ++++++++++++++ 4 files changed, 467 insertions(+) create mode 100644 pubsublite/internal/test/msg_tracker.go create mode 100644 pubsublite/internal/test/msg_tracker_test.go create mode 100644 pubsublite/internal/test/msg_validators.go create mode 100644 pubsublite/internal/test/msg_validators_test.go diff --git a/pubsublite/internal/test/msg_tracker.go b/pubsublite/internal/test/msg_tracker.go new file mode 100644 index 00000000000..f89adeb67f2 --- /dev/null +++ b/pubsublite/internal/test/msg_tracker.go @@ -0,0 +1,90 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "fmt" + "sync" + "time" +) + +// MsgTracker is a helper for checking whether a set of messages make a full +// round trip from publisher to subscriber. +// +// Add() registers published messages. Remove() should be called when messages +// are received by subscribers. Call Wait() to block until all tracked messages +// are received. The same MsgTracker instance can be reused to repeat this +// sequence for multiple test cycles. +// +// Add() and Remove() calls should not be interleaved. +type MsgTracker struct { + msgMap map[string]bool + done chan struct{} + mu sync.Mutex +} + +// NewMsgTracker creates a new message tracker. +func NewMsgTracker() *MsgTracker { + return &MsgTracker{ + msgMap: make(map[string]bool), + done: make(chan struct{}, 1), + } +} + +// Add a set of tracked messages. +func (mt *MsgTracker) Add(msgs ...string) { + mt.mu.Lock() + defer mt.mu.Unlock() + + for _, msg := range msgs { + mt.msgMap[msg] = true + } +} + +// Remove and return true if `msg` is tracked. Signals the `done` channel once +// all messages have been received. +func (mt *MsgTracker) Remove(msg string) bool { + mt.mu.Lock() + defer mt.mu.Unlock() + + if _, exists := mt.msgMap[msg]; exists { + delete(mt.msgMap, msg) + if len(mt.msgMap) == 0 { + var void struct{} + mt.done <- void + } + return true + } + return false +} + +// Wait up to `timeout` to receive all tracked messages. +func (mt *MsgTracker) Wait(timeout time.Duration) error { + mt.mu.Lock() + totalCount := len(mt.msgMap) + mt.mu.Unlock() + + select { + case <-time.After(timeout): + mt.mu.Lock() + receivedCount := totalCount - len(mt.msgMap) + err := fmt.Errorf("received %d of %d messages", receivedCount, totalCount) + mt.msgMap = make(map[string]bool) + mt.mu.Unlock() + return err + + case <-mt.done: + return nil + } +} diff --git a/pubsublite/internal/test/msg_tracker_test.go b/pubsublite/internal/test/msg_tracker_test.go new file mode 100644 index 00000000000..642bfd3c18d --- /dev/null +++ b/pubsublite/internal/test/msg_tracker_test.go @@ -0,0 +1,56 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "testing" + "time" +) + +func TestMsgTrackerWaitSuccess(t *testing.T) { + msgs := []string{"a", "b", "c"} + + msgTracker := NewMsgTracker() + msgTracker.Add(msgs...) + + for _, msg := range msgs { + if got, want := msgTracker.Remove(msg), true; got != want { + t.Errorf("MsgTracker.Remove(%q) got %v, want %v", msg, got, want) + } + } + for _, msg := range []string{"d", "e"} { + if got, want := msgTracker.Remove(msg), false; got != want { + t.Errorf("MsgTracker.Remove(%q) got %v, want %v", msg, got, want) + } + } + if gotErr := msgTracker.Wait(time.Millisecond); gotErr != nil { + t.Errorf("MsgTracker.Wait() got err: %v", gotErr) + } +} + +func TestMsgTrackerWaitTimeout(t *testing.T) { + msgs := []string{"a", "b", "c"} + + msgTracker := NewMsgTracker() + msgTracker.Add(msgs...) + + for _, msg := range []string{"a", "c"} { + if got, want := msgTracker.Remove(msg), true; got != want { + t.Errorf("MsgTracker.Remove(%q) got %v, want %v", msg, got, want) + } + } + if gotErr, wantMsg := msgTracker.Wait(time.Millisecond), "received 2 of 3 messages"; ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("MsgTracker.Wait() got err: %v, want msg: %q", gotErr, wantMsg) + } +} diff --git a/pubsublite/internal/test/msg_validators.go b/pubsublite/internal/test/msg_validators.go new file mode 100644 index 00000000000..323e314d269 --- /dev/null +++ b/pubsublite/internal/test/msg_validators.go @@ -0,0 +1,181 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "fmt" + "strconv" + "strings" + "sync" +) + +// OrderingSender generates strings containing a message index to use for +// verifying message ordering. It is used on conjunction with Publishers. +type OrderingSender struct { + TotalMsgCount int64 +} + +// NewOrderingSender creats a new OrderingSender. +func NewOrderingSender() *OrderingSender { + return new(OrderingSender) +} + +// Next generates the next string to publish. +func (os *OrderingSender) Next(prefix string) string { + os.TotalMsgCount++ + return fmt.Sprintf("%s/%d", prefix, os.TotalMsgCount) +} + +// OrderingReceiver consumes a message string generated by OrderingSender and +// verifies that messages in a partition are ordered. It is used in conjunction +// with Subscribers. +type OrderingReceiver struct { + mu sync.Mutex + // Map of key and last received message index. Messages are only guaranteed to + // be received in order within a partition. + received map[string]int64 +} + +// NewOrderingReceiver creates a new OrderingReceiver. +func NewOrderingReceiver() *OrderingReceiver { + return &OrderingReceiver{ + received: make(map[string]int64), + } +} + +func parseMsgIndex(msg string) int64 { + pos := strings.LastIndex(msg, "/") + if pos >= 0 { + if n, err := strconv.ParseInt(msg[pos+1:], 10, 64); err == nil { + return n + } + } + return -1 +} + +// Receive checks the given message data and key and returns an error if +// unordered messages are detected. +// +// Note: a normal scenario resulting in unordered messages is when the Publish +// stream breaks while there are in-flight batches, which are resent upon +// stream reconnect. +func (or *OrderingReceiver) Receive(data, key string) error { + or.mu.Lock() + defer or.mu.Unlock() + + idx := parseMsgIndex(data) + if idx < 0 { + return fmt.Errorf("failed to parse index from message: %q", data) + } + + // Verify non-decreasing ordering. Allow duplicates, which can be verified + // with DuplicateMsgDetector. + lastIdx, exists := or.received[key] + if exists && idx < lastIdx { + return fmt.Errorf("message ordering failed for key %s, expected message idx >= %d, got %d", key, lastIdx, idx) + } + or.received[key] = idx + return nil +} + +var void struct{} + +type msgMetadata struct { + offsets map[int64]struct{} +} + +func newMsgMetadata() *msgMetadata { + return &msgMetadata{ + offsets: make(map[int64]struct{}), + } +} + +func (mm *msgMetadata) ContainsOffset(offset int64) bool { + _, exists := mm.offsets[offset] + return exists +} + +func (mm *msgMetadata) AddOffset(offset int64) { + mm.offsets[offset] = void +} + +// DuplicateMsgDetector can be used to detect duplicate messages, either due to +// duplicate publishes or receives. +type DuplicateMsgDetector struct { + mu sync.Mutex + // Map of Pub/Sub message data and associated metadata. + msgs map[string]*msgMetadata + duplicatePublishCount int64 + duplicateReceiveCount int64 +} + +// NewDuplicateMsgDetector creates a new DuplicateMsgDetector. +func NewDuplicateMsgDetector() *DuplicateMsgDetector { + return &DuplicateMsgDetector{ + msgs: make(map[string]*msgMetadata), + } +} + +// Receive checks the given message data and offset. +func (dm *DuplicateMsgDetector) Receive(data string, offset int64) { + dm.mu.Lock() + defer dm.mu.Unlock() + + if metadata, exists := dm.msgs[data]; exists { + if metadata.ContainsOffset(offset) { + // If the message contains the same offset, it means it was received + // multiple times. This is not expected within a single test run. But it + // is normal when processes are stopped & restarted without committing + // cursors. + dm.duplicateReceiveCount++ + } else { + // If the message contains a different offset, it means a message was + // republished, which can occur when a publish stream reconnects with + // in-flight published messages. + dm.duplicatePublishCount++ + metadata.AddOffset(offset) + } + } else { + metadata = newMsgMetadata() + metadata.AddOffset(offset) + dm.msgs[data] = metadata + } +} + +// Status returns a non-empty status string if there were duplicates detected. +func (dm *DuplicateMsgDetector) Status() string { + dm.mu.Lock() + defer dm.mu.Unlock() + + if (dm.duplicateReceiveCount + dm.duplicatePublishCount) == 0 { + return "" + } + return fmt.Sprintf("duplicate publish count = %d, receive count = %d", dm.duplicatePublishCount, dm.duplicateReceiveCount) +} + +// HasPublishDuplicates returns true if duplicate published messages were +// detected. +func (dm *DuplicateMsgDetector) HasPublishDuplicates() bool { + dm.mu.Lock() + defer dm.mu.Unlock() + return dm.duplicatePublishCount > 0 +} + +// HasReceiveDuplicates returns true if duplicate received messages were +// detected. +func (dm *DuplicateMsgDetector) HasReceiveDuplicates() bool { + dm.mu.Lock() + defer dm.mu.Unlock() + return dm.duplicateReceiveCount > 0 +} diff --git a/pubsublite/internal/test/msg_validators_test.go b/pubsublite/internal/test/msg_validators_test.go new file mode 100644 index 00000000000..5def0c620cb --- /dev/null +++ b/pubsublite/internal/test/msg_validators_test.go @@ -0,0 +1,140 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +package test + +import ( + "testing" +) + +func TestOrderingSender(t *testing.T) { + sender := NewOrderingSender() + if got, want := sender.Next("prefix"), "prefix/1"; got != want { + t.Errorf("OrderingSender.Next() got %q, want %q", got, want) + } + if got, want := sender.Next("prefix"), "prefix/2"; got != want { + t.Errorf("OrderingSender.Next() got %q, want %q", got, want) + } + if got, want := sender.Next("foo"), "foo/3"; got != want { + t.Errorf("OrderingSender.Next() got %q, want %q", got, want) + } +} + +func TestOrderingReceiver(t *testing.T) { + receiver := NewOrderingReceiver() + + t.Run("Invalid message", func(t *testing.T) { + if gotErr, wantMsg := receiver.Receive("invalid", "ignored"), "failed to parse index"; !ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("OrderingReceiver.Receive() got err: %v, want msg: %q", gotErr, wantMsg) + } + }) + + t.Run("Key=foo", func(t *testing.T) { + if gotErr := receiver.Receive("foo/1", "foo"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr := receiver.Receive("foo/3", "foo"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr := receiver.Receive("foo/3", "foo"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr := receiver.Receive("foo/4", "foo"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr, wantMsg := receiver.Receive("foo/3", "foo"), "expected message idx >= 4, got 3"; !ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("OrderingReceiver.Receive() got err: %v, want msg: %q", gotErr, wantMsg) + } + }) + + t.Run("Key=bar", func(t *testing.T) { + if gotErr := receiver.Receive("bar/30", "bar"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr := receiver.Receive("bar/30", "bar"); gotErr != nil { + t.Errorf("OrderingReceiver.Receive() got err: %v", gotErr) + } + if gotErr, wantMsg := receiver.Receive("bar/29", "bar"), "expected message idx >= 30, got 29"; !ErrorHasMsg(gotErr, wantMsg) { + t.Errorf("OrderingReceiver.Receive() got err: %v, want msg: %q", gotErr, wantMsg) + } + }) +} + +func TestDuplicateMsgDetector(t *testing.T) { + t.Run("No duplicates", func(t *testing.T) { + duplicateDetector := NewDuplicateMsgDetector() + duplicateDetector.Receive("foo", 10) + duplicateDetector.Receive("bar", 20) + + if got, want := duplicateDetector.duplicatePublishCount, int64(0); got != want { + t.Errorf("DuplicateMsgDetector.duplicatePublishCount() got %v, want %v", got, want) + } + if got, want := duplicateDetector.duplicateReceiveCount, int64(0); got != want { + t.Errorf("DuplicateMsgDetector.duplicateReceiveCount got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasPublishDuplicates(), false; got != want { + t.Errorf("DuplicateMsgDetector.HasPublishDuplicates() got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasReceiveDuplicates(), false; got != want { + t.Errorf("DuplicateMsgDetector.HasReceiveDuplicates() got %v, want %v", got, want) + } + if got, want := duplicateDetector.Status(), ""; got != want { + t.Errorf("DuplicateMsgDetector.Status() got %q, want %q", got, want) + } + }) + + t.Run("Duplicate publish", func(t *testing.T) { + duplicateDetector := NewDuplicateMsgDetector() + duplicateDetector.Receive("foo", 10) + duplicateDetector.Receive("foo", 11) + duplicateDetector.Receive("foo", 12) + + if got, want := duplicateDetector.duplicatePublishCount, int64(2); got != want { + t.Errorf("DuplicateMsgDetector.duplicatePublishCount() got %v, want %v", got, want) + } + if got, want := duplicateDetector.duplicateReceiveCount, int64(0); got != want { + t.Errorf("DuplicateMsgDetector.duplicateReceiveCount got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasPublishDuplicates(), true; got != want { + t.Errorf("DuplicateMsgDetector.HasPublishDuplicates() got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasReceiveDuplicates(), false; got != want { + t.Errorf("DuplicateMsgDetector.HasReceiveDuplicates() got %v, want %v", got, want) + } + if got := duplicateDetector.Status(); got == "" { + t.Errorf("DuplicateMsgDetector.Status() got %q, want status string", got) + } + }) + + t.Run("Duplicate receive", func(t *testing.T) { + duplicateDetector := NewDuplicateMsgDetector() + duplicateDetector.Receive("bar", 20) + duplicateDetector.Receive("bar", 20) + + if got, want := duplicateDetector.duplicatePublishCount, int64(0); got != want { + t.Errorf("DuplicateMsgDetector.duplicatePublishCount() got %v, want %v", got, want) + } + if got, want := duplicateDetector.duplicateReceiveCount, int64(1); got != want { + t.Errorf("DuplicateMsgDetector.duplicateReceiveCount got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasPublishDuplicates(), false; got != want { + t.Errorf("DuplicateMsgDetector.HasPublishDuplicates() got %v, want %v", got, want) + } + if got, want := duplicateDetector.HasReceiveDuplicates(), true; got != want { + t.Errorf("DuplicateMsgDetector.HasReceiveDuplicates() got %v, want %v", got, want) + } + if got := duplicateDetector.Status(); got == "" { + t.Errorf("DuplicateMsgDetector.Status() got %q, want status string", got) + } + }) +}