Skip to content

Commit

Permalink
callout: try to renew jwt when expire
Browse files Browse the repository at this point in the history
  • Loading branch information
ramonberrutti committed Nov 23, 2023
1 parent e8772d5 commit 28b2105
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 31 deletions.
2 changes: 1 addition & 1 deletion server/auth_callout.go
Expand Up @@ -289,7 +289,7 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options) (authorize
}

// Check if we need to set an auth timer if the user jwt expires.
c.setExpiration(arc.Claims(), expiration)
c.setRenewal(arc.Claims(), expiration)

respCh <- _EMPTY_
}
Expand Down
82 changes: 55 additions & 27 deletions server/auth_callout_test.go
Expand Up @@ -207,18 +207,24 @@ func TestAuthCalloutBasics(t *testing.T) {
}
`
callouts := uint32(0)
done := false
waitForCallout := make(chan struct{})
handler := func(m *nats.Msg) {
atomic.AddUint32(&callouts, 1)
calls := atomic.AddUint32(&callouts, 1)
user, si, ci, opts, _ := decodeAuthRequest(t, m.Data)
require_True(t, si.Name == "A")
require_True(t, ci.Host == "127.0.0.1")
// Allow dlc user.
if opts.Username == "dlc" && opts.Password == "zzz" {
if !done && opts.Username == "dlc" && opts.Password == "zzz" {
var j jwt.UserPermissionLimits
j.Pub.Allow.Add("$SYS.>")
if calls == 3 {
j.Pub.Allow.Add("ramon.>")
}
j.Payload = 1024
ujwt := createAuthUser(t, user, _EMPTY_, globalAccountName, "", nil, 10*time.Minute, &j)
ujwt := createAuthUser(t, user, _EMPTY_, globalAccountName, "", nil, 10*time.Second, &j)
m.Respond(serviceResponse(t, user, si.ID, ujwt, "", 0))
waitForCallout <- struct{}{}
} else {
// Nil response signals no authentication.
m.Respond(nil)
Expand All @@ -233,33 +239,55 @@ func TestAuthCalloutBasics(t *testing.T) {
// This one will use callout since not defined in server config.
nc := at.Connect(nats.UserInfo("dlc", "zzz"))

resp, err := nc.Request(userDirectInfoSubj, nil, time.Second)
require_NoError(t, err)
response := ServerAPIResponse{Data: &UserInfo{}}
err = json.Unmarshal(resp.Data, &response)
require_NoError(t, err)

userInfo := response.Data.(*UserInfo)
compareUserInfo := func(perm ...string) {
time.Sleep(100 * time.Millisecond)
resp, err := nc.Request(userDirectInfoSubj, nil, time.Second)
require_NoError(t, err)
response := ServerAPIResponse{Data: &UserInfo{}}
err = json.Unmarshal(resp.Data, &response)
require_NoError(t, err)

dlc := &UserInfo{
UserID: "dlc",
Account: globalAccountName,
Permissions: &Permissions{
Publish: &SubjectPermission{
Allow: []string{"$SYS.>"},
Deny: []string{AuthCalloutSubject}, // Will be auto-added since in auth account.
userInfo := response.Data.(*UserInfo)

dlc := &UserInfo{
UserID: "dlc",
Account: globalAccountName,
Permissions: &Permissions{
Publish: &SubjectPermission{
Allow: append([]string{"$SYS.>"}, perm...),
Deny: []string{AuthCalloutSubject}, // Will be auto-added since in auth account.
},
Subscribe: &SubjectPermission{},
},
Subscribe: &SubjectPermission{},
},
}
expires := userInfo.Expires
userInfo.Expires = 0
if !reflect.DeepEqual(dlc, userInfo) {
t.Fatalf("User info for %q did not match", "dlc")
}
if expires > 10*time.Minute || expires < (10*time.Minute-5*time.Second) {
t.Fatalf("Expected expires of ~%v, got %v", 10*time.Minute, expires)
}
expires := userInfo.Expires
userInfo.Expires = 0
if !reflect.DeepEqual(dlc, userInfo) {
dlcJson, _ := json.MarshalIndent(dlc, "", " ")
userInfoJson, _ := json.MarshalIndent(userInfo, "", " ")
t.Fatalf("User info for %q did not match %s %s", "dlc", dlcJson, userInfoJson)
}
if expires > 10*time.Second || expires < (10*time.Second-5*time.Second) {
t.Fatalf("Expected expires of ~%v, got %v", 10*time.Second, expires)
}
}

<-waitForCallout
compareUserInfo()

// Wait for a second valid callout with a new permission.
<-waitForCallout
compareUserInfo("ramon.>")

done = true
disconnected := make(chan struct{})
nc.SetErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) {
if err != nats.ErrAuthExpired {
t.Fatalf("Expected %v, got %v", nats.ErrAuthExpired, err)
}
close(disconnected)
})
<-disconnected
}

func TestAuthCalloutMultiAccounts(t *testing.T) {
Expand Down
69 changes: 66 additions & 3 deletions server/client.go
Expand Up @@ -254,6 +254,7 @@ type client struct {
darray []string
pcd map[*client]struct{}
atmr *time.Timer
rtmr *time.Timer // renew timer.
expires time.Time
ping pinfo
msgb [msgScratchSize]byte
Expand Down Expand Up @@ -1144,9 +1145,17 @@ func (c *client) mergeDenyPermissionsLocked(what denyType, denyPubs []string) {
// Check to see if we have an expiration for the user JWT via base claims.
// FIXME(dlc) - Clear on connect with new JWT.
func (c *client) setExpiration(claims *jwt.ClaimsData, validFor time.Duration) {
c.setTimer(claims, validFor, c.setExpirationTimer)
}

func (c *client) setRenewal(claims *jwt.ClaimsData, validFor time.Duration) {
c.setTimer(claims, validFor, c.setRenewalTimer)
}

func (c *client) setTimer(claims *jwt.ClaimsData, validFor time.Duration, f func(time.Duration)) {
if claims.Expires == 0 {
if validFor != 0 {
c.setExpirationTimer(validFor)
f(validFor)
}
return
}
Expand All @@ -1156,9 +1165,9 @@ func (c *client) setExpiration(claims *jwt.ClaimsData, validFor time.Duration) {
expiresAt = time.Duration(claims.Expires-tn) * time.Second
}
if validFor != 0 && validFor < expiresAt {
c.setExpirationTimer(validFor)
f(validFor)
} else {
c.setExpirationTimer(expiresAt)
f(expiresAt)
}
}

Expand Down Expand Up @@ -4872,6 +4881,10 @@ func (c *client) clearTlsToTimer() {

// Lock should be held
func (c *client) setAuthTimer(d time.Duration) {
if c.atmr != nil {
c.atmr.Stop()
}

c.atmr = time.AfterFunc(d, c.authTimeout)
}

Expand All @@ -4885,6 +4898,16 @@ func (c *client) clearAuthTimer() bool {
return stopped
}

// Lock should be held
func (c *client) clearRenewTimer() bool {
if c.rtmr == nil {
return true
}
stopped := c.rtmr.Stop()
c.rtmr = nil
return stopped
}

// We may reuse atmr for expiring user jwts,
// so check connectReceived.
// Lock assume held on entry.
Expand All @@ -4902,13 +4925,52 @@ func (c *client) setExpirationTimer(d time.Duration) {

// This will set the atmr for the JWT expiration time. client lock should be held before call
func (c *client) setExpirationTimerUnlocked(d time.Duration) {
// Stop any previous timer
if c.atmr != nil {
c.atmr.Stop()
}
c.atmr = time.AfterFunc(d, c.authExpired)
// This is an JWT expiration.
if c.flags.isSet(connectReceived) {
c.expires = time.Now().Add(d).Truncate(time.Second)
}
}

func (c *client) setRenewalTimer(d time.Duration) {
c.mu.Lock()
c.setRenewalTimerUnlocked(d)
c.mu.Unlock()
}

func (c *client) setRenewalTimerUnlocked(d time.Duration) {
// Stop any previous timer
if c.rtmr != nil {
c.rtmr.Stop()
}
c.rtmr = time.AfterFunc(d, func() {
c.mu.Lock()
srv := c.srv
c.mu.Unlock()
if srv == nil {
return
}

authorized, _ := srv.processClientOrLeafCallout(c, srv.getOpts())
// If we are authorized, we will set the renewal timer again.
// Deny the Callout subject.
if authorized {
c.mergeDenyPermissionsLocked(pub, []string{AuthCalloutSubject})
} else {
// If we are not authorized, we will close the connection in the expiration handler.
c.authExpired()
}
})
// This is an JWT expiration.
if c.flags.isSet(connectReceived) {
c.expires = time.Now().Add(d).Truncate(time.Second)
}
}

// Return when this client expires via a claim, or 0 if not set.
func (c *client) claimExpiration() time.Duration {
c.mu.Lock()
Expand Down Expand Up @@ -5075,6 +5137,7 @@ func (c *client) closeConnection(reason ClosedState) {
c.rref++
c.flags.set(closeConnection)
c.clearAuthTimer()
c.clearRenewTimer()
c.clearPingTimer()
c.clearTlsToTimer()
c.markConnAsClosed(reason)
Expand Down

0 comments on commit 28b2105

Please sign in to comment.