-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial changes to rm wrp PID population logic * Draft for wrp pid validator * notes on wrp handler we need to use * finish wiring config and handler * clean up * clean up and order return values * perform file renames and more unit tests * add remaining unit tests * fix some linting warnings * fix local test errors * basic auth not allowed with WRP Check * prefer go idiomatic subset check * update changelog for release
- Loading branch information
Showing
15 changed files
with
723 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
|
||
"github.com/go-kit/kit/metrics" | ||
"github.com/xmidt-org/bascule" | ||
"github.com/xmidt-org/webpa-common/basculechecks" | ||
"github.com/xmidt-org/webpa-common/xhttp" | ||
"github.com/xmidt-org/wrp-go/v2" | ||
) | ||
|
||
//partnerAuthority errors | ||
var ( | ||
ErrTokenMissing = &xhttp.Error{Code: http.StatusInternalServerError, Text: "No JWT Token was found in context"} | ||
ErrTokenTypeMismatch = &xhttp.Error{Code: http.StatusInternalServerError, Text: "Token must be a JWT"} | ||
ErrPIDMissing = &xhttp.Error{Code: http.StatusBadRequest, Text: "WRP PartnerIDs field must not be empty"} | ||
ErrInvalidAllowedPartners = &xhttp.Error{Code: http.StatusForbidden, Text: "AllowedPartners JWT claim must be a non-empty list of strings"} | ||
ErrPIDMismatch = &xhttp.Error{Code: http.StatusForbidden, Text: "Unauthorized partners credentials in WRP message"} | ||
) | ||
|
||
//WRPCheckConfig drives the WRP Access control configuration when enabled | ||
type WRPCheckConfig struct { | ||
Type string | ||
} | ||
|
||
// wrpAccessAuthority describes behavior for authorizing WRP messages | ||
// against defined access policies. | ||
type wrpAccessAuthority interface { | ||
authorizeWRP(context.Context, *wrp.Message) (bool, error) | ||
} | ||
|
||
//authorizeWRP should run the scytale partnerID checks against incoming WRP messages | ||
//It takes a pointer to the wrp message as it may modify it in some cases. It returns | ||
//true if such modification was made. An error is returned in cases the validator | ||
//check failed and they are go-kit HTTP response error encoder friendly | ||
|
||
// wrpPartnersAuthority defines the access policy for which WRP messages | ||
// are authorized against the partners credentials of the message creator | ||
type wrpPartnersAccess struct { | ||
strict bool | ||
receivedWRPMessageCount metrics.Counter | ||
} | ||
|
||
func (p *wrpPartnersAccess) withFailure(labelValues ...string) metrics.Counter { | ||
if !p.strict { | ||
return p.withSuccess(labelValues...) | ||
} | ||
return p.receivedWRPMessageCount.With(append(labelValues, OutcomeLabel, Rejected)...) | ||
} | ||
|
||
func (p *wrpPartnersAccess) withSuccess(labelValues ...string) metrics.Counter { | ||
return p.receivedWRPMessageCount.With(append(labelValues, OutcomeLabel, Accepted)...) | ||
} | ||
|
||
//authorizeWRP runs the partners access policy against the WRP and returns an error if the check fails. | ||
//When the policy is not strictly enforced, | ||
// Additionally, when the policy is not a boolean is returned for failure cases where the policy autocorrects the WRP contents | ||
func (p *wrpPartnersAccess) authorizeWRP(ctx context.Context, message *wrp.Message) (bool, error) { | ||
var ( | ||
auth, ok = bascule.FromContext(ctx) | ||
satClientID = "none" | ||
) | ||
|
||
if !ok { | ||
p.withFailure(ClientIDLabel, satClientID, ReasonLabel, TokenMissing).Add(1) | ||
|
||
if p.strict { | ||
return false, ErrTokenMissing | ||
} | ||
return false, nil | ||
} | ||
|
||
token := auth.Token | ||
|
||
if token.Type() != "jwt" { | ||
p.withFailure(ClientIDLabel, satClientID, ReasonLabel, TokenTypeMismatch).Add(1) | ||
|
||
if p.strict { | ||
return false, ErrTokenTypeMismatch | ||
} | ||
return false, nil | ||
} | ||
|
||
attributes := token.Attributes() | ||
|
||
if principal := token.Principal(); len(principal) > 0 { | ||
satClientID = principal | ||
} | ||
|
||
allowedPartners, ok := attributes.GetStringSlice(basculechecks.PartnerKey) | ||
|
||
if !ok || len(allowedPartners) < 1 { | ||
p.withFailure(ClientIDLabel, satClientID, ReasonLabel, JWTPIDInvalid).Add(1) | ||
|
||
if p.strict { | ||
return false, ErrInvalidAllowedPartners | ||
} | ||
|
||
return false, nil | ||
} | ||
|
||
if len(message.PartnerIDs) < 1 { | ||
p.withFailure(ClientIDLabel, satClientID, ReasonLabel, WRPPIDMissing).Add(1) | ||
|
||
if p.strict { | ||
return false, ErrPIDMissing | ||
} | ||
|
||
message.PartnerIDs = allowedPartners | ||
return true, nil | ||
} | ||
|
||
if contains(allowedPartners, "*") { | ||
p.withSuccess(ClientIDLabel, satClientID, ReasonLabel, JWTPIDWildcard).Add(1) | ||
return false, nil | ||
} | ||
|
||
if isSubset(message.PartnerIDs, allowedPartners) { | ||
p.withSuccess(ClientIDLabel, satClientID, ReasonLabel, WRPPIDMatch).Add(1) | ||
return false, nil | ||
} | ||
|
||
p.withFailure(ClientIDLabel, satClientID, ReasonLabel, WRPPIDMismatch).Add(1) | ||
if p.strict { | ||
return false, ErrPIDMismatch | ||
} | ||
|
||
message.PartnerIDs = allowedPartners | ||
return true, nil | ||
} | ||
|
||
//returns true if list contains str | ||
func contains(list []string, str string) bool { | ||
for _, e := range list { | ||
if e == str { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
//returns true if a is a subset of b | ||
func isSubset(a, b []string) bool { | ||
m := make(map[string]bool) | ||
|
||
for _, e := range b { | ||
m[e] = true | ||
} | ||
|
||
for _, e := range a { | ||
if !m[e] { | ||
return false | ||
} | ||
} | ||
return true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/go-kit/kit/metrics" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/xmidt-org/bascule" | ||
"github.com/xmidt-org/wrp-go/v2" | ||
) | ||
|
||
func TestAuthorizeWRP(t *testing.T) { | ||
testCases := []struct { | ||
Name string | ||
PartnerIDs []string | ||
AllowedPartners []string | ||
TokenType string | ||
InjectSecurityToken bool | ||
ExpectAutocorrect bool | ||
Error error | ||
BaseLabelPairs map[string]string | ||
ExpectedPartnerIDs []string | ||
}{ | ||
{ | ||
Name: "Bascule token Missing", | ||
Error: ErrTokenMissing, | ||
TokenType: "jwt", | ||
BaseLabelPairs: map[string]string{ | ||
ReasonLabel: TokenMissing, | ||
ClientIDLabel: "none", | ||
}, | ||
}, | ||
{ | ||
Name: "Bad bascule token type", | ||
Error: ErrTokenTypeMismatch, | ||
InjectSecurityToken: true, | ||
TokenType: "basic", | ||
AllowedPartners: []string{"partner0"}, | ||
BaseLabelPairs: map[string]string{ | ||
ReasonLabel: TokenTypeMismatch, | ||
ClientIDLabel: "none", | ||
}, | ||
}, | ||
|
||
{ | ||
Name: "Invalid AllowedPartners", | ||
Error: ErrInvalidAllowedPartners, | ||
InjectSecurityToken: true, | ||
TokenType: "jwt", | ||
AllowedPartners: []string{}, | ||
BaseLabelPairs: map[string]string{ | ||
ReasonLabel: JWTPIDInvalid, | ||
ClientIDLabel: "tester", | ||
}, | ||
}, | ||
|
||
{ | ||
Name: "PartnerIDs missing from WRP", | ||
Error: ErrPIDMissing, | ||
InjectSecurityToken: true, | ||
TokenType: "jwt", | ||
AllowedPartners: []string{"p0", "p1"}, | ||
ExpectAutocorrect: true, | ||
BaseLabelPairs: map[string]string{ | ||
ReasonLabel: WRPPIDMissing, | ||
ClientIDLabel: "tester", | ||
}, | ||
ExpectedPartnerIDs: []string{"p0", "p1"}, | ||
}, | ||
|
||
{ | ||
Name: "PartnerIDs is not subset of allowerPartners", | ||
InjectSecurityToken: true, | ||
TokenType: "jwt", | ||
PartnerIDs: []string{"p2"}, | ||
AllowedPartners: []string{"p0", "p1"}, | ||
Error: ErrPIDMismatch, | ||
BaseLabelPairs: map[string]string{ | ||
ReasonLabel: WRPPIDMismatch, | ||
ClientIDLabel: "tester", | ||
}, | ||
ExpectedPartnerIDs: []string{"p0", "p1"}, | ||
ExpectAutocorrect: true, | ||
}, | ||
|
||
{ | ||
Name: "Wildcard in allowedPartners", | ||
InjectSecurityToken: true, | ||
TokenType: "jwt", | ||
PartnerIDs: []string{"p2"}, //TODO: is this the behavior we actually want? '*' giving user superpowers! | ||
AllowedPartners: []string{"p0", "p1", "*"}, | ||
BaseLabelPairs: map[string]string{ | ||
ReasonLabel: JWTPIDWildcard, | ||
ClientIDLabel: "tester", | ||
}, | ||
ExpectedPartnerIDs: []string{"p2"}, | ||
}, | ||
|
||
{ | ||
Name: "Non-empty partnerIDs is subset of allowerPartners", | ||
InjectSecurityToken: true, | ||
TokenType: "jwt", | ||
PartnerIDs: []string{"p0"}, | ||
AllowedPartners: []string{"p0", "p1"}, | ||
BaseLabelPairs: map[string]string{ | ||
ReasonLabel: WRPPIDMatch, | ||
ClientIDLabel: "tester", | ||
}, | ||
ExpectedPartnerIDs: []string{"p0"}, | ||
}, | ||
} | ||
|
||
for _, testCase := range testCases { | ||
t.Run(testCase.Name, func(t *testing.T) { | ||
assert := assert.New(t) | ||
|
||
ctx := context.Background() | ||
if testCase.InjectSecurityToken { | ||
ctx = enrichWithBasculeToken(context.Background(), testCase.TokenType, testCase.AllowedPartners) | ||
} | ||
|
||
wrpMsg := &wrp.Message{ | ||
PartnerIDs: testCase.PartnerIDs, | ||
} | ||
|
||
var ( | ||
wrpAccessAuthority wrpAccessAuthority | ||
counter = newTestCounter() | ||
) | ||
|
||
expectedStrictLabels, expectedLenientLabels := createLabelMaps(testCase.Error != nil, testCase.BaseLabelPairs) | ||
|
||
//strict mode | ||
wrpAccessAuthority = &wrpPartnersAccess{ | ||
strict: true, | ||
receivedWRPMessageCount: counter, | ||
} | ||
modified, err := wrpAccessAuthority.authorizeWRP(ctx, wrpMsg) | ||
assert.False(modified) | ||
assert.Equal(testCase.Error, err) | ||
assert.Equal(float64(1), counter.count) | ||
assert.Equal(expectedStrictLabels, counter.labelPairs) | ||
|
||
//lenient mode | ||
counter = newTestCounter() | ||
wrpAccessAuthority = &wrpPartnersAccess{ | ||
strict: false, | ||
receivedWRPMessageCount: counter, | ||
} | ||
|
||
modified, err = wrpAccessAuthority.authorizeWRP(ctx, wrpMsg) | ||
assert.Equal(testCase.ExpectAutocorrect, modified) | ||
assert.Nil(err) | ||
assert.Equal(float64(1), counter.count) | ||
assert.Equal(expectedLenientLabels, counter.labelPairs) | ||
}) | ||
} | ||
} | ||
|
||
func createLabelMaps(rejected bool, baseLabelPairs map[string]string) (strict map[string]string, lenient map[string]string) { | ||
strict = make(map[string]string) | ||
lenient = make(map[string]string) | ||
|
||
for k, v := range baseLabelPairs { | ||
strict[k] = v | ||
lenient[k] = v | ||
} | ||
|
||
if rejected { | ||
strict[OutcomeLabel] = Rejected | ||
} else { | ||
strict[OutcomeLabel] = Accepted | ||
} | ||
lenient[OutcomeLabel] = Accepted | ||
|
||
return | ||
} | ||
|
||
func enrichWithBasculeToken(ctx context.Context, tokenType string, allowedPartners []string) context.Context { | ||
return bascule.WithAuthentication(ctx, bascule.Authentication{ | ||
Token: bascule.NewToken(tokenType, "tester", bascule.NewAttributesFromMap(map[string]interface{}{ | ||
"allowedResources": map[string]interface{}{"allowedPartners": allowedPartners}, | ||
})), | ||
}) | ||
} | ||
|
||
type testCounter struct { | ||
count float64 | ||
labelPairs map[string]string | ||
} | ||
|
||
func (c *testCounter) Add(delta float64) { | ||
c.count += delta | ||
} | ||
|
||
func (c *testCounter) With(labelValues ...string) metrics.Counter { | ||
for i := 0; i < len(labelValues)-1; i += 2 { | ||
c.labelPairs[labelValues[i]] = labelValues[i+1] | ||
} | ||
return c | ||
} | ||
|
||
func newTestCounter() *testCounter { | ||
return &testCounter{ | ||
labelPairs: make(map[string]string), | ||
} | ||
} |
Oops, something went wrong.