-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
ipfilter.go
84 lines (72 loc) · 2.05 KB
/
ipfilter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package ipfilter
import (
"fmt"
"net"
"net/http"
"github.com/jpillora/ipfilter"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
// Config defines the config for IPFilter middleware.
type Config struct {
// Skipper defines a function to skip middleware.
// default middleware.DefaultSkipper
Skipper middleware.Skipper
// WhiteList is an allowed ip list.
WhiteList []string
// BlackList is a disallowed ip list.
BlackList []string
// Block by default.
BlockByDefault bool
// called with the newly created filter object to allow for
// controlling the filter during runtime.
// The underlying filter implementation is thankfully threadsafe
CreatedFilter func(*ipfilter.IPFilter)
}
// DefaultConfig is the default IPFilter middleware config
var DefaultConfig = Config{
Skipper: middleware.DefaultSkipper,
BlockByDefault: false,
}
// Middleware returns an IPFilter middleware to
// filter requests by ip matching / blocking.
func Middleware() echo.MiddlewareFunc {
return MiddlewareWithConfig(DefaultConfig)
}
// MiddlewareWithConfig returns an IPFilter middleware with config.
// See: `IPFilter()`.
func MiddlewareWithConfig(config Config) echo.MiddlewareFunc {
var err error
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultConfig.Skipper
}
// New jpillora/ipfilter instance
filter := ipfilter.New(ipfilter.Options{
AllowedIPs: config.WhiteList,
BlockedIPs: config.BlackList,
BlockByDefault: config.BlockByDefault,
Logger: nil,
})
if config.CreatedFilter != nil {
config.CreatedFilter(filter)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
ip := c.RealIP()
if ip == "" {
ip, _, err = net.SplitHostPort(c.Request().RemoteAddr)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
}
if !filter.Allowed(ip) {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("IP address %s not allowed", ip))
}
return next(c)
}
}
}