Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

callout: try to renew jwt when expire #4814

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/auth_callout.go
Expand Up @@ -289,7 +289,7 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options) (authorize
}

// Check if we need to set an auth timer if the user jwt expires.
c.setExpiration(arc.Claims(), expiration)
c.setRenewal(arc.Claims(), expiration)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have a buffer imo. So if expiration is say 30mins, we should ask for a renewal at some point before 30mins in case auth callout service is slow to respond. So some percentage, say 10-15% (so for 30mins would be 3mins, so timer would be 27m not 30m). We should also set a max offset for smaller ttls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to retry for slow responds?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not believe so, at least to start. Just have the buffer so we are not cutting things too close.


respCh <- _EMPTY_
}
Expand Down
80 changes: 53 additions & 27 deletions server/auth_callout_test.go
Expand Up @@ -207,18 +207,23 @@ func TestAuthCalloutBasics(t *testing.T) {
}
`
callouts := uint32(0)
waitForCallout := make(chan struct{})
handler := func(m *nats.Msg) {
atomic.AddUint32(&callouts, 1)
calls := atomic.AddUint32(&callouts, 1)
user, si, ci, opts, _ := decodeAuthRequest(t, m.Data)
require_True(t, si.Name == "A")
require_True(t, ci.Host == "127.0.0.1")
// Allow dlc user.
if opts.Username == "dlc" && opts.Password == "zzz" {
if calls <= 3 && opts.Username == "dlc" && opts.Password == "zzz" {
var j jwt.UserPermissionLimits
j.Pub.Allow.Add("$SYS.>")
if calls == 3 {
j.Pub.Allow.Add("ramon.>")
}
j.Payload = 1024
ujwt := createAuthUser(t, user, _EMPTY_, globalAccountName, "", nil, 10*time.Minute, &j)
ujwt := createAuthUser(t, user, _EMPTY_, globalAccountName, "", nil, 10*time.Second, &j)
m.Respond(serviceResponse(t, user, si.ID, ujwt, "", 0))
waitForCallout <- struct{}{}
} else {
// Nil response signals no authentication.
m.Respond(nil)
Expand All @@ -233,33 +238,54 @@ func TestAuthCalloutBasics(t *testing.T) {
// This one will use callout since not defined in server config.
nc := at.Connect(nats.UserInfo("dlc", "zzz"))

resp, err := nc.Request(userDirectInfoSubj, nil, time.Second)
require_NoError(t, err)
response := ServerAPIResponse{Data: &UserInfo{}}
err = json.Unmarshal(resp.Data, &response)
require_NoError(t, err)

userInfo := response.Data.(*UserInfo)
compareUserInfo := func(perm ...string) {
time.Sleep(100 * time.Millisecond)
resp, err := nc.Request(userDirectInfoSubj, nil, time.Second)
require_NoError(t, err)
response := ServerAPIResponse{Data: &UserInfo{}}
err = json.Unmarshal(resp.Data, &response)
require_NoError(t, err)

dlc := &UserInfo{
UserID: "dlc",
Account: globalAccountName,
Permissions: &Permissions{
Publish: &SubjectPermission{
Allow: []string{"$SYS.>"},
Deny: []string{AuthCalloutSubject}, // Will be auto-added since in auth account.
userInfo := response.Data.(*UserInfo)

dlc := &UserInfo{
UserID: "dlc",
Account: globalAccountName,
Permissions: &Permissions{
Publish: &SubjectPermission{
Allow: append([]string{"$SYS.>"}, perm...),
Deny: []string{AuthCalloutSubject}, // Will be auto-added since in auth account.
},
Subscribe: &SubjectPermission{},
},
Subscribe: &SubjectPermission{},
},
}
expires := userInfo.Expires
userInfo.Expires = 0
if !reflect.DeepEqual(dlc, userInfo) {
t.Fatalf("User info for %q did not match", "dlc")
}
if expires > 10*time.Minute || expires < (10*time.Minute-5*time.Second) {
t.Fatalf("Expected expires of ~%v, got %v", 10*time.Minute, expires)
}
expires := userInfo.Expires
userInfo.Expires = 0
if !reflect.DeepEqual(dlc, userInfo) {
dlcJson, _ := json.MarshalIndent(dlc, "", " ")
userInfoJson, _ := json.MarshalIndent(userInfo, "", " ")
t.Fatalf("User info for %q did not match %s %s", "dlc", dlcJson, userInfoJson)
}
if expires > 10*time.Second || expires < (10*time.Second-5*time.Second) {
t.Fatalf("Expected expires of ~%v, got %v", 10*time.Second, expires)
}
}

<-waitForCallout
compareUserInfo()

// Wait for a second valid callout with a new permission.
<-waitForCallout
compareUserInfo("ramon.>")

disconnected := make(chan struct{})
nc.SetErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) {
if err != nats.ErrAuthExpired {
t.Fatalf("Expected %v, got %v", nats.ErrAuthExpired, err)
}
close(disconnected)
})
<-disconnected
}

func TestAuthCalloutMultiAccounts(t *testing.T) {
Expand Down
71 changes: 68 additions & 3 deletions server/client.go
Expand Up @@ -254,6 +254,7 @@ type client struct {
darray []string
pcd map[*client]struct{}
atmr *time.Timer
rtmr *time.Timer // renew timer.
expires time.Time
ping pinfo
msgb [msgScratchSize]byte
Expand Down Expand Up @@ -1144,9 +1145,17 @@ func (c *client) mergeDenyPermissionsLocked(what denyType, denyPubs []string) {
// Check to see if we have an expiration for the user JWT via base claims.
// FIXME(dlc) - Clear on connect with new JWT.
func (c *client) setExpiration(claims *jwt.ClaimsData, validFor time.Duration) {
c.setTimer(claims, validFor, c.setExpirationTimer)
}

func (c *client) setRenewal(claims *jwt.ClaimsData, validFor time.Duration) {
c.setTimer(claims, validFor, c.setRenewalTimer)
}

func (c *client) setTimer(claims *jwt.ClaimsData, validFor time.Duration, f func(time.Duration)) {
if claims.Expires == 0 {
if validFor != 0 {
c.setExpirationTimer(validFor)
f(validFor)
}
return
}
Expand All @@ -1156,9 +1165,9 @@ func (c *client) setExpiration(claims *jwt.ClaimsData, validFor time.Duration) {
expiresAt = time.Duration(claims.Expires-tn) * time.Second
}
if validFor != 0 && validFor < expiresAt {
c.setExpirationTimer(validFor)
f(validFor)
} else {
c.setExpirationTimer(expiresAt)
f(expiresAt)
}
}

Expand Down Expand Up @@ -4872,6 +4881,10 @@ func (c *client) clearTlsToTimer() {

// Lock should be held
func (c *client) setAuthTimer(d time.Duration) {
if c.atmr != nil {
c.atmr.Stop()
}

c.atmr = time.AfterFunc(d, c.authTimeout)
}

Expand All @@ -4885,6 +4898,16 @@ func (c *client) clearAuthTimer() bool {
return stopped
}

// Lock should be held
func (c *client) clearRenewTimer() bool {
if c.rtmr == nil {
return true
}
stopped := c.rtmr.Stop()
c.rtmr = nil
return stopped
}

// We may reuse atmr for expiring user jwts,
// so check connectReceived.
// Lock assume held on entry.
Expand All @@ -4902,13 +4925,54 @@ func (c *client) setExpirationTimer(d time.Duration) {

// This will set the atmr for the JWT expiration time. client lock should be held before call
func (c *client) setExpirationTimerUnlocked(d time.Duration) {
// Stop any previous timer
if c.atmr != nil {
c.atmr.Stop()
}
c.atmr = time.AfterFunc(d, c.authExpired)
// This is an JWT expiration.
if c.flags.isSet(connectReceived) {
c.expires = time.Now().Add(d).Truncate(time.Second)
}
}

func (c *client) setRenewalTimer(d time.Duration) {
c.mu.Lock()
c.setRenewalTimerUnlocked(d)
c.mu.Unlock()
}

func (c *client) setRenewalTimerUnlocked(d time.Duration) {
// Stop any previous timer
if c.rtmr != nil {
c.rtmr.Stop()
}
c.rtmr = time.AfterFunc(d, c.renewCallout)
// This is an JWT expiration.
if c.flags.isSet(connectReceived) {
c.expires = time.Now().Add(d).Truncate(time.Second)
}
}

func (c *client) renewCallout() {
c.mu.Lock()
srv := c.srv
c.mu.Unlock()
if srv == nil {
return
}

authorized, _ := srv.processClientOrLeafCallout(c, srv.getOpts())
// If we are authorized, we will set the renewal timer again.
// Deny the Callout subject.
if authorized {
c.mergeDenyPermissionsLocked(pub, []string{AuthCalloutSubject})
} else {
// If we are not authorized, we will close the connection in the expiration handler.
c.authExpired()
}
}

// Return when this client expires via a claim, or 0 if not set.
func (c *client) claimExpiration() time.Duration {
c.mu.Lock()
Expand Down Expand Up @@ -5075,6 +5139,7 @@ func (c *client) closeConnection(reason ClosedState) {
c.rref++
c.flags.set(closeConnection)
c.clearAuthTimer()
c.clearRenewTimer()
c.clearPingTimer()
c.clearTlsToTimer()
c.markConnAsClosed(reason)
Expand Down