Skip to content

Commit

Permalink
check state param in oidc flow [#523]
Browse files Browse the repository at this point in the history
  • Loading branch information
roberlander2 committed Jan 24, 2023
1 parent 7970dac commit 7b33c9e
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 62 deletions.
6 changes: 3 additions & 3 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
"filename": "core/auth/auth_type_oauth2.go",
"hashed_secret": "f8dabcd142517fb16418a3a2790f3da31392537a",
"is_verified": false,
"line_number": 122
"line_number": 96
}
],
"core/auth/auth_type_oidc.go": [
Expand All @@ -202,7 +202,7 @@
"filename": "core/auth/auth_type_oidc.go",
"hashed_secret": "f8dabcd142517fb16418a3a2790f3da31392537a",
"is_verified": false,
"line_number": 123
"line_number": 115
}
],
"driven/emailer/adapter.go": [
Expand Down Expand Up @@ -288,5 +288,5 @@
}
]
},
"generated_at": "2023-01-20T22:03:24Z"
"generated_at": "2023-01-24T00:06:16Z"
}
7 changes: 1 addition & 6 deletions core/auth/auth_type_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,7 @@ func (a *oauthAuthImpl) externalLogin(authType model.AuthType, appType model.App
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, typeAuthConfig, logutils.StringArgs(a.authType), err)
}

code, err := config.GetAuthorizationCode(creds, params)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, "authorization code", logutils.StringArgs(a.authType), err)
}

return a.loadTokensAndInfo(config, authType, appOrg, code, params, false, l)
return a.loadTokensAndInfo(config, authType, appOrg, creds, params, false, l)
}

func (a *oauthAuthImpl) refresh(params map[string]interface{}, authType model.AuthType, appType model.ApplicationType, appOrg model.ApplicationOrganization, l *logs.Log) (*model.ExternalSystemUser, map[string]interface{}, error) {
Expand Down
59 changes: 19 additions & 40 deletions core/auth/auth_type_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,32 +83,6 @@ func (o *oauth2AuthConfig) GetUserInfoURL() string {
return url
}

func (o *oauth2AuthConfig) GetAuthorizationCode(creds string, params string) (string, error) {
var loginParams oauth2LoginParams
if o.UseState {
err := json.Unmarshal([]byte(params), &loginParams)
if err != nil {
return "", errors.WrapErrorAction(logutils.ActionUnmarshal, typeOAuth2LoginParams, nil, err)
}
validate := validator.New()
err = validate.Struct(loginParams)
if err != nil {
return "", errors.WrapErrorAction(logutils.ActionValidate, typeOAuth2LoginParams, nil, err)
}
}

parsedCreds, err := utils.QueryValuesFromURL(creds)
if err != nil {
return "", errors.WrapErrorAction(logutils.ActionParse, "oauth2 creds", nil, err)
}
//state in creds must match state generated for login url (if used)
if o.UseState && loginParams.State != parsedCreds.Get("state") {
return "", errors.ErrorData(logutils.StatusInvalid, "oauth2 login", &logutils.FieldArgs{"state": parsedCreds.Get("state")})
}

return parsedCreds.Get("code"), nil
}

func (o *oauth2AuthConfig) BuildNewTokenRequest(creds string, params string, refresh bool) (*oauthprovider.OAuthRequest, map[string]interface{}, error) {
if refresh && !o.UseRefresh {
return nil, nil, nil
Expand All @@ -125,7 +99,24 @@ func (o *oauth2AuthConfig) BuildNewTokenRequest(creds string, params string, ref
body["refresh_token"] = creds
body["grant_type"] = "refresh_token"
} else {
body["code"] = creds
parsedCreds, err := utils.QueryValuesFromURL(creds)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionParse, "oauth2 creds", nil, err)
}

if o.UseState {
var loginParams oauth2LoginParams
err = json.Unmarshal([]byte(params), &loginParams)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionUnmarshal, typeOAuth2LoginParams, nil, err)
}
//state in creds must match state generated for login url (if used)
if loginParams.State != parsedCreds.Get("state") {
return nil, nil, errors.ErrorData(logutils.StatusInvalid, "oauth2 login", &logutils.FieldArgs{"state": parsedCreds.Get("state")})
}
}

body["code"] = parsedCreds.Get("code")
}

encoded := utils.EncodeQueryValues(body)
Expand Down Expand Up @@ -172,7 +163,7 @@ func (o *oauth2AuthConfig) BuildLoginURLResponse(redirectURI string) (string, ma

responseParams := make(map[string]interface{})
if o.UseState {
state, err := o.generateState()
state, err := utils.GenerateRandomString(50)
if err != nil {
return "", nil, errors.WrapErrorAction("generating", "random state", nil, err)
}
Expand All @@ -183,18 +174,6 @@ func (o *oauth2AuthConfig) BuildLoginURLResponse(redirectURI string) (string, ma
return o.GetAuthorizeURL() + "?" + utils.EncodeQueryValues(query), responseParams, nil
}

// --- Helper functions ---

// generateState generates and returns a randomized state string
func (o *oauth2AuthConfig) generateState() (string, error) {
state, err := utils.GenerateRandomString(50)
if err != nil {
return "", errors.WrapErrorAction("generating", "state string", nil, err)
}

return state, nil
}

type oauth2Token struct {
AccessToken string `json:"access_token" validate:"required"`
RefreshToken string `json:"refresh_token"`
Expand Down
30 changes: 20 additions & 10 deletions core/auth/auth_type_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type oidcAuthConfig struct {
Scopes string `json:"scopes"`
UseRefresh bool `json:"use_refresh"`
UsePKCE bool `json:"use_pkce"`
UseState bool `json:"use_state"`
ClientID string `json:"client_id" validate:"required"`
ClientSecret string `json:"client_secret"`
AuthorizeClaims string `json:"authorize_claims"`
Expand Down Expand Up @@ -90,15 +91,6 @@ func (o *oidcAuthConfig) GetUserInfoURL() string {
return url
}

func (o *oidcAuthConfig) GetAuthorizationCode(creds string, params string) (string, error) {
parsedCreds, err := utils.QueryValuesFromURL(creds)
if err != nil {
return "", errors.WrapErrorAction(logutils.ActionParse, "oidc creds", nil, err)
}

return parsedCreds.Get("code"), nil
}

func (o *oidcAuthConfig) BuildNewTokenRequest(creds string, params string, refresh bool) (*oauthprovider.OAuthRequest, map[string]interface{}, error) {
if refresh && !o.UseRefresh {
return nil, nil, nil
Expand Down Expand Up @@ -126,7 +118,16 @@ func (o *oidcAuthConfig) BuildNewTokenRequest(creds string, params string, refre
body["refresh_token"] = creds
body["grant_type"] = "refresh_token"
} else {
body["code"] = creds
parsedCreds, err := utils.QueryValuesFromURL(creds)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionParse, "oidc creds", nil, err)
}
//state in creds must match state generated for login url (if used)
if o.UseState && loginParams.State != parsedCreds.Get("state") {
return nil, nil, errors.ErrorData(logutils.StatusInvalid, "oidc login", &logutils.FieldArgs{"state": parsedCreds.Get("state")})
}

body["code"] = parsedCreds.Get("code")
body["grant_type"] = "authorization_code"

if len(loginParams.CodeVerifier) > 0 {
Expand Down Expand Up @@ -224,6 +225,14 @@ func (o *oidcAuthConfig) BuildLoginURLResponse(redirectURI string) (string, map[
query["code_challenge"] = codeChallenge
responseParams["pkce_verifier"] = codeVerifier
}
if o.UseState {
state, err := utils.GenerateRandomString(50)
if err != nil {
return "", nil, errors.WrapErrorAction("generating", "random state", nil, err)
}
query["state"] = state
responseParams["state"] = state
}

return o.GetAuthorizeURL() + "?" + utils.EncodeQueryValues(query), responseParams, nil
}
Expand Down Expand Up @@ -273,6 +282,7 @@ func (t *oidcToken) GetIDToken() string {

type oidcLoginParams struct {
CodeVerifier string `json:"pkce_verifier"`
State string `json:"state"`
RedirectURI string `json:"redirect_uri"`
}

Expand Down
3 changes: 0 additions & 3 deletions driven/oauthprovider/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,10 @@ type OAuthConfig interface {
GetTokenURL() string
GetUserInfoURL() string

GetAuthorizationCode(creds string, params string) (string, error)
BuildNewTokenRequest(creds string, params string, refresh bool) (*OAuthRequest, map[string]interface{}, error)
ParseTokenResponse(response []byte, params map[string]interface{}) (OAuthToken, map[string]interface{}, error)
BuildLoginURLResponse(redirectURI string) (string, map[string]interface{}, error)

// GetResponseParams(params string) (map[string]string, error)

CheckIDToken(token OAuthToken) (string, error)
CheckSubject(tokenSubject string, userSubject string) bool
}
Expand Down

0 comments on commit 7b33c9e

Please sign in to comment.