diff --git a/CHANGELOG.md b/CHANGELOG.md index aae4904..fcf6ded 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +## [v0.5.0] +- Added cassandra row_id with TIMEUUID for long-polling [#25](https://github.com/xmidt-org/codex-db/pull/25) + ## [v0.4.0] - Modified retry package to use backoff package for exponential backoffs on retries [#21](https://github.com/xmidt-org/codex-db/pull/21) - Added automated releases using travis [#22](https://github.com/xmidt-org/codex-db/pull/22) @@ -40,7 +43,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [v0.1.0] - Initial creation, moved from: https://github.com/xmidt-org/codex-deploy -[Unreleased]: https://github.com/xmidt-org/codex-db/compare/v0.4.0..HEAD +[Unreleased]: https://github.com/xmidt-org/codex-db/compare/v0.5.0..HEAD +[v0.5.0]: https://github.com/xmidt-org/codex-db/compare/v0.4.0..v0.5.0 [v0.4.0]: https://github.com/xmidt-org/codex-db/compare/v0.3.3..v0.4.0 [v0.3.3]: https://github.com/xmidt-org/codex-db/compare/v0.3.2..v0.3.3 [v0.3.2]: https://github.com/xmidt-org/codex-db/compare/v0.3.1..v0.3.2 diff --git a/README.md b/README.md index 6d747aa..d854ea4 100644 --- a/README.md +++ b/README.md @@ -34,13 +34,14 @@ This repo is a library of packages. There is no installation. ```cassandraql CREATE KEYSPACE IF NOT EXISTS devices; CREATE TABLE devices.events (device_id varchar, - record_type int, - birthdate bigint, - deathdate bigint, - data blob, - nonce blob, - alg varchar, - kid varchar, + record_type INT, + birthdate BIGINT, + deathdate BIGINT, + data BLOB, + nonce BLOB, + alg VARCHAR, + kid VARCHAR, + row_id TIMEUUID, PRIMARY KEY (device_id, birthdate, record_type)) WITH CLUSTERING ORDER BY (birthdate DESC, record_type ASC) AND default_time_to_live = 2768400 @@ -50,6 +51,11 @@ CREATE INDEX search_by_record_type ON devices.events WITH CLUSTERING ORDER BY (record_type ASC, birthdate DESC) AND default_time_to_live = 2768400 AND transactions = {'enabled': 'false', 'consistency_level':'user_enforced'}; +CREATE INDEX search_by_row_id ON devices.events + (device_id, row_id) + WITH CLUSTERING ORDER BY (row_id DESC) + AND default_time_to_live = 2768400 + AND transactions = {'enabled': 'false', 'consistency_level':'user_enforced'}; CREATE TABLE devices.blacklist (device_id varchar PRIMARY KEY, reason varchar); ``` diff --git a/cassandra/README.md b/cassandra/README.md new file mode 100644 index 0000000..6f951bb --- /dev/null +++ b/cassandra/README.md @@ -0,0 +1,19 @@ +# Cassandra DB driver +This implementation is geared towards yugabyte. + +# Migration from v0.4.0 to v0.5.0 +The addition of row_id as a TIMEUUID as a simplistic version of state hash. +Since row_id can be null, gungnir will work with both database schemas. +In order to do long polling, gungnir db driver will need to be updated. +Svalinn is not backwards compatible as the insert statement has changed to include the +TIMEUUID. + +The following is the migration script from v0.4.0 to v0.5.0 +```cassandraql +ALTER TABLE devices.events ADD row_id TIMEUUID; +CREATE INDEX search_by_row_id ON devices.events + (device_id, row_id) + WITH CLUSTERING ORDER BY (row_id DESC) + AND default_time_to_live = 2768400 + AND transactions = {'enabled': 'false', 'consistency_level':'user_enforced'}; +``` diff --git a/cassandra/db.go b/cassandra/db.go index 2266755..826e168 100644 --- a/cassandra/db.go +++ b/cassandra/db.go @@ -50,7 +50,7 @@ type Config struct { // Database aka Keyspace for cassandra Database string - //OpTimeout + // OpTimeout OpTimeout time.Duration // SSLRootCert used for enabling tls to the cluster. SSLKey, and SSLCert must also be set. @@ -145,8 +145,14 @@ func validateConfig(config *Config) { } // GetRecords returns a list of records for a given device. -func (c *Connection) GetRecords(deviceID string, limit int) ([]db.Record, error) { - deviceInfo, err := c.finder.findRecords(limit, "WHERE device_id=?", deviceID) +func (c *Connection) GetRecords(deviceID string, limit int, stateHash string) ([]db.Record, error) { + filterString := "WHERE device_id=?" + items := []interface{}{deviceID} + if stateHash != "" { + filterString = "WHERE device_id = ? AND row_id > ?" + items = []interface{}{deviceID, stateHash} + } + deviceInfo, err := c.finder.findRecords(limit, filterString, items...) if err != nil { c.measures.SQLQueryFailureCount.With(db.TypeLabel, db.ReadType).Add(1.0) return []db.Record{}, emperror.WrapWith(err, "Getting records from database failed", "device id", deviceID) @@ -157,8 +163,14 @@ func (c *Connection) GetRecords(deviceID string, limit int) ([]db.Record, error) } // GetRecords returns a list of records for a given device and event type. -func (c *Connection) GetRecordsOfType(deviceID string, limit int, eventType db.EventType) ([]db.Record, error) { - deviceInfo, err := c.finder.findRecords(limit, "WHERE device_id = ? AND record_type = ?", deviceID, eventType) +func (c *Connection) GetRecordsOfType(deviceID string, limit int, eventType db.EventType, stateHash string) ([]db.Record, error) { + filterString := "WHERE device_id = ? AND record_type = ?" + items := []interface{}{deviceID, eventType} + if stateHash != "" { + filterString = "WHERE device_id = ? AND record_type = ? AND row_id > ?" + items = []interface{}{deviceID, eventType, stateHash} + } + deviceInfo, err := c.finder.findRecords(limit, filterString, items) if err != nil { c.measures.SQLQueryFailureCount.With(db.TypeLabel, db.ReadType).Add(1.0) return []db.Record{}, emperror.WrapWith(err, "Getting records from database failed", "device id", deviceID) @@ -168,6 +180,30 @@ func (c *Connection) GetRecordsOfType(deviceID string, limit int, eventType db.E return deviceInfo, nil } +// GetStateHash returns a hash for the latest record added to the database. +func (c *Connection) GetStateHash(records []db.Record) (string, error) { + if len(records) == 0 { + return "", errors.New("record slice is empty") + } else if len(records) == 1 && records[0].RowID != "" { + return records[0].RowID, nil + } + original := gocql.UUIDFromTime(time.Time{}) + latest := original + for _, elem := range records { + uuid, err := gocql.ParseUUID(elem.RowID) + if err != nil { + continue + } + if uuid.Time().UnixNano() > latest.Time().UnixNano() { + latest = uuid + } + } + if latest == original { + return "", errors.New("no hash found") + } + return latest.String(), nil +} + // GetBlacklist returns a list of blacklisted devices. func (c *Connection) GetBlacklist() (list []blacklist.BlackListedItem, err error) { list, err = c.findList.findBlacklist() diff --git a/cassandra/db_test.go b/cassandra/db_test.go index 49d839a..5571110 100644 --- a/cassandra/db_test.go +++ b/cassandra/db_test.go @@ -89,7 +89,7 @@ func TestGetRecords(t *testing.T) { p.Assert(t, SQLQuerySuccessCounter)(xmetricstest.Value(0.0)) p.Assert(t, SQLQueryFailureCounter)(xmetricstest.Value(0.0)) - records, err := dbConnection.GetRecords(tc.deviceID, 5) + records, err := dbConnection.GetRecords(tc.deviceID, 5, "") mockObj.AssertExpectations(t) p.Assert(t, SQLQuerySuccessCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedSuccessMetric)) p.Assert(t, SQLQueryFailureCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedFailureMetric)) @@ -157,7 +157,145 @@ func TestGetRecordsOfType(t *testing.T) { p.Assert(t, SQLQueryFailureCounter)(xmetricstest.Value(0.0)) p.Assert(t, SQLReadRecordsCounter)(xmetricstest.Value(0.0)) - records, err := dbConnection.GetRecordsOfType(tc.deviceID, 5, tc.eventType) + records, err := dbConnection.GetRecordsOfType(tc.deviceID, 5, tc.eventType, "") + mockObj.AssertExpectations(t) + p.Assert(t, SQLQuerySuccessCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedSuccessMetric)) + p.Assert(t, SQLQueryFailureCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedFailureMetric)) + p.Assert(t, SQLReadRecordsCounter)(xmetricstest.Value(float64(len(tc.expectedRecords)))) + if tc.expectedErr == nil || err == nil { + assert.Equal(tc.expectedErr, err) + } else { + assert.Contains(err.Error(), tc.expectedErr.Error()) + } + assert.Equal(tc.expectedRecords, records) + }) + } +} + +func TestGetLatestHash(t *testing.T) { + tests := []struct { + description string + expectedHash string + records []db.Record + hasError bool + }{ + { + description: "empty list", + expectedHash: "", + records: []db.Record{}, + hasError: true, + }, + { + description: "one record", + expectedHash: "cb9629c8-3256-11ea-91fe-6b6aedd62e7b", + records: []db.Record{{RowID: "cb9629c8-3256-11ea-91fe-6b6aedd62e7b"}}, + hasError: false, + }, + { + description: "multiple record", + expectedHash: "cb962de2-3256-11ea-91fe-6b6aedd62e7b", + records: []db.Record{{RowID: "cb962ad6-3256-11ea-91fe-6b6aedd62e7b"}, {RowID: "cb9629c8-3256-11ea-91fe-6b6aedd62e7b"}, {RowID: "cb962de2-3256-11ea-91fe-6b6aedd62e7b"}}, + hasError: false, + }, + { + description: "multiple record with last empty", + expectedHash: "cb962de2-3256-11ea-91fe-6b6aedd62e7b", + records: []db.Record{{RowID: "cb962ad6-3256-11ea-91fe-6b6aedd62e7b"}, {RowID: "cb9629c8-3256-11ea-91fe-6b6aedd62e7b"}, {RowID: "cb962de2-3256-11ea-91fe-6b6aedd62e7b"}, {}}, + hasError: false, + }, + { + description: "multiple record with first empty", + expectedHash: "cb962de2-3256-11ea-91fe-6b6aedd62e7b", + records: []db.Record{{}, {RowID: "cb962ad6-3256-11ea-91fe-6b6aedd62e7b"}, {RowID: "cb9629c8-3256-11ea-91fe-6b6aedd62e7b"}, {RowID: "cb962de2-3256-11ea-91fe-6b6aedd62e7b"}, {}}, + hasError: false, + }, + { + description: "empty record record", + expectedHash: "", + records: []db.Record{{}}, + hasError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + mockObj := new(mockFinder) + p := xmetricstest.NewProvider(nil, Metrics) + m := NewMeasures(p) + dbConnection := Connection{ + measures: m, + finder: mockObj, + } + + hash, err := dbConnection.GetStateHash(tc.records) + if tc.hasError { + assert.Error(err) + } else { + assert.NoError(err) + } + assert.Equal(tc.expectedHash, hash) + mockObj.AssertExpectations(t) + + }) + } + +} +func TestGetRecordsAfter(t *testing.T) { + tests := []struct { + description string + deviceID string + hash string + expectedRecords []db.Record + expectedSuccessMetric float64 + expectedFailureMetric float64 + expectedErr error + expectedCalls int + }{ + { + description: "Success", + deviceID: "1234", + hash: "123", + expectedRecords: []db.Record{ + { + Type: 1, + DeviceID: "1234", + }, + }, + expectedSuccessMetric: 1.0, + expectedErr: nil, + expectedCalls: 1, + }, + { + description: "Get Error", + deviceID: "1234", + expectedRecords: []db.Record{}, + expectedFailureMetric: 1.0, + expectedErr: errors.New("test Get error"), + expectedCalls: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + mockObj := new(mockFinder) + p := xmetricstest.NewProvider(nil, Metrics) + m := NewMeasures(p) + dbConnection := Connection{ + measures: m, + finder: mockObj, + } + if tc.expectedCalls > 0 { + marshaledRecords, err := json.Marshal(tc.expectedRecords) + assert.Nil(err) + mockObj.On("findRecords", mock.Anything, mock.Anything, mock.Anything).Return(marshaledRecords, tc.expectedErr).Times(tc.expectedCalls) + } + p.Assert(t, SQLQuerySuccessCounter)(xmetricstest.Value(0.0)) + p.Assert(t, SQLQueryFailureCounter)(xmetricstest.Value(0.0)) + p.Assert(t, SQLReadRecordsCounter)(xmetricstest.Value(0.0)) + + records, err := dbConnection.GetRecords(tc.deviceID, 5, tc.hash) mockObj.AssertExpectations(t) p.Assert(t, SQLQuerySuccessCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedSuccessMetric)) p.Assert(t, SQLQueryFailureCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedFailureMetric)) diff --git a/cassandra/executer.go b/cassandra/executer.go index deb36fc..fabbb3d 100644 --- a/cassandra/executer.go +++ b/cassandra/executer.go @@ -66,11 +66,12 @@ func (b *dbDecorator) findRecords(limit int, filter string, where ...interface{} nonce []byte alg string kid string + rowid string ) - iter := b.session.Query(fmt.Sprintf("SELECT device_id, record_type, birthdate, deathdate, data, nonce, alg, kid FROM devices.events %s LIMIT ?", filter), append(where, limit)...).Iter() + iter := b.session.Query(fmt.Sprintf("SELECT device_id, record_type, birthdate, deathdate, data, nonce, alg, kid, row_id FROM devices.events %s LIMIT ?", filter), append(where, limit)...).Iter() - for iter.Scan(&device, &eventType, &birthdate, &deathdate, &data, &nonce, &alg, &kid) { + for iter.Scan(&device, &eventType, &birthdate, &deathdate, &data, &nonce, &alg, &kid, &rowid) { records = append(records, db.Record{ DeviceID: device, Type: db.EventType(eventType), @@ -80,6 +81,7 @@ func (b *dbDecorator) findRecords(limit int, filter string, where ...interface{} Nonce: nonce, Alg: alg, KID: kid, + RowID: rowid, }) // clear out vars https://github.com/gocql/gocql/issues/1348 device = "" @@ -90,6 +92,7 @@ func (b *dbDecorator) findRecords(limit int, filter string, where ...interface{} nonce = []byte{} alg = "" kid = "" + rowid = "" } err := iter.Close() @@ -142,7 +145,7 @@ func (b *dbDecorator) insert(records []db.Record) (int, error) { for _, record := range records { // there can be no spaces for some weird reason. Otherwise the database returns and error. - batch.Query("INSERT INTO devices.events (device_id, record_type, birthdate, deathdate, data, nonce, alg, kid) VALUES (?, ?, ?, ?, ?, ?, ?, ?);", + batch.Query("INSERT INTO devices.events (device_id, record_type, birthdate, deathdate, data, nonce, alg, kid, row_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, now());", record.DeviceID, record.Type, record.BirthDate, diff --git a/db.go b/db.go index 11f4d46..17b3b89 100644 --- a/db.go +++ b/db.go @@ -45,6 +45,7 @@ type Record struct { Nonce []byte `json:"nonce" bson:"nonce"` Alg string `json:"alg" bson:"alg"` KID string `json:"kid" bson:"kid" gorm:"Column:kid"` + RowID string `json:"rowid"` } // RecordToDelete is the information needed to get out of the database in order @@ -75,6 +76,7 @@ type Pruner interface { // RecordGetter is something that can get records, including only getting records of a // certain type. type RecordGetter interface { - GetRecords(deviceID string, limit int) ([]Record, error) - GetRecordsOfType(deviceID string, limit int, eventType EventType) ([]Record, error) + GetRecords(deviceID string, limit int, stateHash string) ([]Record, error) + GetRecordsOfType(deviceID string, limit int, eventType EventType, stateHash string) ([]Record, error) + GetStateHash(records []Record) (string, error) } diff --git a/postgresql/db.go b/postgresql/db.go index 083263a..7f0fb40 100644 --- a/postgresql/db.go +++ b/postgresql/db.go @@ -249,7 +249,7 @@ func (c *Connection) setupMetrics() { } // GetRecords returns a list of records for a given device. -func (c *Connection) GetRecords(deviceID string, limit int) ([]db.Record, error) { +func (c *Connection) GetRecords(deviceID string, limit int, stateHash string) ([]db.Record, error) { var ( deviceInfo []db.Record ) @@ -264,7 +264,7 @@ func (c *Connection) GetRecords(deviceID string, limit int) ([]db.Record, error) } // GetRecords returns a list of records for a given device and event type. -func (c *Connection) GetRecordsOfType(deviceID string, limit int, eventType db.EventType) ([]db.Record, error) { +func (c *Connection) GetRecordsOfType(deviceID string, limit int, eventType db.EventType, stateHash string) ([]db.Record, error) { var ( deviceInfo []db.Record ) @@ -278,6 +278,11 @@ func (c *Connection) GetRecordsOfType(deviceID string, limit int, eventType db.E return deviceInfo, nil } +// GetStateHash returns a hash for the latest record added to the database +func (c *Connection) GetStateHash(records []db.Record) (string, error) { + panic("not implemented") +} + // GetRecordsToDelete returns a list of record ids and deathdates not past a // given date. func (c *Connection) GetRecordsToDelete(shard int, limit int, deathDate int64) ([]db.RecordToDelete, error) { diff --git a/postgresql/db_test.go b/postgresql/db_test.go index 6a5cfd3..350b96a 100644 --- a/postgresql/db_test.go +++ b/postgresql/db_test.go @@ -90,7 +90,7 @@ func TestGetRecords(t *testing.T) { p.Assert(t, SQLQuerySuccessCounter)(xmetricstest.Value(0.0)) p.Assert(t, SQLQueryFailureCounter)(xmetricstest.Value(0.0)) - records, err := dbConnection.GetRecords(tc.deviceID, 5) + records, err := dbConnection.GetRecords(tc.deviceID, 5, "") mockObj.AssertExpectations(t) p.Assert(t, SQLQuerySuccessCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedSuccessMetric)) p.Assert(t, SQLQueryFailureCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedFailureMetric)) @@ -158,7 +158,7 @@ func TestGetRecordsOfType(t *testing.T) { p.Assert(t, SQLQueryFailureCounter)(xmetricstest.Value(0.0)) p.Assert(t, SQLReadRecordsCounter)(xmetricstest.Value(0.0)) - records, err := dbConnection.GetRecordsOfType(tc.deviceID, 5, tc.eventType) + records, err := dbConnection.GetRecordsOfType(tc.deviceID, 5, tc.eventType, "") mockObj.AssertExpectations(t) p.Assert(t, SQLQuerySuccessCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedSuccessMetric)) p.Assert(t, SQLQueryFailureCounter, db.TypeLabel, db.ReadType)(xmetricstest.Value(tc.expectedFailureMetric))