Skip to content

Commit

Permalink
Merge pull request #582 from eternal-flame-AD/lastping
Browse files Browse the repository at this point in the history
add last seen field to client (fixes #400)
  • Loading branch information
jmattheis committed Aug 6, 2023
2 parents a444182 + a3ce298 commit 8c0f7a9
Show file tree
Hide file tree
Showing 20 changed files with 252 additions and 33 deletions.
3 changes: 2 additions & 1 deletion api/application_test.go
Expand Up @@ -91,8 +91,9 @@ func (s *ApplicationSuite) Test_ensureApplicationHasCorrectJsonRepresentation()
Description: "mydesc",
Image: "asd",
Internal: true,
LastUsed: nil,
}
test.JSONEquals(s.T(), actual, `{"id":1,"token":"Aasdasfgeeg","name":"myapp","description":"mydesc", "image": "asd", "internal":true, "defaultPriority":0}`)
test.JSONEquals(s.T(), actual, `{"id":1,"token":"Aasdasfgeeg","name":"myapp","description":"mydesc", "image": "asd", "internal":true, "defaultPriority":0, "lastUsed":null}`)
}

func (s *ApplicationSuite) Test_CreateApplication_expectBadRequestOnEmptyName() {
Expand Down
2 changes: 1 addition & 1 deletion api/client_test.go
Expand Up @@ -58,7 +58,7 @@ func (s *ClientSuite) AfterTest(suiteName, testName string) {

func (s *ClientSuite) Test_ensureClientHasCorrectJsonRepresentation() {
actual := &model.Client{ID: 1, UserID: 2, Token: "Casdasfgeeg", Name: "myclient"}
test.JSONEquals(s.T(), actual, `{"id":1,"token":"Casdasfgeeg","name":"myclient"}`)
test.JSONEquals(s.T(), actual, `{"id":1,"token":"Casdasfgeeg","name":"myclient","lastUsed":null}`)
}

func (s *ClientSuite) Test_CreateClient_mapAllParameters() {
Expand Down
25 changes: 25 additions & 0 deletions api/stream/stream.go
Expand Up @@ -37,6 +37,19 @@ func New(pingPeriod, pongTimeout time.Duration, allowedWebSocketOrigins []string
}
}

// CollectConnectedClientTokens returns all tokens of the connected clients.
func (a *API) CollectConnectedClientTokens() []string {
a.lock.RLock()
defer a.lock.RUnlock()
var clients []string
for _, cs := range a.clients {
for _, c := range cs {
clients = append(clients, c.token)
}
}
return uniq(clients)
}

// NotifyDeletedUser closes existing connections for the given user.
func (a *API) NotifyDeletedUser(userID uint) error {
a.lock.Lock()
Expand Down Expand Up @@ -155,6 +168,18 @@ func (a *API) Close() {
}
}

func uniq[T comparable](s []T) []T {
m := make(map[T]struct{})
for _, v := range s {
m[v] = struct{}{}
}
var r []T
for k := range m {
r = append(r, k)
}
return r
}

func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool {
origin := r.Header.Get("origin")
if origin == "" {
Expand Down
92 changes: 75 additions & 17 deletions api/stream/stream_test.go
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -56,8 +57,8 @@ func TestWriteMessageFails(t *testing.T) {
wsURL := wsURL(server.URL)
user := testClient(t, wsURL)

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, 1)

clients := clients(api, 1)
assert.NotEmpty(t, clients)

Expand Down Expand Up @@ -86,13 +87,13 @@ func TestWritePingFails(t *testing.T) {
user := testClient(t, wsURL)
defer user.conn.Close()

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, 1)

clients := clients(api, 1)

assert.NotEmpty(t, clients)

time.Sleep(api.pingPeriod) // waiting for ping
time.Sleep(api.pingPeriod + (50 * time.Millisecond)) // waiting for ping

api.Notify(1, &model.MessageExternal{Message: "HI"})
user.expectNoMessage()
Expand Down Expand Up @@ -147,8 +148,8 @@ func TestCloseClientOnNotReading(t *testing.T) {
assert.Nil(t, err)
defer ws.Close()

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, 1)

assert.NotEmpty(t, clients(api, 1))

time.Sleep(api.pingPeriod + api.pongTimeout)
Expand All @@ -167,8 +168,9 @@ func TestMessageDirectlyAfterConnect(t *testing.T) {

user := testClient(t, wsURL)
defer user.conn.Close()
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)

waitForConnectedClients(api, 1)

api.Notify(1, &model.MessageExternal{Message: "msg"})
user.expectMessage(&model.MessageExternal{Message: "msg"})
}
Expand All @@ -184,8 +186,9 @@ func TestDeleteClientShouldCloseConnection(t *testing.T) {

user := testClient(t, wsURL)
defer user.conn.Close()
// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)

waitForConnectedClients(api, 1)

api.Notify(1, &model.MessageExternal{Message: "msg"})
user.expectMessage(&model.MessageExternal{Message: "msg"})

Expand Down Expand Up @@ -230,8 +233,7 @@ func TestDeleteMultipleClients(t *testing.T) {
defer userThreeAndroid.conn.Close()
userThree := []*testingClient{userThreeAndroid}

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, len(userOne)+len(userTwo)+len(userThree))

api.Notify(1, &model.MessageExternal{ID: 4, Message: "there"})
expectMessage(&model.MessageExternal{ID: 4, Message: "there"}, userOne...)
Expand Down Expand Up @@ -294,8 +296,7 @@ func TestDeleteUser(t *testing.T) {
defer userThreeAndroid.conn.Close()
userThree := []*testingClient{userThreeAndroid}

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, len(userOne)+len(userTwo)+len(userThree))

api.Notify(1, &model.MessageExternal{ID: 4, Message: "there"})
expectMessage(&model.MessageExternal{ID: 4, Message: "there"}, userOne...)
Expand All @@ -322,6 +323,43 @@ func TestDeleteUser(t *testing.T) {
api.Close()
}

func TestCollectConnectedClientTokens(t *testing.T) {
mode.Set(mode.TestDev)

defer leaktest.Check(t)()
userIDs := []uint{1, 1, 1, 2, 2}
tokens := []string{"1-1", "1-2", "1-2", "2-1", "2-2"}
i := 0
server, api := bootTestServer(func(context *gin.Context) {
auth.RegisterAuthentication(context, nil, userIDs[i], tokens[i])
i++
})
defer server.Close()

wsURL := wsURL(server.URL)
userOneConnOne := testClient(t, wsURL)
defer userOneConnOne.conn.Close()
userOneConnTwo := testClient(t, wsURL)
defer userOneConnTwo.conn.Close()
userOneConnThree := testClient(t, wsURL)
defer userOneConnThree.conn.Close()
waitForConnectedClients(api, 3)

ret := api.CollectConnectedClientTokens()
sort.Strings(ret)
assert.Equal(t, []string{"1-1", "1-2"}, ret)

userTwoConnOne := testClient(t, wsURL)
defer userTwoConnOne.conn.Close()
userTwoConnTwo := testClient(t, wsURL)
defer userTwoConnTwo.conn.Close()
waitForConnectedClients(api, 5)

ret = api.CollectConnectedClientTokens()
sort.Strings(ret)
assert.Equal(t, []string{"1-1", "1-2", "2-1", "2-2"}, ret)
}

func TestMultipleClients(t *testing.T) {
mode.Set(mode.TestDev)

Expand Down Expand Up @@ -354,8 +392,7 @@ func TestMultipleClients(t *testing.T) {
defer userThreeAndroid.conn.Close()
userThree := []*testingClient{userThreeAndroid}

// the server may take some time to register the client
time.Sleep(100 * time.Millisecond)
waitForConnectedClients(api, len(userOne)+len(userTwo)+len(userThree))

// there should not be messages at the beginning
expectNoMessage(userOne...)
Expand Down Expand Up @@ -474,6 +511,17 @@ func clients(api *API, user uint) []*client {
return api.clients[user]
}

func countClients(a *API) int {
a.lock.RLock()
defer a.lock.RUnlock()

var i int
for _, clients := range a.clients {
i += len(clients)
}
return i
}

func testClient(t *testing.T, url string) *testingClient {
client := createClient(t, url)
startReading(client)
Expand Down Expand Up @@ -560,3 +608,13 @@ func staticUserID() gin.HandlerFunc {
auth.RegisterAuthentication(context, nil, 1, "customtoken")
}
}

func waitForConnectedClients(api *API, count int) {
for i := 0; i < 10; i++ {
if countClients(api) == count {
// ok
return
}
time.Sleep(10 * time.Millisecond)
}
}
27 changes: 21 additions & 6 deletions auth/authentication.go
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"errors"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/gotify/server/v2/auth/password"
Expand All @@ -20,6 +21,8 @@ type Database interface {
GetPluginConfByToken(token string) (*model.PluginConf, error)
GetUserByName(name string) (*model.User, error)
GetUserByID(id uint) (*model.User, error)
UpdateClientTokensLastUsed(tokens []string, t *time.Time) error
UpdateApplicationTokenLastUsed(token string, t *time.Time) error
}

// Auth is the provider for authentication middleware.
Expand Down Expand Up @@ -56,10 +59,16 @@ func (a *Auth) RequireClient() gin.HandlerFunc {
if user != nil {
return true, true, user.ID, nil
}
if token, err := a.DB.GetClientByToken(tokenID); err != nil {
if client, err := a.DB.GetClientByToken(tokenID); err != nil {
return false, false, 0, err
} else if token != nil {
return true, true, token.UserID, nil
} else if client != nil {
now := time.Now()
if client.LastUsed == nil || client.LastUsed.Add(5*time.Minute).Before(now) {
if err := a.DB.UpdateClientTokensLastUsed([]string{tokenID}, &now); err != nil {
return false, false, 0, err
}
}
return true, true, client.UserID, nil
}
return false, false, 0, nil
})
Expand All @@ -71,10 +80,16 @@ func (a *Auth) RequireApplicationToken() gin.HandlerFunc {
if user != nil {
return true, false, 0, nil
}
if token, err := a.DB.GetApplicationByToken(tokenID); err != nil {
if app, err := a.DB.GetApplicationByToken(tokenID); err != nil {
return false, false, 0, err
} else if token != nil {
return true, true, token.UserID, nil
} else if app != nil {
now := time.Now()
if app.LastUsed == nil || app.LastUsed.Add(5*time.Minute).Before(now) {
if err := a.DB.UpdateApplicationTokenLastUsed(tokenID, &now); err != nil {
return false, false, 0, err
}
}
return true, true, app.UserID, nil
}
return false, false, 0, nil
})
Expand Down
7 changes: 7 additions & 0 deletions database/application.go
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
)
Expand Down Expand Up @@ -56,3 +58,8 @@ func (d *GormDatabase) GetApplicationsByUser(userID uint) ([]*model.Application,
func (d *GormDatabase) UpdateApplication(app *model.Application) error {
return d.DB.Save(app).Error
}

// UpdateApplicationTokenLastUsed updates the last used time of the application token.
func (d *GormDatabase) UpdateApplicationTokenLastUsed(token string, t *time.Time) error {
return d.DB.Model(&model.Application{}).Where("token = ?", token).Update("last_used", t).Error
}
10 changes: 10 additions & 0 deletions database/application_test.go
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/gotify/server/v2/model"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -40,6 +42,14 @@ func (s *DatabaseSuite) TestApplication() {
assert.Equal(s.T(), app, newApp)
}

lastUsed := time.Now().Add(-time.Hour)
s.db.UpdateApplicationTokenLastUsed(app.Token, &lastUsed)
newApp, err = s.db.GetApplicationByID(app.ID)
if assert.NoError(s.T(), err) {
assert.Equal(s.T(), lastUsed.Unix(), newApp.LastUsed.Unix())
}
app.LastUsed = &lastUsed

newApp.Image = "asdasd"
assert.NoError(s.T(), s.db.UpdateApplication(newApp))

Expand Down
7 changes: 7 additions & 0 deletions database/client.go
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/gotify/server/v2/model"
"github.com/jinzhu/gorm"
)
Expand Down Expand Up @@ -55,3 +57,8 @@ func (d *GormDatabase) DeleteClientByID(id uint) error {
func (d *GormDatabase) UpdateClient(client *model.Client) error {
return d.DB.Save(client).Error
}

// UpdateClientTokensLastUsed updates the last used timestamp of clients.
func (d *GormDatabase) UpdateClientTokensLastUsed(tokens []string, t *time.Time) error {
return d.DB.Model(&model.Client{}).Where("token IN (?)", tokens).Update("last_used", t).Error
}
9 changes: 9 additions & 0 deletions database/client_test.go
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/gotify/server/v2/model"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -44,6 +46,13 @@ func (s *DatabaseSuite) TestClient() {
assert.Equal(s.T(), updateClient, updatedClient)
}

lastUsed := time.Now().Add(-time.Hour)
s.db.UpdateClientTokensLastUsed([]string{client.Token}, &lastUsed)
newClient, err = s.db.GetClientByID(client.ID)
if assert.NoError(s.T(), err) {
assert.Equal(s.T(), lastUsed.Unix(), newClient.LastUsed.Unix())
}

s.db.DeleteClientByID(client.ID)

if clients, err := s.db.GetClientsByUser(user.ID); assert.NoError(s.T(), err) {
Expand Down

0 comments on commit 8c0f7a9

Please sign in to comment.