diff --git a/README.md b/README.md
index 2ffe31e..924d4ee 100644
--- a/README.md
+++ b/README.md
@@ -115,6 +115,24 @@ Using the example given below, here is the request results:
my post data
```
+Middlewares
+-----------
+
+You can add some middlewares in your application by the following way:
+
+```go
+app.Use(func(context gofast.Context, next gofast.MiddlewareFunc) gofast.Handler {
+ // Some code before calling the next middleware
+ handler := next(context, next)
+ // Some code after calling the next middleware
+
+ return handler
+})
+```
+
+It allows you to access `context` (request, response, current route) and also
+allows to define a new `handler` function to update the application behavior.
+
Default CORS headers
--------------------
diff --git a/context.go b/context.go
index cf8e98e..98f471a 100644
--- a/context.go
+++ b/context.go
@@ -11,6 +11,7 @@ import (
type Context struct {
request *Request
response *Response
+ route *Route
}
// Creates a new context component instance
@@ -19,8 +20,8 @@ func NewContext() Context {
}
// Sets a HTTP request instance
-func (c *Context) SetRequest(req *http.Request, route Route) {
- request := NewRequest(req, route)
+func (c *Context) SetRequest(req *http.Request) {
+ request := NewRequest(req)
c.request = &request
}
@@ -29,6 +30,16 @@ func (c *Context) GetRequest() *Request {
return c.request
}
+// Sets a route instance
+func (c *Context) SetRoute(route *Route) {
+ c.route = route
+}
+
+// Returns a route instance
+func (c *Context) GetRoute() *Route {
+ return c.route
+}
+
// Sets a HTTP response instance
func (c *Context) SetResponse(res http.ResponseWriter) {
response := NewResponse(res)
diff --git a/gofast.go b/gofast.go
index 9d8e399..f6e76cc 100644
--- a/gofast.go
+++ b/gofast.go
@@ -21,6 +21,7 @@ const (
type Gofast struct {
*Router
*Templating
+ *Middleware
}
// Bootstraps a new instance
@@ -29,8 +30,9 @@ func Bootstrap() *Gofast {
router := NewRouter()
templating := NewTemplating()
+ middleware := NewMiddleware()
- return &Gofast{&router, &templating}
+ return &Gofast{&router, &templating, &middleware}
}
// Listens and handles HTTP requests
@@ -76,10 +78,12 @@ func (g *Gofast) HandleRoute(res http.ResponseWriter, req *http.Request, route R
startTime := time.Now()
context := NewContext()
- context.SetRequest(req, route)
+ context.SetRoute(&route)
+ context.SetRequest(req)
context.SetResponse(res)
- route.handler(context)
+ handler := g.HandleMiddlewares(context)
+ handler(context)
stopTime := time.Now()
diff --git a/middleware.go b/middleware.go
new file mode 100644
index 0000000..8d95e5d
--- /dev/null
+++ b/middleware.go
@@ -0,0 +1,40 @@
+// A fast web framework written in Go.
+//
+// Author: Vincent Composieux
+
+package gofast
+
+type Middleware struct {
+ middlewares []MiddlewareFunc
+}
+
+type MiddlewareFunc func(context Context, middleware MiddlewareFunc) Handler
+
+// Creates a new middleware component instance
+func NewMiddleware() Middleware {
+ return Middleware{middlewares: make([]MiddlewareFunc, 0)}
+}
+
+// Adds a new middleware
+func (m *Middleware) Use(middleware MiddlewareFunc) {
+ m.middlewares = append(m.middlewares, middleware)
+}
+
+// Handle middlewares and returns handler
+func (m *Middleware) HandleMiddlewares(context Context) Handler {
+ m.Use(func(context Context, next MiddlewareFunc) Handler {
+ return context.GetRoute().GetHandler()
+ })
+
+ handler := context.GetRoute().GetHandler()
+
+ for i := 0; i < len(m.middlewares)-1; i++ {
+ middleware := m.middlewares[i]
+ next := m.middlewares[i+1]
+
+ handler = middleware(context, next)
+ context.GetRoute().SetHandler(handler)
+ }
+
+ return handler
+}
diff --git a/middleware_test.go b/middleware_test.go
new file mode 100644
index 0000000..76c13fe
--- /dev/null
+++ b/middleware_test.go
@@ -0,0 +1,34 @@
+// A fast web framework written in Go.
+//
+// Author: Vincent Composieux
+
+package gofast
+
+import (
+ "testing"
+)
+
+// Tests initializing a new middleware component
+func TestMiddleware(t *testing.T) {
+ middleware := NewMiddleware()
+
+ if len(middleware.middlewares) > 0 {
+ t.Fail()
+ }
+}
+
+// Tests adding a new middlewares
+func TestUseNewMiddlewares(t *testing.T) {
+ middleware := NewMiddleware()
+ middleware.Use(func(context Context, next MiddlewareFunc) Handler {
+ return next(context, next)
+ })
+
+ middleware.Use(func(context Context, next MiddlewareFunc) Handler {
+ return next(context, next)
+ })
+
+ if len(middleware.middlewares) < 2 {
+ t.Fail()
+ }
+}
diff --git a/request.go b/request.go
index 3a2d3a8..d973d31 100644
--- a/request.go
+++ b/request.go
@@ -10,7 +10,6 @@ import (
type Request struct {
httpRequest *http.Request
- route *Route
parameters []Parameter
}
@@ -20,10 +19,10 @@ type Parameter struct {
}
// Creates a new Request component instance
-func NewRequest(req *http.Request, route Route) Request {
+func NewRequest(req *http.Request) Request {
req.ParseForm()
- return Request{req, &route, make([]Parameter, 0)}
+ return Request{req, make([]Parameter, 0)}
}
// Returs HTTP request
@@ -31,11 +30,6 @@ func (r *Request) GetHttpRequest() *http.Request {
return r.httpRequest
}
-// Returns current route
-func (r *Request) GetRoute() *Route {
- return r.route
-}
-
// Adds a request parameter
func (r *Request) AddParameter(name string, value interface{}) {
r.parameters = append(r.parameters, Parameter{name, value})
diff --git a/request_test.go b/request_test.go
index 16c6cc8..39c2a59 100644
--- a/request_test.go
+++ b/request_test.go
@@ -6,28 +6,13 @@ package gofast
import (
"net/http"
- "regexp"
"testing"
)
-// Tests setting and retrieving current route
-func TestRoute(t *testing.T) {
- httpRequest := new(http.Request)
- route := Route{"GET", "test", regexp.MustCompile("/test"), func(c Context) {}}
-
- request := NewRequest(httpRequest, route)
-
- if request.GetRoute().name != "test" {
- t.Fail()
- }
-}
-
// Tests setting and retrieving request parameters
func TestParameters(t *testing.T) {
httpRequest := new(http.Request)
- route := Route{"GET", "test", regexp.MustCompile("/test"), func(c Context) {}}
-
- request := NewRequest(httpRequest, route)
+ request := NewRequest(httpRequest)
request.AddParameter("test1", "value1")
request.AddParameter("test2", "value2")
@@ -46,9 +31,7 @@ func TestGetHeader(t *testing.T) {
httpRequest, _ := http.NewRequest("GET", "/", nil)
httpRequest.Header.Set("X-Test-Header", "yes")
- route := Route{"GET", "test", regexp.MustCompile("/test"), func(c Context) {}}
-
- request := NewRequest(httpRequest, route)
+ request := NewRequest(httpRequest)
if request.GetHeader("X-Test-Header") != "yes" {
t.Fail()
diff --git a/router.go b/router.go
index 9d53892..4e2a216 100644
--- a/router.go
+++ b/router.go
@@ -100,6 +100,16 @@ func (r *Route) GetPattern() *regexp.Regexp {
return r.pattern
}
+// Sets a route handler
+func (r *Route) SetHandler(handler Handler) {
+ r.handler = handler
+}
+
+// Returns a route handler
+func (r *Route) GetHandler() Handler {
+ return r.handler
+}
+
// Route sort functions
func (this RouteLen) Len() int {
return len(this)
diff --git a/router_test.go b/router_test.go
index ad09a9b..145edc6 100644
--- a/router_test.go
+++ b/router_test.go
@@ -72,3 +72,17 @@ func TestFallbackRoute(t *testing.T) {
t.Fail()
}
}
+
+// Tests route handling getter
+func TestRouteGetHandler(t *testing.T) {
+ handler := func(c Context) {}
+
+ router := NewRouter()
+ router.Get("get", "/get", handler)
+
+ route := router.GetRoute("get")
+
+ if route.GetHandler() == nil {
+ t.Fail()
+ }
+}