Skip to content

Commit

Permalink
continue merging oauth2 and oidc functions [#523]
Browse files Browse the repository at this point in the history
  • Loading branch information
roberlander2 committed Aug 24, 2022
1 parent 4b0f706 commit 70acd64
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 357 deletions.
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@
{
"type": "Secret Keyword",
"filename": "core/auth/auth_type_oidc.go",
"hashed_secret": "0ade4f3edccc8888bef404fe6b3c92c13cdfad6b",
"hashed_secret": "f8dabcd142517fb16418a3a2790f3da31392537a",
"is_verified": false,
"line_number": 361
"line_number": 122
}
],
"driven/emailer/adapter.go": [
Expand Down
2 changes: 1 addition & 1 deletion core/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,7 @@ func (a *Auth) validateAuthTypeForAppOrg(authenticationType string, appID string
}

func (a *Auth) isValidAdminAuthType(authenticationType string) bool {
return authenticationType == AuthTypeOidc || authenticationType == AuthTypeOAuth2 || authenticationType == AuthTypeEmail || strings.HasSuffix(authenticationType, "_oidc") || strings.HasSuffix(authenticationType, "_oauth2")
return authenticationType == AuthTypeEmail || strings.HasSuffix("_"+authenticationType, "_oidc") || strings.HasSuffix("_"+authenticationType, "_oauth2")
}

func (a *Auth) queryValuesFromURL(urlStr string) (url.Values, error) {
Expand Down
124 changes: 30 additions & 94 deletions core/auth/auth_type_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (o *oauth2AuthConfig) getAuthorizationCode(auth *Auth, creds string, params
return parsedCreds.Get("code"), nil
}

func (o *oauth2AuthConfig) buildNewTokenRequest(auth *Auth, token string, refresh bool) (*http.Request, error) {
func (o *oauth2AuthConfig) buildNewTokenRequest(auth *Auth, creds string, params string, refresh bool) (*http.Request, error) {
body := map[string]string{
"client_id": o.ClientID,
"redirect_uri": o.RedirectURI,
Expand All @@ -130,10 +130,10 @@ func (o *oauth2AuthConfig) buildNewTokenRequest(auth *Auth, token string, refres
body["client_secret"] = o.ClientSecret
}
if refresh {
body["refresh_token"] = token
body["refresh_token"] = creds
body["grant_type"] = "refresh_token"
} else {
body["code"] = token
body["code"] = creds
}

encoded := auth.encodeQueryValues(body)
Expand Down Expand Up @@ -207,22 +207,22 @@ func (t *oauth2Token) getResponse() map[string]interface{} {
return params
}

type oauth2RefreshParams struct {
RefreshToken string `json:"refresh_token" bson:"refresh_token" validate:"required"`
func (t *oauth2Token) getIDToken() string {
return ""
}

func (a *oauth2AuthImpl) externalLogin(authType model.AuthType, appType model.ApplicationType, appOrg model.ApplicationOrganization, creds string, params string, l *logs.Log) (*model.ExternalSystemUser, map[string]interface{}, error) {
oauth2Config, err := a.getOAuth2AuthConfig(authType, appType)
config, err := a.auth.getOAuthConfig(authType, appType)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, typeOAuth2AuthConfig, nil, err)
}

code, err := oauth2Config.getAuthorizationCode(a.auth, creds, params)
code, err := config.getAuthorizationCode(a.auth, creds, params)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, "authorization code", nil, err)
}

externalUser, parameters, err := a.loadOAuth2TokensAndInfo(oauth2Config, authType, appType, appOrg, code, false, l)
externalUser, parameters, err := a.loadOAuth2TokensAndInfo(config, authType, appOrg, code, params, false, l)
if err != nil {
return nil, nil, err
}
Expand All @@ -231,44 +231,44 @@ func (a *oauth2AuthImpl) externalLogin(authType model.AuthType, appType model.Ap
}

func (a *oauth2AuthImpl) refresh(params map[string]interface{}, authType model.AuthType, appType model.ApplicationType, appOrg model.ApplicationOrganization, l *logs.Log) (*model.ExternalSystemUser, map[string]interface{}, error) {
oauth2Config, err := a.getOAuth2AuthConfig(authType, appType)
config, err := a.auth.getOAuthConfig(authType, appType)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, typeOAuth2AuthConfig, nil, err)
}
if !oauth2Config.UseRefresh {
return nil, nil, errors.Newf("oauth2 refresh tokens not enabled for org_id=%s, app_id=%s", appOrg.Organization.ID, appOrg.Application.ID)
}
// if !config.UseRefresh {
// return nil, nil, errors.Newf("oauth2 refresh tokens not enabled for org_id=%s, app_id=%s", appOrg.Organization.ID, appOrg.Application.ID)
// }

refreshParams, err := refreshParamsFromMap(params, AuthTypeOAuth2)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionParse, typeAuthRefreshParams, nil, err)
}

return a.loadOAuth2TokensAndInfo(oauth2Config, authType, appType, appOrg, refreshParams.RefreshToken, true, l)
return a.loadOAuth2TokensAndInfo(config, authType, appOrg, refreshParams.RefreshToken, "", true, l)
}

func (a *oauth2AuthImpl) getLoginURL(authType model.AuthType, appType model.ApplicationType, l *logs.Log) (string, map[string]interface{}, error) {
oauth2Config, err := a.getOAuth2AuthConfig(authType, appType)
config, err := a.auth.getOAuthConfig(authType, appType)
if err != nil {
return "", nil, errors.WrapErrorAction(logutils.ActionGet, typeOAuth2AuthConfig, nil, err)
}

return oauth2Config.buildLoginURLResponse(a.auth)
return config.buildLoginURLResponse(a.auth)
}

func (a *oauth2AuthImpl) loadOAuth2TokensAndInfo(oauth2Config *oauth2AuthConfig, authType model.AuthType, appType model.ApplicationType,
appOrg model.ApplicationOrganization, token string, refresh bool, l *logs.Log) (*model.ExternalSystemUser, map[string]interface{}, error) {
newToken, err := a.loadOAuth2TokenWithParams(oauth2Config, token, refresh)
func (a *oauth2AuthImpl) loadOAuth2TokensAndInfo(config oauthConfig, authType model.AuthType, appOrg model.ApplicationOrganization, creds string,
params string, refresh bool, l *logs.Log) (*model.ExternalSystemUser, map[string]interface{}, error) {
newToken, err := a.loadOAuth2TokenWithParams(config, creds, params, refresh)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, typeOAuth2Token, nil, err)
}

sub, err := oauth2Config.checkIDToken(newToken)
sub, err := config.checkIDToken(newToken)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionValidate, typeOidcToken, nil, err)
return nil, nil, errors.WrapErrorAction(logutils.ActionValidate, typeOAuth2Token, nil, err)
}

userInfo, err := a.loadOAuth2UserInfo(oauth2Config, newToken)
userInfo, err := a.loadOAuth2UserInfo(config, newToken)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, "user info", nil, err)
}
Expand All @@ -280,51 +280,20 @@ func (a *oauth2AuthImpl) loadOAuth2TokensAndInfo(oauth2Config *oauth2AuthConfig,
}

userClaimsSub, _ := userClaims["sub"].(string)
if !oauth2Config.checkSubject(sub, userClaimsSub) {
if !config.checkSubject(sub, userClaimsSub) {
return nil, nil, errors.Newf("mismatching user info sub %s and id token sub %s", userClaimsSub, sub)
}

identityProviderID, _ := authType.Params["identity_provider"].(string)
identityProviderSetting := appOrg.FindIdentityProviderSetting(identityProviderID)
if identityProviderSetting == nil {
return nil, nil, errors.ErrorData(logutils.StatusMissing, model.TypeIdentityProviderConfig, &logutils.FieldArgs{"app_org": appOrg.ID, "identity_provider_id": identityProviderID})
}

//identifier
identifier, _ := userClaims[identityProviderSetting.UserIdentifierField].(string)
//name
name, _ := userClaims[identityProviderSetting.NameField].(string)
names := strings.Split(name, " ")
//email
email, _ := userClaims[identityProviderSetting.EmailField].(string)
//system specific
systemSpecific := map[string]interface{}{}
userSpecificFields := identityProviderSetting.UserSpecificFields
if len(userSpecificFields) > 0 {
for _, field := range userSpecificFields {
fieldValue, _ := userClaims[field].(string)
systemSpecific[field] = fieldValue
}
}
//external ids
externalIDs := make(map[string]string)
for k, v := range identityProviderSetting.ExternalIDFields {
externalID, ok := userClaims[v].(string)
if !ok {
l.ErrorWithDetails("failed to parse external id", logutils.Fields{k: userClaims[v]})
continue
}
externalIDs[k] = externalID
externalUser, err := a.auth.getExternalUser(userClaims, authType, appOrg, l)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, model.TypeExternalSystemUser, nil, err)
}

externalUser := model.ExternalSystemUser{Identifier: identifier, ExternalIDs: externalIDs, FirstName: names[0],
LastName: names[len(names)-1], Email: email, SystemSpecific: systemSpecific}

return &externalUser, newToken.getResponse(), nil
return externalUser, newToken.getResponse(), nil
}

func (a *oauth2AuthImpl) loadOAuth2TokenWithParams(oauth2Config *oauth2AuthConfig, token string, refresh bool) (*oauth2Token, error) {
req, err := oauth2Config.buildNewTokenRequest(a.auth, token, refresh)
func (a *oauth2AuthImpl) loadOAuth2TokenWithParams(config oauthConfig, creds string, params string, refresh bool) (oauthToken, error) {
req, err := config.buildNewTokenRequest(a.auth, creds, params, refresh)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionCreate, "oauth2 token request", nil, err)
}
Expand Down Expand Up @@ -358,9 +327,9 @@ func (a *oauth2AuthImpl) loadOAuth2TokenWithParams(oauth2Config *oauth2AuthConfi
return &authToken, nil
}

func (a *oauth2AuthImpl) loadOAuth2UserInfo(oauth2Config *oauth2AuthConfig, token *oauth2Token) ([]byte, error) {
func (a *oauth2AuthImpl) loadOAuth2UserInfo(config oauthConfig, token oauthToken) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest(http.MethodGet, oauth2Config.getUserInfoURL(), nil)
req, err := http.NewRequest(http.MethodGet, config.getUserInfoURL(), nil)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionCreate, logutils.TypeRequest, nil, err)
}
Expand All @@ -386,39 +355,6 @@ func (a *oauth2AuthImpl) loadOAuth2UserInfo(oauth2Config *oauth2AuthConfig, toke
return body, nil
}

func (a *oauth2AuthImpl) getOAuth2AuthConfig(authType model.AuthType, appType model.ApplicationType) (*oauth2AuthConfig, error) {
errFields := &logutils.FieldArgs{"auth_type_id": authType.ID, "app_type_id": appType}

identityProviderID, ok := authType.Params["identity_provider"].(string)
if !ok {
return nil, errors.ErrorData(logutils.StatusInvalid, "identity provider", errFields)
}
appTypeID := appType.ID
authConfig, err := a.auth.getCachedIdentityProviderConfig(identityProviderID, appTypeID)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionFind, model.TypeIdentityProviderConfig, errFields, err)
}

configBytes, err := json.Marshal(authConfig.Config)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionMarshal, model.TypeIdentityProviderConfig, errFields, err)
}

var oauth2Config oauth2AuthConfig
err = json.Unmarshal(configBytes, &oauth2Config)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionUnmarshal, model.TypeIdentityProviderConfig, errFields, err)
}

validate := validator.New()
err = validate.Struct(oauth2Config)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionValidate, model.TypeIdentityProviderConfig, errFields, err)
}

return &oauth2Config, nil
}

// --- Helper functions ---

// generateState generates and returns a randomized state string
Expand Down

0 comments on commit 70acd64

Please sign in to comment.