Skip to content

Commit

Permalink
refactor oauth to set up for testing externalAuthType interface [#523]
Browse files Browse the repository at this point in the history
  • Loading branch information
roberlander2 committed Aug 29, 2022
1 parent 9091faf commit 0bf5082
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 213 deletions.
16 changes: 8 additions & 8 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
"filename": "core/auth/auth.go",
"hashed_secret": "4d55af37dbbb6a42088d917caa1ca25428ec42c9",
"is_verified": false,
"line_number": 2429
"line_number": 2392
}
],
"core/auth/auth_type_email.go": [
Expand Down Expand Up @@ -193,7 +193,7 @@
"filename": "core/auth/auth_type_oauth2.go",
"hashed_secret": "f8dabcd142517fb16418a3a2790f3da31392537a",
"is_verified": false,
"line_number": 122
"line_number": 128
}
],
"core/auth/auth_type_oidc.go": [
Expand All @@ -202,7 +202,7 @@
"filename": "core/auth/auth_type_oidc.go",
"hashed_secret": "f8dabcd142517fb16418a3a2790f3da31392537a",
"is_verified": false,
"line_number": 113
"line_number": 119
}
],
"driven/emailer/adapter.go": [
Expand Down Expand Up @@ -263,30 +263,30 @@
"filename": "main.go",
"hashed_secret": "7ee294e78b33e56f1f32f05e03847461cb0008df",
"is_verified": false,
"line_number": 89
"line_number": 90
},
{
"type": "Secret Keyword",
"filename": "main.go",
"hashed_secret": "d22d26f750c4f64214b43ec38aaf6e3456407f0a",
"is_verified": false,
"line_number": 141
"line_number": 144
},
{
"type": "Secret Keyword",
"filename": "main.go",
"hashed_secret": "4f2da4710cdfd5c9f655ceb78bcb5ed03c73ddfd",
"is_verified": false,
"line_number": 153
"line_number": 156
},
{
"type": "Secret Keyword",
"filename": "main.go",
"hashed_secret": "24fe4e436894f6911a8234dbf81ad77417eac012",
"is_verified": false,
"line_number": 155
"line_number": 158
}
]
},
"generated_at": "2022-08-24T21:08:13Z"
"generated_at": "2022-08-29T05:09:11Z"
}
53 changes: 8 additions & 45 deletions core/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import (
"crypto/rsa"
"encoding/json"
"fmt"
"net/url"
"strconv"
"strings"
"sync"
"time"
Expand All @@ -37,7 +35,6 @@ import (
"github.com/rokwire/core-auth-library-go/v2/tokenauth"
"golang.org/x/sync/syncmap"
"gopkg.in/go-playground/validator.v9"
"gopkg.in/gomail.v2"

"github.com/rokwire/logging-library-go/errors"
"github.com/rokwire/logging-library-go/logs"
Expand Down Expand Up @@ -77,8 +74,9 @@ const (

// Auth represents the auth functionality unit
type Auth struct {
storage Storage
emailer Emailer
storage Storage
emailer Emailer
oauthProvider OAuthProvider

logger *logs.Logger

Expand All @@ -100,9 +98,6 @@ type Auth struct {

profileBB ProfileBuildingBlock

emailFrom string
emailDialer *gomail.Dialer

cachedIdentityProviders *syncmap.Map //cache identityProviders
identityProvidersLock *sync.RWMutex

Expand All @@ -115,7 +110,7 @@ type Auth struct {
}

// NewAuth creates a new auth instance
func NewAuth(serviceID string, host string, authPrivKey *rsa.PrivateKey, storage Storage, emailer Emailer, minTokenExp *int64, maxTokenExp *int64, twilioAccountSID string,
func NewAuth(serviceID string, host string, authPrivKey *rsa.PrivateKey, storage Storage, emailer Emailer, oauthProvider OAuthProvider, minTokenExp *int64, maxTokenExp *int64, twilioAccountSID string,
twilioToken string, twilioServiceSID string, profileBB *profilebb.Adapter, smtpHost string, smtpPortNum int, smtpUser string, smtpPassword string, smtpFrom string, logger *logs.Logger) (*Auth, error) {
if minTokenExp == nil {
var minTokenExpVal int64 = 5
Expand All @@ -126,8 +121,6 @@ func NewAuth(serviceID string, host string, authPrivKey *rsa.PrivateKey, storage
var maxTokenExpVal int64 = 60
maxTokenExp = &maxTokenExpVal
}
//maybe set up from config collection for diff types of auth
emailDialer := gomail.NewDialer(smtpHost, smtpPortNum, smtpUser, smtpPassword)

authTypes := map[string]authType{}
externalAuthTypes := map[string]externalAuthType{}
Expand All @@ -143,10 +136,10 @@ func NewAuth(serviceID string, host string, authPrivKey *rsa.PrivateKey, storage

timerDone := make(chan bool)

auth := &Auth{storage: storage, emailer: emailer, logger: logger, authTypes: authTypes, externalAuthTypes: externalAuthTypes, anonymousAuthTypes: anonymousAuthTypes,
serviceAuthTypes: serviceAuthTypes, mfaTypes: mfaTypes, authPrivKey: authPrivKey, ServiceRegManager: nil, serviceID: serviceID, host: host, minTokenExp: *minTokenExp,
maxTokenExp: *maxTokenExp, profileBB: profileBB, cachedIdentityProviders: cachedIdentityProviders, identityProvidersLock: identityProvidersLock,
timerDone: timerDone, emailDialer: emailDialer, emailFrom: smtpFrom, apiKeys: apiKeys, apiKeysLock: apiKeysLock}
auth := &Auth{storage: storage, emailer: emailer, oauthProvider: oauthProvider, logger: logger, authTypes: authTypes, externalAuthTypes: externalAuthTypes,
anonymousAuthTypes: anonymousAuthTypes, serviceAuthTypes: serviceAuthTypes, mfaTypes: mfaTypes, authPrivKey: authPrivKey, ServiceRegManager: nil, serviceID: serviceID, host: host,
minTokenExp: *minTokenExp, maxTokenExp: *maxTokenExp, profileBB: profileBB, cachedIdentityProviders: cachedIdentityProviders, identityProvidersLock: identityProvidersLock,
timerDone: timerDone, apiKeys: apiKeys, apiKeysLock: apiKeysLock}

err := auth.storeReg()
if err != nil {
Expand Down Expand Up @@ -2065,36 +2058,6 @@ func (a *Auth) isValidAdminAuthType(authenticationType string) bool {
return authenticationType == AuthTypeEmail || strings.HasSuffix("_"+authenticationType, "_oidc") || strings.HasSuffix("_"+authenticationType, "_oauth2")
}

func (a *Auth) queryValuesFromURL(urlStr string) (url.Values, error) {
unquotedCreds, err := strconv.Unquote(urlStr)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionParse, "raw url", nil, err)
}
parsedURL, err := url.Parse(unquotedCreds)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionParse, "unquoted url", nil, err)
}
unescapedQuery, err := url.QueryUnescape(parsedURL.RawQuery)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionParse, "raw url query", nil, err)
}
parsedCreds, err := url.ParseQuery(unescapedQuery)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionParse, "unescaped url query", nil, err)
}

return parsedCreds, nil
}

func (a *Auth) encodeQueryValues(values map[string]string) string {
data := url.Values{}
for k, v := range values {
data.Set(k, v)
}

return data.Encode()
}

func (a *Auth) getAuthTypeImpl(authType model.AuthType) (authType, error) {
if auth, ok := a.authTypes[authType.Code]; ok {
return auth, nil
Expand Down
124 changes: 13 additions & 111 deletions core/auth/auth_type_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ package auth

import (
"core-building-block/core/model"
"core-building-block/driven/oauthprovider"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/rokwire/logging-library-go/errors"
Expand Down Expand Up @@ -49,7 +48,7 @@ func (a *oauthAuthImpl) externalLogin(authType model.AuthType, appType model.App
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, typeAuthConfig, logutils.StringArgs(a.authType), err)
}

code, err := config.getAuthorizationCode(a.auth, creds, params)
code, err := config.GetAuthorizationCode(creds, params)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, "authorization code", logutils.StringArgs(a.authType), err)
}
Expand Down Expand Up @@ -82,28 +81,29 @@ func (a *oauthAuthImpl) getLoginURL(authType model.AuthType, appType model.Appli
return "", nil, errors.WrapErrorAction(logutils.ActionGet, typeAuthConfig, logutils.StringArgs(a.authType), err)
}

return config.buildLoginURLResponse(a.auth)
return config.BuildLoginURLResponse()
}

// --- Helper functions ---

func (a *oauthAuthImpl) loadTokensAndInfo(config oauthConfig, authType model.AuthType, appOrg model.ApplicationOrganization, creds string, params string,
func (a *oauthAuthImpl) loadTokensAndInfo(config oauthprovider.OAuthConfig, authType model.AuthType, appOrg model.ApplicationOrganization, creds string, params string,
refresh bool, l *logs.Log) (*model.ExternalSystemUser, map[string]interface{}, error) {
newToken, err := a.loadTokenWithParams(config, creds, params, refresh)
newToken := config.EmptyToken()
err := a.auth.oauthProvider.LoadToken(config, creds, params, refresh, newToken)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, logutils.TypeToken, nil, err)
}
if refresh && newToken == nil {
if refresh && strings.TrimSpace(newToken.GetAuthorizationHeader()) == "" {
l.Warnf("%s refresh tokens not enabled for app_id=%s, org_id=%s", a.authType, appOrg.Organization.ID, appOrg.Application.ID)
return nil, nil, nil
}

sub, err := config.checkIDToken(newToken)
sub, err := config.CheckIDToken(newToken)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionValidate, logutils.TypeToken, nil, err)
}

userInfo, err := a.loadUserInfo(config, newToken)
userInfo, err := a.auth.oauthProvider.LoadUserInfo(config, newToken)
if err != nil {
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, "user info", nil, err)
}
Expand All @@ -115,7 +115,7 @@ func (a *oauthAuthImpl) loadTokensAndInfo(config oauthConfig, authType model.Aut
}

userClaimsSub, _ := userClaims["sub"].(string)
if !config.checkSubject(sub, userClaimsSub) {
if !config.CheckSubject(sub, userClaimsSub) {
return nil, nil, errors.ErrorData(logutils.StatusInvalid, "subject claim", &logutils.FieldArgs{"id_token": sub, "user_claims": userClaimsSub})
}

Expand All @@ -124,87 +124,7 @@ func (a *oauthAuthImpl) loadTokensAndInfo(config oauthConfig, authType model.Aut
return nil, nil, errors.WrapErrorAction(logutils.ActionGet, model.TypeExternalSystemUser, nil, err)
}

return externalUser, newToken.getResponse(), nil
}

func (a *oauthAuthImpl) loadTokenWithParams(config oauthConfig, creds string, params string, refresh bool) (oauthToken, error) {
req, err := config.buildNewTokenRequest(a.auth, creds, params, refresh)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionCreate, "token request", nil, err)
}
if refresh && req == nil {
return nil, nil
}

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionSend, logutils.TypeRequest, nil, err)
}

defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionRead, logutils.TypeRequestBody, nil, err)
}
if resp.StatusCode != 200 {
return nil, errors.ErrorData(logutils.StatusInvalid, logutils.TypeResponse, &logutils.FieldArgs{"status_code": resp.StatusCode, "error": string(body)})
}

var token oauthToken
switch a.authType {
case AuthTypeOidc:
{
var oidc oidcToken
token = &oidc
}
case AuthTypeOAuth2:
{
var oauth2 oauth2Token
token = &oauth2
}
default:
return nil, errors.ErrorData(logutils.StatusInvalid, model.TypeAuthType, logutils.StringArgs(a.authType))
}
err = json.Unmarshal(body, &token)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionUnmarshal, logutils.TypeToken, nil, err)
}
validate := validator.New()
err = validate.Struct(token)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionValidate, logutils.TypeToken, nil, err)
}

return token, nil
}

func (a *oauthAuthImpl) loadUserInfo(config oauthConfig, token oauthToken) ([]byte, error) {
client := &http.Client{}
req, err := http.NewRequest(http.MethodGet, config.getUserInfoURL(), nil)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionCreate, logutils.TypeRequest, nil, err)
}
req.Header.Set("Authorization", token.getAuthorizationHeader())

resp, err := client.Do(req)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionSend, logutils.TypeRequest, nil, err)
}

defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.WrapErrorAction(logutils.ActionRead, logutils.TypeResponse, nil, err)
}
if resp.StatusCode != 200 {
return nil, errors.ErrorData(logutils.StatusInvalid, logutils.TypeResponse, &logutils.FieldArgs{"status_code": resp.StatusCode, "error": string(body)})
}
if len(body) == 0 {
return nil, errors.ErrorData(logutils.StatusMissing, logutils.TypeResponseBody, nil)
}

return body, nil
return externalUser, newToken.GetResponse(), nil
}

func (a *oauthAuthImpl) getExternalUser(claims map[string]interface{}, authType model.AuthType, appOrg model.ApplicationOrganization, l *logs.Log) (*model.ExternalSystemUser, error) {
Expand Down Expand Up @@ -271,7 +191,7 @@ func (a *oauthAuthImpl) getExternalUser(claims map[string]interface{}, authType
MiddleName: middleName, LastName: lastName, Email: email, Roles: roles, Groups: groups, SystemSpecific: systemSpecific}, nil
}

func (a *oauthAuthImpl) getOAuthConfig(authType model.AuthType, appType model.ApplicationType) (oauthConfig, error) {
func (a *oauthAuthImpl) getOAuthConfig(authType model.AuthType, appType model.ApplicationType) (oauthprovider.OAuthConfig, error) {
errFields := &logutils.FieldArgs{"auth_type_id": authType.ID, "app_type_id": appType}

identityProviderID, ok := authType.Params["identity_provider"].(string)
Expand All @@ -288,7 +208,7 @@ func (a *oauthAuthImpl) getOAuthConfig(authType model.AuthType, appType model.Ap
return nil, errors.WrapErrorAction(logutils.ActionMarshal, model.TypeIdentityProviderConfig, errFields, err)
}

var config oauthConfig
var config oauthprovider.OAuthConfig
switch a.authType {
case AuthTypeOidc:
{
Expand Down Expand Up @@ -329,21 +249,3 @@ func (a *oauthAuthImpl) refreshParamsFromMap(val map[string]interface{}) (*oauth

return &oauthRefreshParams{RefreshToken: refreshToken}, nil
}

type oauthConfig interface {
getAuthorizeURL() string
getTokenURL() string
getUserInfoURL() string

getAuthorizationCode(auth *Auth, creds string, params string) (string, error)
buildNewTokenRequest(auth *Auth, creds string, params string, refresh bool) (*http.Request, error)
checkIDToken(token oauthToken) (string, error)
checkSubject(tokenSubject string, userSubject string) bool
buildLoginURLResponse(auth *Auth) (string, map[string]interface{}, error)
}

type oauthToken interface {
getAuthorizationHeader() string
getResponse() map[string]interface{}
getIDToken() string
}

0 comments on commit 0bf5082

Please sign in to comment.