Skip to content

Commit

Permalink
Merge branch 'notfound' of github.com:nbjahan/martini into nbjahan-no…
Browse files Browse the repository at this point in the history
…tfound
  • Loading branch information
codegangsta committed Dec 5, 2013
2 parents 1102114 + 14a7cc3 commit b649d7b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 11 deletions.
21 changes: 10 additions & 11 deletions router.go
Expand Up @@ -29,21 +29,21 @@ type Router interface {
// Any adds a route for any HTTP method request to the specified matching pattern.
Any(string, ...Handler) Route

// NotFound sets the handler that is called when a no route matches a request. Throws a basic 404 by default.
NotFound(Handler)
// NotFound sets the handlers that are called when a no route matches a request. Throws a basic 404 by default.
NotFound(...Handler)

// Handle is the entry point for routing. This is used as a martini.Handler
Handle(http.ResponseWriter, *http.Request, Context)
}

type router struct {
routes []*route
notFound Handler
routes []*route
notFounds []Handler
}

// NewRouter creates a new Router instance.
func NewRouter() Router {
return &router{notFound: http.NotFound}
return &router{notFounds: []Handler{http.NotFound}}
}

func (r *router) Get(pattern string, h ...Handler) Route {
Expand Down Expand Up @@ -91,14 +91,13 @@ func (r *router) Handle(res http.ResponseWriter, req *http.Request, context Cont
}

// no routes exist, 404
_, err := context.Invoke(r.notFound)
if err != nil {
panic(err)
}
c := &routeContext{context, 0, r.notFounds}
context.MapTo(c, (*Context)(nil))
c.run()
}

func (r *router) NotFound(handler Handler) {
r.notFound = handler
func (r *router) NotFound(handler ...Handler) {
r.notFounds = handler
}

func (r *router) addRoute(method string, pattern string, handlers []Handler) *route {
Expand Down
78 changes: 78 additions & 0 deletions router_test.go
Expand Up @@ -225,6 +225,84 @@ func Test_NotFound(t *testing.T) {
expect(t, recorder.Body.String(), "Nope\n")
}

func Test_NotFoundAsHandler(t *testing.T) {
router := NewRouter()
recorder := httptest.NewRecorder()

req, _ := http.NewRequest("GET", "http://localhost:3000/foo", nil)
context := New().createContext(recorder, req)

router.NotFound(func() string {
return "not found"
})

router.Handle(recorder, req, context)
expect(t, recorder.Code, http.StatusOK)
expect(t, recorder.Body.String(), "not found")

recorder = httptest.NewRecorder()

context = New().createContext(recorder, req)

router.NotFound(func() (int, string) {
return 404, "not found"
})

router.Handle(recorder, req, context)
expect(t, recorder.Code, http.StatusNotFound)
expect(t, recorder.Body.String(), "not found")

recorder = httptest.NewRecorder()

context = New().createContext(recorder, req)

router.NotFound(func() (int, string) {
return 200, ""
})

router.Handle(recorder, req, context)
expect(t, recorder.Code, http.StatusOK)
expect(t, recorder.Body.String(), "")
}

func Test_NotFoundStacking(t *testing.T) {
router := NewRouter()
recorder := httptest.NewRecorder()

req, err := http.NewRequest("GET", "http://localhost:3000/foo", nil)
if err != nil {
t.Error(err)
}
context := New().createContext(recorder, req)

result := ""

f1 := func() {
result += "foo"
}

f2 := func(c Context) {
result += "bar"
c.Next()
result += "bing"
}

f3 := func() string {
result += "bat"
return "Not Found"
}

f4 := func() {
result += "baz"
}

router.NotFound(f1, f2, f3, f4)

router.Handle(recorder, req, context)
expect(t, result, "foobarbatbing")
expect(t, recorder.Body.String(), "Not Found")
}

func Test_Any(t *testing.T) {
router := NewRouter()
router.Any("/foo", func(res http.ResponseWriter) {
Expand Down

0 comments on commit b649d7b

Please sign in to comment.