Skip to content

Commit

Permalink
use exact matching of allowed domain entries, issue #489 (#493)
Browse files Browse the repository at this point in the history
* use exact matching of allowed domain entries, issue #489

* update doc, add testcases from PR conversation

* introduce AllowedDomainFunc #489

* more tests, fix doc

* lowercase origin before checking cors
  • Loading branch information
emicklei committed Jun 6, 2022
1 parent c2c010a commit fd3c327
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 40 deletions.
64 changes: 26 additions & 38 deletions cors_filter.go
Expand Up @@ -18,9 +18,22 @@ import (
// http://enable-cors.org/server.html
// http://www.html5rocks.com/en/tutorials/cors/#toc-handling-a-not-so-simple-request
type CrossOriginResourceSharing struct {
ExposeHeaders []string // list of Header names
AllowedHeaders []string // list of Header names
AllowedDomains []string // list of allowed values for Http Origin. An allowed value can be a regular expression to support subdomain matching. If empty all are allowed.
ExposeHeaders []string // list of Header names

// AllowedHeaders is alist of Header names. Checking is case-insensitive.
// The list may contain the special wildcard string ".*" ; all is allowed
AllowedHeaders []string

// AllowedDomains is a list of allowed values for Http Origin.
// The list may contain the special wildcard string ".*" ; all is allowed
// If empty all are allowed.
AllowedDomains []string

// AllowedDomainFunc is optional and is a function that will do the check
// when the origin is not part of the AllowedDomains and it does not contain the wildcard ".*".
AllowedDomainFunc func(origin string) bool

// AllowedMethods is either empty or has a list of http methods names. Checking is case-insensitive.
AllowedMethods []string
MaxAge int // number of seconds before requiring new Options request
CookiesAllowed bool
Expand Down Expand Up @@ -119,36 +132,24 @@ func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool {
if len(origin) == 0 {
return false
}
lowerOrigin := strings.ToLower(origin)
if len(c.AllowedDomains) == 0 {
if c.AllowedDomainFunc != nil {
return c.AllowedDomainFunc(lowerOrigin)
}
return true
}

allowed := false
// exact match on each allowed domain
for _, domain := range c.AllowedDomains {
if domain == origin {
allowed = true
break
if domain == ".*" || strings.ToLower(domain) == lowerOrigin {
return true
}
}

if !allowed {
if len(c.allowedOriginPatterns) == 0 {
// compile allowed domains to allowed origin patterns
allowedOriginRegexps, err := compileRegexps(c.AllowedDomains)
if err != nil {
return false
}
c.allowedOriginPatterns = allowedOriginRegexps
}

for _, pattern := range c.allowedOriginPatterns {
if allowed = pattern.MatchString(origin); allowed {
break
}
}
if c.AllowedDomainFunc != nil {
return c.AllowedDomainFunc(origin)
}

return allowed
return false
}

func (c CrossOriginResourceSharing) setAllowOriginHeader(req *Request, resp *Response) {
Expand Down Expand Up @@ -190,16 +191,3 @@ func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header str
}
return false
}

// Take a list of strings and compile them into a list of regular expressions.
func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) {
regexps := []*regexp.Regexp{}
for _, regexpStr := range regexpStrings {
r, err := regexp.Compile(regexpStr)
if err != nil {
return regexps, err
}
regexps = append(regexps, r)
}
return regexps, nil
}
40 changes: 38 additions & 2 deletions cors_filter_test.go
Expand Up @@ -120,10 +120,46 @@ func TestCORSFilter_AllowedDomains(t *testing.T) {
DefaultContainer.Dispatch(httpWriter, httpRequest)
actual := httpWriter.Header().Get(HEADER_AccessControlAllowOrigin)
if actual != each.origin && each.allowed {
t.Fatal("expected to be accepted")
t.Error("expected to be accepted", each)
}
if actual == each.origin && !each.allowed {
t.Fatal("did not expect to be accepted")
t.Error("did not expect to be accepted")
}
}
}

func TestCORSFilter_AllowedDomainFunc(t *testing.T) {
cors := CrossOriginResourceSharing{
AllowedDomains: []string{"here", "there"},
AllowedDomainFunc: func(origin string) bool {
return "where" == origin
},
}
if got, want := cors.isOriginAllowed("here"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("HERE"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("there"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("where"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("nowhere"), false; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
// just func
cors.AllowedDomains = []string{}
if got, want := cors.isOriginAllowed("here"), false; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
if got, want := cors.isOriginAllowed("where"), true; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
// empty domain
if got, want := cors.isOriginAllowed(""), false; got != want {
t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want)
}
}

0 comments on commit fd3c327

Please sign in to comment.