Skip to content

Commit

Permalink
Fix group tests, add sign-in tests, and generate nonces
Browse files Browse the repository at this point in the history
  • Loading branch information
sporkmonger committed Nov 16, 2018
1 parent cc767b3 commit 7b9fb26
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 70 deletions.
18 changes: 16 additions & 2 deletions internal/auth/providers/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package providers

import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -99,7 +102,6 @@ func (p *AzureV2Provider) Redeem(redirectURL, code string) (s *sessions.SessionS
// TODO: test this w/ an account that uses an alias and compare email claim
// with UPN claim; UPN has usually been what you want, but I think it's not
// rendered as a full email address here.
// FIXME: validate nonce against session

s = &sessions.SessionState{
AccessToken: token.AccessToken,
Expand Down Expand Up @@ -245,12 +247,24 @@ func (p *AzureV2Provider) GetSignInURL(redirectURI, state string) string {
params.Add("scope", p.Scope)
params.Add("state", state)
params.Set("prompt", p.ApprovalPrompt)
params.Set("nonce", "FIXME") // FIXME, maybe change to session state struct
params.Set("nonce", p.calculateNonce(state)) // required parameter
a.RawQuery = params.Encode()

return a.String()
}

// calculateNonce generates a deterministic nonce from the state value.
// We don't have a session state pointer but we need to generate a nonce
// that we can verify statelessly later. We can only use what's in the
// params and provider struct to assemble a nonce. State is guaranteed to be
// indistinguishable from random and will always change.
func (p *AzureV2Provider) calculateNonce(state string) string {
key := []byte(p.ClientID + p.ClientSecret)
h := hmac.New(sha256.New, key)
h.Write([]byte(state))
return base64.URLEncoding.EncodeToString(h.Sum(nil))[:8]
}

// ValidateGroupMembership takes in an email and the allowed groups and returns the groups that the email is part of in that list.
// If `allGroups` is an empty list it returns all the groups that the user belongs to.
func (p *AzureV2Provider) ValidateGroupMembership(email string, allGroups []string) ([]string, error) {
Expand Down
171 changes: 103 additions & 68 deletions internal/auth/providers/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"testing"
"time"

"github.com/buzzfeed/sso/internal/pkg/groups"
"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/buzzfeed/sso/internal/pkg/testutil"

Expand Down Expand Up @@ -369,96 +368,133 @@ func (c *groupsClientMock) Do(req *http.Request) (*http.Response, error) {
return &http.Response{}, nil
}

func TestAzureV2ValidateGroupMembers(t *testing.T) {
func TestAzureV2GetSignInURL(t *testing.T) {
testCases := []struct {
name string
inputAllowedGroups []string
groups []string
groupsError error
getMembersFunc func(string) (groups.MemberSet, bool)
expectedGroups []string
expectedErrorString string
name string
redirectURI string
state string
expectedParams url.Values
}{
{
name: "empty input groups should return an empty string",
inputAllowedGroups: []string{},
groups: []string{"group1"},
expectedGroups: []string{"group1"},
getMembersFunc: func(string) (groups.MemberSet, bool) { return nil, false },
name: "nonce values passed to azure should be deterministic, pass one",
redirectURI: "https://example.com/oauth/callback",
state: "1234",
expectedParams: url.Values{
"redirect_uri": []string{"https://example.com/oauth/callback"},
"response_mode": []string{"form_post"},
"response_type": []string{"id_token code"},
"scope": []string{"openid email profile offline_access"},
"state": []string{"1234"},
"client_id": []string{TestClientID},
"nonce": []string{"KEB9Aopa"},
"prompt": []string{"consent"},
},
},
{
name: "empty inputs and error on groups resource should return error",
inputAllowedGroups: []string{},
getMembersFunc: func(string) (groups.MemberSet, bool) { return nil, false },
groupsError: fmt.Errorf("error"),
expectedErrorString: "error",
name: "nonce values passed to azure should be deterministic, pass two",
redirectURI: "https://example.com/oauth/callback",
state: "1234",
expectedParams: url.Values{
"redirect_uri": []string{"https://example.com/oauth/callback"},
"response_mode": []string{"form_post"},
"response_type": []string{"id_token code"},
"scope": []string{"openid email profile offline_access"},
"state": []string{"1234"},
"client_id": []string{TestClientID},
"nonce": []string{"KEB9Aopa"},
"prompt": []string{"consent"},
},
},
{
name: "member exists in cache, should not call groups resource",
inputAllowedGroups: []string{"group1"},
groupsError: fmt.Errorf("should not get here"),
getMembersFunc: func(string) (groups.MemberSet, bool) { return groups.MemberSet{"email": {}}, true },
expectedGroups: []string{"group1"},
name: "nonce values passed to azure should be deterministic, pass three",
redirectURI: "https://example.com/oauth/callback",
state: "4321",
expectedParams: url.Values{
"redirect_uri": []string{"https://example.com/oauth/callback"},
"response_mode": []string{"form_post"},
"response_type": []string{"id_token code"},
"scope": []string{"openid email profile offline_access"},
"state": []string{"4321"},
"client_id": []string{TestClientID},
"nonce": []string{"x_PhEN0K"},
"prompt": []string{"consent"},
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
p := newAzureV2Provider(nil)
p.ClientID = TestClientID
p.ClientSecret = "456"
p.Scope = "openid email profile offline_access"
p.ApprovalPrompt = "consent"

signInURL := p.GetSignInURL(tc.redirectURI, tc.state)
parsedURL, err := url.Parse(signInURL)
if err != nil {
t.Error(err)
}

if !reflect.DeepEqual(tc.expectedParams, parsedURL.Query()) {
t.Logf("expected params %+v", tc.expectedParams)
t.Logf("got params %+v", parsedURL.Query())
t.Errorf("unexpected params returned")
}
})
}
}

func TestAzureV2ValidateGroupMembers(t *testing.T) {
testCases := []struct {
name string
allowedGroups []string
mockedGroups []string
mockedError error
expectedGroups []string
expectedErrorString string
}{
{
name: "member does not exist in cache, should still not call groups resource",
inputAllowedGroups: []string{"group1"},
groupsError: fmt.Errorf("should not get here"),
getMembersFunc: func(string) (groups.MemberSet, bool) { return groups.MemberSet{}, true },
expectedGroups: []string{},
name: "allowed groups and groups resource output exactly match should return all groups",
allowedGroups: []string{"group1", "group2", "group3"},
mockedGroups: []string{"group1", "group2", "group3"},
expectedGroups: []string{"group1", "group2", "group3"},
},
{
name: "subset of groups are not cached, calls groups resource",
inputAllowedGroups: []string{"group1", "group2"},
groups: []string{"group1", "group2", "group3"},
groupsError: nil,
getMembersFunc: func(group string) (groups.MemberSet, bool) {
switch group {
case "group1":
return groups.MemberSet{"email": {}}, true
default:
return groups.MemberSet{}, false
}
},
name: "allowed groups should restrict to subset of groups",
allowedGroups: []string{"group1", "group2"},
mockedGroups: []string{"group1", "group2", "group3"},
expectedGroups: []string{"group1", "group2"},
},
{
name: "subset of groups are not cached, calls groups resource with error",
inputAllowedGroups: []string{"group1", "group2"},
groupsError: fmt.Errorf("error"),
getMembersFunc: func(group string) (groups.MemberSet, bool) {
switch group {
case "group1":
return groups.MemberSet{"email": {}}, true
default:
return groups.MemberSet{}, false
}
},
expectedErrorString: "error",
name: "allowed groups superset should not restrict to subset of groups",
allowedGroups: []string{"group1", "group2", "group3"},
mockedGroups: []string{"group1", "group2"},
expectedGroups: []string{"group1", "group2"},
},
{
name: "subset of groups not there, does not call groups resource",
inputAllowedGroups: []string{"group1", "group2"},
groups: []string{"group1", "group2", "group3"},
groupsError: fmt.Errorf("should not get here"),
getMembersFunc: func(group string) (groups.MemberSet, bool) {
switch group {
case "group1":
return groups.MemberSet{"email": {}}, true
default:
return groups.MemberSet{}, true
}
},
name: "groups allowed zero value should default to return all groups",
allowedGroups: []string{},
mockedGroups: []string{"group1"},
expectedGroups: []string{"group1"},
},
{
name: "empty inputs and error on groups resource should return error",
allowedGroups: []string{},
mockedError: fmt.Errorf("error"),
expectedErrorString: "error",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
p := newAzureV2Provider(nil)
p.GraphService = &MockAzureGraphService{Groups: tc.groups, GroupsError: tc.groupsError}
p.GraphService = &MockAzureGraphService{
Groups: tc.mockedGroups,
GroupsError: tc.mockedError,
}

groups, err := p.ValidateGroupMembership("email", tc.inputAllowedGroups)
groups, err := p.ValidateGroupMembership("test@example.com", tc.allowedGroups)

if err != nil {
if tc.expectedErrorString != err.Error() {
Expand All @@ -470,7 +506,6 @@ func TestAzureV2ValidateGroupMembers(t *testing.T) {
t.Logf("got groups %v", groups)
t.Errorf("unexpected groups returned")
}

})
}
}

0 comments on commit 7b9fb26

Please sign in to comment.