From d2d0245d5867084244b12a62b522e0df74fdd85a Mon Sep 17 00:00:00 2001 From: Denis Arh Date: Wed, 23 Feb 2022 18:54:34 +0100 Subject: [PATCH] Password reset rate limiting & existing token invalidation --- system/service/auth.go | 201 ++++++++++++++++++++++------- system/service/auth_actions.gen.go | 36 ++++++ system/service/auth_actions.yaml | 4 + system/service/auth_test.go | 111 +++++++++++++++- 4 files changed, 301 insertions(+), 51 deletions(-) diff --git a/system/service/auth.go b/system/service/auth.go index 5d04136d2d..0ce55e7b5c 100644 --- a/system/service/auth.go +++ b/system/service/auth.go @@ -64,10 +64,17 @@ const ( passwordMinLength = 8 passwordMaxLength = 256 + + tokenReqMaxCount = 5 + tokenReqMaxWindow = time.Minute * 15 ) var ( reEmail = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$") + + oneTokenPerUser = map[string]bool{ + credentialsTypeResetPasswordToken: true, + } ) func defaultProviderValidator(provider string) error { @@ -1007,67 +1014,55 @@ func (svc auth) procLogin(ctx context.Context, s store.Storer, u *types.User, c return nil } -func (svc auth) loadUserFromToken(ctx context.Context, token, kind string) (u *types.User, err error) { +// Loads user from token and removes that token right after +func (svc auth) loadUserFromToken(ctx context.Context, token, kind string) (u *types.User, _ error) { var ( aam = &authActionProps{ credentials: &types.Credentials{Kind: kind}, } ) - credentialsID, credentials := svc.validateToken(token) - if credentialsID == 0 { - return nil, AuthErrInvalidToken(aam) - } - - c, err := store.LookupCredentialsByID(ctx, svc.store, credentialsID) - if errors.IsNotFound(err) { - return nil, AuthErrInvalidToken(aam) - } - - aam.setCredentials(c) - - if err != nil { - return - } + return u, svc.store.Tx(ctx, func(ctx context.Context, s store.Storer) (err error) { + credentialsID, credentials := validateToken(token) + if credentialsID == 0 { + return AuthErrInvalidToken(aam) + } - if err = store.DeleteCredentialsByID(ctx, svc.store, c.ID); err != nil { - return - } + c, err := store.LookupCredentialsByID(ctx, s, credentialsID) + if errors.IsNotFound(err) { + return AuthErrInvalidToken(aam) + } - if !c.Valid() || c.Credentials != credentials { - return nil, AuthErrInvalidToken(aam) - } + aam.setCredentials(c) - u, err = store.LookupUserByID(ctx, svc.store, c.OwnerID) - if err != nil { - return nil, err - } + if err != nil { + return + } - aam.setUser(u) + if err = store.DeleteCredentialsByID(ctx, s, c.ID); err != nil { + return + } - // context will be updated with new identity - // in the caller fn + if !c.Valid() || c.Credentials != credentials { + return AuthErrInvalidToken(aam) + } - if !u.Valid() { - return nil, AuthErrInvalidCredentials(aam) - } + u, err = store.LookupUserByID(ctx, s, c.OwnerID) + if err != nil { + return err + } - return u, nil -} + aam.setUser(u) -func (svc auth) validateToken(token string) (ID uint64, credentials string) { - // Token = <32 random chars> - if len(token) <= credentialsTokenLength { - return - } + // context will be updated with new identity + // in the caller fn - ID, _ = strconv.ParseUint(token[credentialsTokenLength:], 10, 64) - if ID == 0 { - return - } + if !u.Valid() { + return AuthErrInvalidCredentials(aam) + } - credentials = token[:credentialsTokenLength] - return + return nil + }) } // Generates & stores user token @@ -1081,7 +1076,35 @@ func (svc auth) createUserToken(ctx context.Context, u *types.User, kind string) } ) - err = func() error { + err = svc.store.Tx(ctx, func(ctx context.Context, s store.Storer) (err error) { + if u == nil || u.ID == 0 { + return AuthErrGeneric() + } + + // Rate limit requests + cc, _, err := store.SearchCredentials(ctx, s, types.CredentialsFilter{ + OwnerID: u.ID, + Kind: kind, + + // we want to count deleted tokens as well + Deleted: filter.StateInclusive, + }) + + if err != nil { + return err + } + + // gt/eq since this current request is not yet stored + if err = svc.checkTokenRate(cc, tokenReqMaxWindow, tokenReqMaxCount); err != nil { + return + } + + // removes expired and soft-deleted tokens + // and enforces one-token-per-user rule + if err = svc.cleanupCredentials(ctx, s, cc); err != nil { + return + } + switch kind { case credentialsTypeMFAEmailOTP: expSec := svc.settings.Auth.MultiFactor.EmailOTP.Expires @@ -1118,7 +1141,7 @@ func (svc auth) createUserToken(ctx context.Context, u *types.User, kind string) ExpiresAt: &expiresAt, } - err = store.CreateCredentials(ctx, svc.store, c) + err = store.CreateCredentials(ctx, s, c) if err != nil { return err @@ -1134,11 +1157,78 @@ func (svc auth) createUserToken(ctx context.Context, u *types.User, kind string) } return nil - }() + }) return token, svc.recordAction(ctx, aam, AuthActionIssueToken, err) } +// checks existing tokens and ensure that the creation rate is within limits +func (svc auth) checkTokenRate(cc types.CredentialsSet, window time.Duration, max int) error { + if len(cc) == 0 || window == 0 || max == 0 { + return nil + } + + var ( + cutoff = now().Add(window * -1) + count = 0 + ) + + for _, c := range cc { + if c.CreatedAt.Before(cutoff) { + // skip tokens created before cutoff + continue + } + + count++ + + if count > max { + break + } + } + + if count > max { + return AuthErrRateLimitExceeded() + } + + return nil +} + +func (svc auth) cleanupCredentials(ctx context.Context, s store.Credentials, cc types.CredentialsSet) (err error) { + var ( + update types.CredentialsSet + remove types.CredentialsSet + ) + + for _, c := range cc { + switch { + case oneTokenPerUser[c.Kind]: + // if token type is shortlisted in one-token-per-user + // mark all existing tokens as deleted if to + // + // only want to mark them as deleted ad + c.DeletedAt = now() + update = append(update, c) + + case false, // just a placeholder + c.DeletedAt.Add(tokenReqMaxWindow).Before(*now()), + c.ExpiresAt.Before(*now()): + // schedule all soft-deleted and expired token + // for removal + remove = append(remove, c) + } + } + + if err = store.UpdateCredentials(ctx, s, update...); err != nil { + return + } + + if err = store.DeleteCredentials(ctx, s, remove...); err != nil { + return + } + + return +} + // Automatically promotes user to super-administrator if it is the first non-system user in the database func (svc auth) autoPromote(ctx context.Context, u *types.User) (err error) { var ( @@ -1539,3 +1629,18 @@ func (svc auth) RemoveAccessTokens(ctx context.Context, user *types.User) error svc.store.DeleteAuthOA2TokenByUserID(ctx, user.ID), ) } + +func validateToken(token string) (ID uint64, credentials string) { + // Token = <32 random chars> + if len(token) <= credentialsTokenLength { + return + } + + ID, _ = strconv.ParseUint(token[credentialsTokenLength:], 10, 64) + if ID == 0 { + return + } + + credentials = token[:credentialsTokenLength] + return +} diff --git a/system/service/auth_actions.gen.go b/system/service/auth_actions.gen.go index 1dff16feb5..c2d4ba4246 100644 --- a/system/service/auth_actions.gen.go +++ b/system/service/auth_actions.gen.go @@ -1748,6 +1748,42 @@ func AuthErrInvalidEmailOTP(mm ...*authActionProps) *errors.Error { return e } +// AuthErrRateLimitExceeded returns "system:auth.rateLimitExceeded" as *errors.Error +// +// +// This function is auto-generated. +// +func AuthErrRateLimitExceeded(mm ...*authActionProps) *errors.Error { + var p = &authActionProps{} + if len(mm) > 0 { + p = mm[0] + } + + var e = errors.New( + errors.KindInternal, + + p.Format("rate limit exceeded", nil), + + errors.Meta("type", "rateLimitExceeded"), + errors.Meta("resource", "system:auth"), + + // action log entry; no formatting, it will be applied inside recordAction fn. + errors.Meta(authLogMetaKey{}, "rate limit exceeded for {{user}}"), + errors.Meta(authPropsMetaKey{}, p), + + // translation namespace & key + errors.Meta(locale.ErrorMetaNamespace{}, "system"), + errors.Meta(locale.ErrorMetaKey{}, "auth.errors.rateLimitExceeded"), + + errors.StackSkip(1), + ) + + if len(mm) > 0 { + } + + return e +} + // AuthErrMaxUserLimitReached returns "system:auth.maxUserLimitReached" as *errors.Error // // diff --git a/system/service/auth_actions.yaml b/system/service/auth_actions.yaml index 991e75fe69..4cd378a0cf 100644 --- a/system/service/auth_actions.yaml +++ b/system/service/auth_actions.yaml @@ -203,6 +203,10 @@ errors: message: "invalid code" severity: warning + - error: rateLimitExceeded + message: "rate limit exceeded" + log: "rate limit exceeded for {{user}}" + - error: maxUserLimitReached message: "you have reached your user limit, contact your Corteza administrator" severity: warning diff --git a/system/service/auth_test.go b/system/service/auth_test.go index a8dfd1fa6f..3ef5973c67 100644 --- a/system/service/auth_test.go +++ b/system/service/auth_test.go @@ -287,6 +287,112 @@ func TestAuth_InternalLogin(t *testing.T) { } } +func TestAuth_createUserToken(t *testing.T) { + var ( + req = require.New(t) + ctx = context.Background() + + validUser = &types.User{Email: "valid@test.cortezaproject.org", ID: nextID(), CreatedAt: *now(), EmailConfirmed: true} + + tests = []struct { + name string + user *types.User + kind string + err error + }{ + { + "no user", + nil, + "", + AuthErrGeneric(), + }, + { + "zero ID", + &types.User{}, + "", + AuthErrGeneric(), + }, + { + "valid user", + validUser, + credentialsTypeResetPasswordToken, + nil, + }, + } + ) + + svc := makeMockAuthService() + req.NoError(svc.store.TruncateUsers(ctx)) + req.NoError(svc.store.TruncateCredentials(ctx)) + req.NoError(store.CreateUser(ctx, svc.store, validUser)) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req = require.New(t) + + var ( + token, err = svc.createUserToken(ctx, tt.user, tt.kind) + ) + + if tt.err == nil { + req.NoError(err) + req.NotEmpty(token) + } else { + req.EqualError(err, tt.err.Error()) + } + + }) + } +} + +// ensure that existing password reset tokens are invalidated AND rate limiting kicks in +func TestAuth_multiCreateUserTokenForPasswordReset(t *testing.T) { + var ( + err error + pToken, token string + + svc = makeMockAuthService() + + req = require.New(t) + ctx = context.Background() + + validUser = &types.User{Email: "valid@test.cortezaproject.org", ID: nextID(), CreatedAt: *now(), EmailConfirmed: true} + + // load credentials from token + t2c = func(token string) *types.Credentials { + id, _ := validateToken(token) + req.NotZero(id) + c, err := store.LookupCredentialsByID(ctx, svc.store, id) + req.NoError(err) + return c + } + ) + + req.NoError(svc.store.TruncateUsers(ctx)) + req.NoError(svc.store.TruncateCredentials(ctx)) + req.NoError(store.CreateUser(ctx, svc.store, validUser)) + + for try := 0; try <= tokenReqMaxCount+1; try++ { + token, err = svc.createUserToken(ctx, validUser, credentialsTypeResetPasswordToken) + t.Log("got token", token) + + if try == tokenReqMaxCount+1 { + t.Log("rate limiting should kicked in") + req.EqualError(err, AuthErrRateLimitExceeded().Error()) + } else { + if try > 0 { + t.Log("checking if previous token", pToken, "is deleted") + req.NotNil(t2c(pToken).DeletedAt) + } + + req.NoError(err) + req.Nil(t2c(token).DeletedAt) + pToken = token + } + } + +} + func Test_auth_checkPassword(t *testing.T) { plainPassword := " ... plain password ... " hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(plainPassword), bcrypt.DefaultCost) @@ -347,7 +453,7 @@ func Test_auth_checkPassword(t *testing.T) { } } -func Test_auth_validateToken(t *testing.T) { +func TestValidateToken(t *testing.T) { type args struct { token string } @@ -380,8 +486,7 @@ func Test_auth_validateToken(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - svc := auth{} - gotID, gotCredentials := svc.validateToken(tt.args.token) + gotID, gotCredentials := validateToken(tt.args.token) if gotID != tt.wantID { t.Errorf("auth.validateToken() gotID = %v, want %v", gotID, tt.wantID)