Skip to content

Commit

Permalink
Added domain whitelisting
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Reichenberger committed Jun 12, 2018
1 parent a88f3e0 commit 88dd4a1
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 20 deletions.
6 changes: 3 additions & 3 deletions README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ OAuth2 Router [![GoDoc](http://img.shields.io/badge/godoc-reference-blue.svg)](h
========

Redirects an OAuth2 callback to a URL specified in the state parameter with all query parameters from the original request.
This allows having one endpoint as the destination for different OAuth2 callbacks destinations.
This allows having one endpoint as the destination for different OAuth2 callbacks destinations. Allows for whitelisting redirect domains.

Possible use cases:
- Pull Request builds
Expand All @@ -24,15 +24,15 @@ https://hub.docker.com/r/preichenberger/oauth2-router

## Start
```bash
$ oauth2-router -port 3000
$ oauth2-router -port 3000 -whitelist localhost,*.google.com
$ 2018/06/08 12:48:09 Starting OAuth2 Router on port: 3000
```

## Client Implementation
For the following info:
- Whitelist: localhost
- OAuth2 Router redirect_uri: http://localhost:8080
- Real redirect_uri: http://localhost:5000/random_endpoint?apple=pears
- Authorization code from callback: secretcode

1. Add the OAuth2 Router redirect_uri: http://localhost:8080 to your OAuth2 provider (Google, Facebook, Github, etc...)
2. Create a JSON with any date to be passed in the OAuth2 state parameter. Add a redirect parameter, with the real redirect_uri
Expand Down
19 changes: 15 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import (
"github.com/preichenberger/oauth2-router/redirector"
)

var _redirector *redirector.Redirector

func OauthRouterServer(w http.ResponseWriter, req *http.Request) {
redirectUrl, err := redirector.CreateUrl(req.URL.Query())
redirectUrl, err := _redirector.CreateUrl(req.URL.Query())
if err != nil {
switch err.(type) {
case redirector.ValidationError:
case redirector.Error:
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf("400 - %s\n", err)))
default:
Expand All @@ -30,9 +32,11 @@ func OauthRouterServer(w http.ResponseWriter, req *http.Request) {
}

func main() {
var port int
var help bool
var port int
var whitelist string
flag.IntVar(&port, "port", 8080, "port to listen on")
flag.StringVar(&whitelist, "whitelist", "*", "comma-delimited list of whitelist domains i.e '*.github.com,pizza.com'")
flag.BoolVar(&help, "h", false, "help")
flag.Parse()

Expand All @@ -45,11 +49,18 @@ func main() {
port = env_port_int
}

env_whitelist := os.Getenv("WHITELIST")
if len(env_whitelist) != 0 {
whitelist = env_whitelist
}

if help {
println("Usage: oauth2-redirector [-port 8080]")
println("Usage: oauth2-redirector [-port 8080] [-whitelist *.github.com,pizza.com")
os.Exit(0)
}

_redirector = redirector.NewRedirector(whitelist)

http.HandleFunc("/", OauthRouterServer)
log.Printf("Starting OAuth2 Router on port: %d", port)
err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil)
Expand Down
4 changes: 2 additions & 2 deletions redirector/error.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package redirector

type ValidationError struct {
type Error struct {
Message string
}

func (e ValidationError) Error() string {
func (e Error) Error() string {
return e.Message
}
53 changes: 53 additions & 0 deletions redirector/redirector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package redirector

import (
"strings"
)

type Redirector struct {
WhitelistDomains []string
}

func validateDomainParts(a, b []string) bool {
if len(a) != len(b) {
return false
}

for i, _ := range a {
if a[i] == "*" || b[i] == "*" {
continue
}

if a[i] != b[i] {
return false
}
}

return true
}

func NewRedirector(whitelist string) *Redirector {
whitelistDomains := strings.Split(whitelist, ",")

return &Redirector{
WhitelistDomains: whitelistDomains,
}
}

func (r *Redirector) ValidateDomain(host string) bool {
for _, whitelistDomain := range r.WhitelistDomains {
whitelistDomainParts := strings.Split(whitelistDomain, ".")
hostParts := strings.Split(host, ".")
if whitelistDomain == "" {
return true
}

if !validateDomainParts(whitelistDomainParts, hostParts) {
continue
}

return true
}

return false
}
40 changes: 40 additions & 0 deletions redirector/redirector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package redirector

import (
"errors"
"testing"
)

func TestValidateDomain(t *testing.T) {
r := NewRedirector("")
if !r.ValidateDomain("pizza") {
t.Error(errors.New("Empty whitelist failed"))
}
}

func TestValidateDomainWildcard(t *testing.T) {
r := NewRedirector("google.com,apple.com,*.github.com")
if !r.ValidateDomain("api.github.com") {
t.Error(errors.New("Wildcard domain validation failed"))
}
}

func TestValidateDomainNested(t *testing.T) {
r := NewRedirector("google.com,*.api.apple.com,*.github.com")
if !r.ValidateDomain("test.api.apple.com") {
t.Error(errors.New("Nested domain validation failed"))
}
}

func TestValidateDomainFail(t *testing.T) {
r := NewRedirector("google.com,*.api.apple.com,*.github.com")
if r.ValidateDomain("test.api.github.com") {
t.Error(errors.New("Validate domain did not fail"))
}
if r.ValidateDomain("api.apple.com") {
t.Error(errors.New("Validate domain did not fail"))
}
if r.ValidateDomain("test.google.com") {
t.Error(errors.New("Validate domain did not fail"))
}
}
18 changes: 11 additions & 7 deletions redirector/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,40 @@ import (
"net/url"
)

func CreateUrl(queryValues url.Values) (*url.URL, error) {
func (r *Redirector ) CreateUrl(queryValues url.Values) (*url.URL, error) {
var stateValues map[string]string

if _, ok := queryValues["state"]; !ok {
return nil, ValidationError{"Missing state field"}
return nil, Error{"Missing state field"}
}

stateQuery, err := base64.StdEncoding.DecodeString(queryValues["state"][0])
if err != nil {
return nil, ValidationError{"Could not base64 decode state value"}
return nil, Error{"Could not base64 decode state value"}
}

if err := json.Unmarshal(stateQuery, &stateValues); err != nil {
return nil, ValidationError{"Could not json decode state value"}
return nil, Error{"Could not json decode state value"}
}

if _, ok := stateValues["redirect"]; !ok {
return nil, ValidationError{"Query param redirect missing from state"}
return nil, Error{"Query param redirect missing from state"}
}

redirect, err := url.QueryUnescape(stateValues["redirect"])
if err != nil {
return nil, ValidationError{"Could not URL decode redirect value"}
return nil, Error{"Could not URL decode redirect value"}
}

redirectUrl, err := url.ParseRequestURI(redirect)
if err != nil {
return nil, ValidationError{"Could not parse redirect URL"}
return nil, Error{"Could not parse redirect URL"}
}

if !r.ValidateDomain(redirectUrl.Host) {
return nil, Error{"Domain is not whitelisted"}
}

redirectValues := redirectUrl.Query()
for key, values := range queryValues {
for _, value := range values {
Expand Down
13 changes: 9 additions & 4 deletions redirector/url_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ func TestCreateUrl(t *testing.T) {
}
queryValues.Add("state", state)

redirectUrl, err := CreateUrl(queryValues)
redirector := NewRedirector("www.github.com")
redirectUrl, err := redirector.CreateUrl(queryValues)
if err != nil {
t.Error(err)
}
Expand All @@ -49,7 +50,9 @@ func TestCreateUrl(t *testing.T) {
}

func TestCreateUrlMissingStateError(t *testing.T) {
_, err := CreateUrl(url.Values{})
redirector := NewRedirector("")

_, err := redirector.CreateUrl(url.Values{})
if err.Error() != "Missing state field" {
t.Error(errors.New("Missing not found state error"))
}
Expand All @@ -67,7 +70,8 @@ func TestCreateUrlMissingRedirectError(t *testing.T) {
}
queryValues.Add("state", state)

_, err = CreateUrl(queryValues)
redirector := NewRedirector("")
_, err = redirector.CreateUrl(queryValues)
if err.Error() != "Query param redirect missing from state" {
t.Error(errors.New("Missing query param redirect error"))
}
Expand All @@ -85,7 +89,8 @@ func TestCreateUrlRedirectParseError(t *testing.T) {
}
queryValues.Add("state", state)

_, err = CreateUrl(queryValues)
redirector := NewRedirector("")
_, err = redirector.CreateUrl(queryValues)
if err.Error() != "Could not parse redirect URL" {
t.Error(errors.New("Missing could not parse redirect error"))
}
Expand Down

0 comments on commit 88dd4a1

Please sign in to comment.