From fd3c327a379ce08c68ef18765bdc925f5d9bad10 Mon Sep 17 00:00:00 2001 From: Ernest Micklei Date: Mon, 6 Jun 2022 07:45:08 +0200 Subject: [PATCH] use exact matching of allowed domain entries, issue #489 (#493) * use exact matching of allowed domain entries, issue #489 * update doc, add testcases from PR conversation * introduce AllowedDomainFunc #489 * more tests, fix doc * lowercase origin before checking cors --- cors_filter.go | 64 ++++++++++++++++++--------------------------- cors_filter_test.go | 40 ++++++++++++++++++++++++++-- 2 files changed, 64 insertions(+), 40 deletions(-) diff --git a/cors_filter.go b/cors_filter.go index d6e7c857..9d18dfb7 100644 --- a/cors_filter.go +++ b/cors_filter.go @@ -18,9 +18,22 @@ import ( // http://enable-cors.org/server.html // http://www.html5rocks.com/en/tutorials/cors/#toc-handling-a-not-so-simple-request type CrossOriginResourceSharing struct { - ExposeHeaders []string // list of Header names - AllowedHeaders []string // list of Header names - AllowedDomains []string // list of allowed values for Http Origin. An allowed value can be a regular expression to support subdomain matching. If empty all are allowed. + ExposeHeaders []string // list of Header names + + // AllowedHeaders is alist of Header names. Checking is case-insensitive. + // The list may contain the special wildcard string ".*" ; all is allowed + AllowedHeaders []string + + // AllowedDomains is a list of allowed values for Http Origin. + // The list may contain the special wildcard string ".*" ; all is allowed + // If empty all are allowed. + AllowedDomains []string + + // AllowedDomainFunc is optional and is a function that will do the check + // when the origin is not part of the AllowedDomains and it does not contain the wildcard ".*". + AllowedDomainFunc func(origin string) bool + + // AllowedMethods is either empty or has a list of http methods names. Checking is case-insensitive. AllowedMethods []string MaxAge int // number of seconds before requiring new Options request CookiesAllowed bool @@ -119,36 +132,24 @@ func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool { if len(origin) == 0 { return false } + lowerOrigin := strings.ToLower(origin) if len(c.AllowedDomains) == 0 { + if c.AllowedDomainFunc != nil { + return c.AllowedDomainFunc(lowerOrigin) + } return true } - allowed := false + // exact match on each allowed domain for _, domain := range c.AllowedDomains { - if domain == origin { - allowed = true - break + if domain == ".*" || strings.ToLower(domain) == lowerOrigin { + return true } } - - if !allowed { - if len(c.allowedOriginPatterns) == 0 { - // compile allowed domains to allowed origin patterns - allowedOriginRegexps, err := compileRegexps(c.AllowedDomains) - if err != nil { - return false - } - c.allowedOriginPatterns = allowedOriginRegexps - } - - for _, pattern := range c.allowedOriginPatterns { - if allowed = pattern.MatchString(origin); allowed { - break - } - } + if c.AllowedDomainFunc != nil { + return c.AllowedDomainFunc(origin) } - - return allowed + return false } func (c CrossOriginResourceSharing) setAllowOriginHeader(req *Request, resp *Response) { @@ -190,16 +191,3 @@ func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header str } return false } - -// Take a list of strings and compile them into a list of regular expressions. -func compileRegexps(regexpStrings []string) ([]*regexp.Regexp, error) { - regexps := []*regexp.Regexp{} - for _, regexpStr := range regexpStrings { - r, err := regexp.Compile(regexpStr) - if err != nil { - return regexps, err - } - regexps = append(regexps, r) - } - return regexps, nil -} diff --git a/cors_filter_test.go b/cors_filter_test.go index 09c5d330..5a620c01 100644 --- a/cors_filter_test.go +++ b/cors_filter_test.go @@ -120,10 +120,46 @@ func TestCORSFilter_AllowedDomains(t *testing.T) { DefaultContainer.Dispatch(httpWriter, httpRequest) actual := httpWriter.Header().Get(HEADER_AccessControlAllowOrigin) if actual != each.origin && each.allowed { - t.Fatal("expected to be accepted") + t.Error("expected to be accepted", each) } if actual == each.origin && !each.allowed { - t.Fatal("did not expect to be accepted") + t.Error("did not expect to be accepted") } } } + +func TestCORSFilter_AllowedDomainFunc(t *testing.T) { + cors := CrossOriginResourceSharing{ + AllowedDomains: []string{"here", "there"}, + AllowedDomainFunc: func(origin string) bool { + return "where" == origin + }, + } + if got, want := cors.isOriginAllowed("here"), true; got != want { + t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want) + } + if got, want := cors.isOriginAllowed("HERE"), true; got != want { + t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want) + } + if got, want := cors.isOriginAllowed("there"), true; got != want { + t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want) + } + if got, want := cors.isOriginAllowed("where"), true; got != want { + t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want) + } + if got, want := cors.isOriginAllowed("nowhere"), false; got != want { + t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want) + } + // just func + cors.AllowedDomains = []string{} + if got, want := cors.isOriginAllowed("here"), false; got != want { + t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want) + } + if got, want := cors.isOriginAllowed("where"), true; got != want { + t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want) + } + // empty domain + if got, want := cors.isOriginAllowed(""), false; got != want { + t.Errorf("got [%v:%T] want [%v:%T]", got, got, want, want) + } +}