diff --git a/README.md b/README.md index 2ffe31e..924d4ee 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,24 @@ Using the example given below, here is the request results:

my post data

``` +Middlewares +----------- + +You can add some middlewares in your application by the following way: + +```go +app.Use(func(context gofast.Context, next gofast.MiddlewareFunc) gofast.Handler { + // Some code before calling the next middleware + handler := next(context, next) + // Some code after calling the next middleware + + return handler +}) +``` + +It allows you to access `context` (request, response, current route) and also +allows to define a new `handler` function to update the application behavior. + Default CORS headers -------------------- diff --git a/context.go b/context.go index cf8e98e..98f471a 100644 --- a/context.go +++ b/context.go @@ -11,6 +11,7 @@ import ( type Context struct { request *Request response *Response + route *Route } // Creates a new context component instance @@ -19,8 +20,8 @@ func NewContext() Context { } // Sets a HTTP request instance -func (c *Context) SetRequest(req *http.Request, route Route) { - request := NewRequest(req, route) +func (c *Context) SetRequest(req *http.Request) { + request := NewRequest(req) c.request = &request } @@ -29,6 +30,16 @@ func (c *Context) GetRequest() *Request { return c.request } +// Sets a route instance +func (c *Context) SetRoute(route *Route) { + c.route = route +} + +// Returns a route instance +func (c *Context) GetRoute() *Route { + return c.route +} + // Sets a HTTP response instance func (c *Context) SetResponse(res http.ResponseWriter) { response := NewResponse(res) diff --git a/gofast.go b/gofast.go index 9d8e399..f6e76cc 100644 --- a/gofast.go +++ b/gofast.go @@ -21,6 +21,7 @@ const ( type Gofast struct { *Router *Templating + *Middleware } // Bootstraps a new instance @@ -29,8 +30,9 @@ func Bootstrap() *Gofast { router := NewRouter() templating := NewTemplating() + middleware := NewMiddleware() - return &Gofast{&router, &templating} + return &Gofast{&router, &templating, &middleware} } // Listens and handles HTTP requests @@ -76,10 +78,12 @@ func (g *Gofast) HandleRoute(res http.ResponseWriter, req *http.Request, route R startTime := time.Now() context := NewContext() - context.SetRequest(req, route) + context.SetRoute(&route) + context.SetRequest(req) context.SetResponse(res) - route.handler(context) + handler := g.HandleMiddlewares(context) + handler(context) stopTime := time.Now() diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..8d95e5d --- /dev/null +++ b/middleware.go @@ -0,0 +1,40 @@ +// A fast web framework written in Go. +// +// Author: Vincent Composieux + +package gofast + +type Middleware struct { + middlewares []MiddlewareFunc +} + +type MiddlewareFunc func(context Context, middleware MiddlewareFunc) Handler + +// Creates a new middleware component instance +func NewMiddleware() Middleware { + return Middleware{middlewares: make([]MiddlewareFunc, 0)} +} + +// Adds a new middleware +func (m *Middleware) Use(middleware MiddlewareFunc) { + m.middlewares = append(m.middlewares, middleware) +} + +// Handle middlewares and returns handler +func (m *Middleware) HandleMiddlewares(context Context) Handler { + m.Use(func(context Context, next MiddlewareFunc) Handler { + return context.GetRoute().GetHandler() + }) + + handler := context.GetRoute().GetHandler() + + for i := 0; i < len(m.middlewares)-1; i++ { + middleware := m.middlewares[i] + next := m.middlewares[i+1] + + handler = middleware(context, next) + context.GetRoute().SetHandler(handler) + } + + return handler +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..76c13fe --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,34 @@ +// A fast web framework written in Go. +// +// Author: Vincent Composieux + +package gofast + +import ( + "testing" +) + +// Tests initializing a new middleware component +func TestMiddleware(t *testing.T) { + middleware := NewMiddleware() + + if len(middleware.middlewares) > 0 { + t.Fail() + } +} + +// Tests adding a new middlewares +func TestUseNewMiddlewares(t *testing.T) { + middleware := NewMiddleware() + middleware.Use(func(context Context, next MiddlewareFunc) Handler { + return next(context, next) + }) + + middleware.Use(func(context Context, next MiddlewareFunc) Handler { + return next(context, next) + }) + + if len(middleware.middlewares) < 2 { + t.Fail() + } +} diff --git a/request.go b/request.go index 3a2d3a8..d973d31 100644 --- a/request.go +++ b/request.go @@ -10,7 +10,6 @@ import ( type Request struct { httpRequest *http.Request - route *Route parameters []Parameter } @@ -20,10 +19,10 @@ type Parameter struct { } // Creates a new Request component instance -func NewRequest(req *http.Request, route Route) Request { +func NewRequest(req *http.Request) Request { req.ParseForm() - return Request{req, &route, make([]Parameter, 0)} + return Request{req, make([]Parameter, 0)} } // Returs HTTP request @@ -31,11 +30,6 @@ func (r *Request) GetHttpRequest() *http.Request { return r.httpRequest } -// Returns current route -func (r *Request) GetRoute() *Route { - return r.route -} - // Adds a request parameter func (r *Request) AddParameter(name string, value interface{}) { r.parameters = append(r.parameters, Parameter{name, value}) diff --git a/request_test.go b/request_test.go index 16c6cc8..39c2a59 100644 --- a/request_test.go +++ b/request_test.go @@ -6,28 +6,13 @@ package gofast import ( "net/http" - "regexp" "testing" ) -// Tests setting and retrieving current route -func TestRoute(t *testing.T) { - httpRequest := new(http.Request) - route := Route{"GET", "test", regexp.MustCompile("/test"), func(c Context) {}} - - request := NewRequest(httpRequest, route) - - if request.GetRoute().name != "test" { - t.Fail() - } -} - // Tests setting and retrieving request parameters func TestParameters(t *testing.T) { httpRequest := new(http.Request) - route := Route{"GET", "test", regexp.MustCompile("/test"), func(c Context) {}} - - request := NewRequest(httpRequest, route) + request := NewRequest(httpRequest) request.AddParameter("test1", "value1") request.AddParameter("test2", "value2") @@ -46,9 +31,7 @@ func TestGetHeader(t *testing.T) { httpRequest, _ := http.NewRequest("GET", "/", nil) httpRequest.Header.Set("X-Test-Header", "yes") - route := Route{"GET", "test", regexp.MustCompile("/test"), func(c Context) {}} - - request := NewRequest(httpRequest, route) + request := NewRequest(httpRequest) if request.GetHeader("X-Test-Header") != "yes" { t.Fail() diff --git a/router.go b/router.go index 9d53892..4e2a216 100644 --- a/router.go +++ b/router.go @@ -100,6 +100,16 @@ func (r *Route) GetPattern() *regexp.Regexp { return r.pattern } +// Sets a route handler +func (r *Route) SetHandler(handler Handler) { + r.handler = handler +} + +// Returns a route handler +func (r *Route) GetHandler() Handler { + return r.handler +} + // Route sort functions func (this RouteLen) Len() int { return len(this) diff --git a/router_test.go b/router_test.go index ad09a9b..145edc6 100644 --- a/router_test.go +++ b/router_test.go @@ -72,3 +72,17 @@ func TestFallbackRoute(t *testing.T) { t.Fail() } } + +// Tests route handling getter +func TestRouteGetHandler(t *testing.T) { + handler := func(c Context) {} + + router := NewRouter() + router.Get("get", "/get", handler) + + route := router.GetRoute("get") + + if route.GetHandler() == nil { + t.Fail() + } +}